[WIP] Create and small fix (#67)

* - create op
- skip exec for empty inputs for non_max_suppression
- EmptyHandling idea

Signed-off-by: raver119 <raver119@gmail.com>

* Create op and mapping for it

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-21 13:31:20 +03:00 committed by GitHub
parent dc0036f2c6
commit 83cb0d9329
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 389 additions and 47 deletions

View File

@ -32,6 +32,7 @@
#include <array/ResultSet.h> #include <array/ResultSet.h>
#include <helpers/OpArgsHolder.h> #include <helpers/OpArgsHolder.h>
#include <dll.h> #include <dll.h>
#include <ops/declarable/EmptyHandling.h>
//#include <ops/declarable/declarable_ops.h> //#include <ops/declarable/declarable_ops.h>
#include <chrono> #include <chrono>
@ -111,7 +112,7 @@ namespace nd4j {
*/ */
int prepareOutputs(Context& block); int prepareOutputs(Context& block);
//std::vector<int>* calculateOutputShape(std::vector<int>* inputShape, nd4j::graph::Block<T>& block); virtual samediff::EmptyHandling emptyHandling();
public: public:
// for special cases, like BooleanOps // for special cases, like BooleanOps
DeclarableOp(); DeclarableOp();

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_EMPTYHANDLING_H
#define SAMEDIFF_EMPTYHANDLING_H
namespace samediff {
enum EmptyHandling {
EMPTY_SKIP = 1,
EMPTY_EXCEPTION = 2,
EMPTY_EXECUTE = 3
};
}
#endif //SAMEDIFF_EMPTYHANDLING_H

View File

@ -37,6 +37,9 @@ namespace nd4j {
else else
REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved.");
if (boxes->isEmpty() || scales->isEmpty())
return Status::OK();
REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf());
REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1)); REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1));
REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf());

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_shapes_of)
#include <ops/declarable/CustomOperations.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) {
auto init = block.numB() > 0 ? B_ARG(0) : true;
if (init)
OUTPUT_VARIABLE(0)->nullify();
return Status::OK();
}
DECLARE_SHAPE_FN(create) {
auto shapeInput = INPUT_VARIABLE(0);
auto order = (char) INT_ARG(0);
auto dtype = DataTypeUtils::fromInt(INT_ARG(1));
REQUIRE_TRUE(order == 'c' || order == 'f', 0, "create: order must be either c or f");
auto shape = shapeInput->getBufferAsVector<Nd4jLong>();
return SHAPELIST(nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, order, shape));
}
DECLARE_TYPES(create) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setAllowedOutputTypes(nd4j::DataType::ANY);
}
}
}
#endif

View File

@ -99,6 +99,22 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_evaluate_reduction_shape) #if NOT_EXCLUDED(OP_evaluate_reduction_shape)
DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0); DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0);
#endif #endif
/**
* This operation creates new array
* Input:
* array with shape values
*
* IArgs:
* order value
* data type value
*
* BArgs:
* initialization option
*/
#if NOT_EXCLUDED(OP_create)
DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1);
#endif
} }
} }

View File

@ -933,6 +933,10 @@ namespace nd4j {
return ND4J_STATUS_OK; return ND4J_STATUS_OK;
} }
samediff::EmptyHandling DeclarableOp::emptyHandling() {
return samediff::EmptyHandling::EMPTY_SKIP;
}
void DeclarableOp::registerTypes() { void DeclarableOp::registerTypes() {
this->getOpDescriptor()->setSameMode(true); this->getOpDescriptor()->setSameMode(true);
} }

View File

@ -149,36 +149,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.shape.Concat; import org.nd4j.linalg.api.ops.impl.shape.*;
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
import org.nd4j.linalg.api.ops.impl.shape.Cross;
import org.nd4j.linalg.api.ops.impl.shape.Diag;
import org.nd4j.linalg.api.ops.impl.shape.DiagPart;
import org.nd4j.linalg.api.ops.impl.shape.ExpandDims;
import org.nd4j.linalg.api.ops.impl.shape.Gather;
import org.nd4j.linalg.api.ops.impl.shape.GatherNd;
import org.nd4j.linalg.api.ops.impl.shape.MergeAvg;
import org.nd4j.linalg.api.ops.impl.shape.MergeMax;
import org.nd4j.linalg.api.ops.impl.shape.MeshGrid;
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.api.ops.impl.shape.OnesLike;
import org.nd4j.linalg.api.ops.impl.shape.ParallelStack;
import org.nd4j.linalg.api.ops.impl.shape.Permute;
import org.nd4j.linalg.api.ops.impl.shape.Rank;
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
import org.nd4j.linalg.api.ops.impl.shape.Repeat;
import org.nd4j.linalg.api.ops.impl.shape.Reshape;
import org.nd4j.linalg.api.ops.impl.shape.SequenceMask;
import org.nd4j.linalg.api.ops.impl.shape.Size;
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
import org.nd4j.linalg.api.ops.impl.shape.Slice;
import org.nd4j.linalg.api.ops.impl.shape.Squeeze;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.StridedSlice;
import org.nd4j.linalg.api.ops.impl.shape.Tile;
import org.nd4j.linalg.api.ops.impl.shape.Transpose;
import org.nd4j.linalg.api.ops.impl.shape.Unstack;
import org.nd4j.linalg.api.ops.impl.shape.ZerosLike;
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
@ -339,6 +310,15 @@ public class DifferentialFunctionFactory {
return new ZerosLike(name, sameDiff(), input).outputVariable(); return new ZerosLike(name, sameDiff(), input).outputVariable();
} }
public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) {
return create(name, shape, 'c', initialize, dataType);
}
public SDVariable create(String name, SDVariable shape, char order, boolean initialize, DataType dataType) {
validateDifferentialFunctionsameDiff(shape);
return new Create(name, sameDiff(), shape, order, initialize, dataType).outputVariable();
}
public SDVariable onesLike(String name, SDVariable input, DataType dataType) { public SDVariable onesLike(String name, SDVariable input, DataType dataType) {
validateDifferentialFunctionsameDiff(input); validateDifferentialFunctionsameDiff(input);
return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); return new OnesLike(name, sameDiff(), input, dataType).outputVariable();

View File

@ -65,6 +65,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp.class, org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp.class,
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp.class, org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp.class,
org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo.class, org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo.class,
org.nd4j.linalg.api.ops.impl.shape.Create.class,
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo.class,
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan.class,
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual.class,

View File

@ -0,0 +1,138 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
/**
* This operation creates a new, optionally nullified, array with a given shape, order and data type
*
* @author raver119@gmail.com
*/
@Slf4j
public class Create extends DynamicCustomOp {
protected boolean initialize = false;
protected char order = 'c';
protected DataType outputType = DataType.FLOAT; //Allow customizing dtype for TF import
public Create() {
}
public Create(String name, SameDiff sameDiff, SDVariable input, boolean initialize) {
this(name, sameDiff, input, 'c', initialize, input.dataType());
}
public Create(String name, SameDiff sameDiff, SDVariable input, char order, boolean initialize, DataType dataType) {
super(name, sameDiff, new SDVariable[]{input}, false);
this.outputType = dataType;
this.initialize = initialize;
this.order = order;
addArgs();
}
public Create(INDArray shape, DataType dataType) {
this(shape, 'c', false, dataType);
}
public Create(INDArray shape, boolean initialize, DataType dataType) {
this(shape, 'c', initialize, dataType);
}
public Create(@NonNull INDArray shape, char order, boolean initialize, DataType dataType) {
super(new INDArray[]{shape}, new INDArray[0]);
this.order = order;
this.initialize = initialize;
this.outputType = dataType;
addArgs();
}
protected void addArgs() {
addBArgument(initialize);
addIArgument((int) order,outputType.toInt());
}
@Override
public String opName() {
return "create";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No op found for " + opName());
}
@Override
public String tensorflowName() {
return "Empty";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
// convert output data type
if(attributesForNode.containsKey("dtype")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
// get init field
if(attributesForNode.containsKey("init")) {
initialize = attributesForNode.get("init").getB();
}
// there's no order in TF, just plain C
this.order = 'c';
addArgs();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = sameDiff.zerosLike(outputVariables()[0]);
return Arrays.asList(ret);
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.size() == 1, "Expected list with exactly 1 datatype for %s, got %s", getClass(), dataTypes);
if(outputType != null){
return Collections.singletonList(outputType);
} else {
//Output type is same as input type
return dataTypes;
}
}
}

View File

@ -9632,6 +9632,7 @@ public static final int PREALLOC_SIZE = 33554432;
// #include <array/ResultSet.h> // #include <array/ResultSet.h>
// #include <helpers/OpArgsHolder.h> // #include <helpers/OpArgsHolder.h>
// #include <dll.h> // #include <dll.h>
// #include <ops/declarable/EmptyHandling.h>
//#include <ops/declarable/declarable_ops.h> //#include <ops/declarable/declarable_ops.h>
// #include <chrono> // #include <chrono>

View File

@ -11834,6 +11834,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #include <array/ResultSet.h> // #include <array/ResultSet.h>
// #include <helpers/OpArgsHolder.h> // #include <helpers/OpArgsHolder.h>
// #include <dll.h> // #include <dll.h>
// #include <ops/declarable/EmptyHandling.h>
//#include <ops/declarable/declarable_ops.h> //#include <ops/declarable/declarable_ops.h>
// #include <chrono> // #include <chrono>
@ -17126,6 +17127,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -20576,7 +20578,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/** /**
* This op make bilinear or nearest neighbor interpolated resize for given tensor * This op make bilinear or nearest neighbor interpolated resize for given tensor
* *
* input array: * input array:
@ -20612,7 +20614,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/** /**
* This op make bilinear interpolated resize for given tensor * This op make bilinear interpolated resize for given tensor
* *
* input array: * input array:
@ -20647,7 +20649,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/** /**
* This op make nearest neighbor interpolated resize for given tensor * This op make nearest neighbor interpolated resize for given tensor
* *
* input array: * input array:
@ -20659,7 +20661,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
* 1 - new height * 1 - new height
* *
* output array: * output array:
* the 4D-Tensor with calculated backproped dots * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels})
* *
* CAUTION: either size tensor or a pair of int params should be provided. * CAUTION: either size tensor or a pair of int params should be provided.
*/ */
@ -20682,21 +20684,85 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/** /**
* This op calculates backprop dot for two tensors along given dimensions * This op make bicubic interpolated resize for given tensor
* *
* input array: * input array:
* x: tensor to calculate dot for * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels)
* y: tensor to calculate dot for * 1 - 1D-Tensor with 2 values (newWidth, newHeight)
* z: tensor with gradient output of the FF dot for x and y
*
* int arguments:
* list of integers - dimensions to calculate dot along,
* default corresponds to empty list in which case calculation
* is performed for all dimensions and scalar is returned.
* *
* output array: * output array:
* the tensor with calculated backproped dots * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels})
*
*/
// #if NOT_EXCLUDED(OP_resize_bicubic)
@Namespace("nd4j::ops") public static class resize_bicubic extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public resize_bicubic(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public resize_bicubic(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public resize_bicubic position(long position) {
return (resize_bicubic)super.position(position);
}
public resize_bicubic() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* This op make interpolated resize for given tensor with given algorithm.
* Supported algorithms are bilinear, bicubic, nearest_neighbor.
* Need to implement to full compatibility with TF: lanczos5, gaussian, area and mitchellcubic
*
* input array:
* 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels)
* 1 - 1D-Tensor with 2 values (newWidth, newHeight)
*
* optional int args:
* 0 - algorithm - bilinear by default
* optional bool args:
* 0 - preserve_aspect_ratio - default False
* 1 - antialias - default False
*
* output array:
* the 4D-Tensor with resized by given algorithm image (shape is {batch, newWidth, newHeight, channels})
*
*/
// #if NOT_EXCLUDED(OP_image_resize)
@Namespace("nd4j::ops") public static class image_resize extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public image_resize(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public image_resize(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public image_resize position(long position) {
return (image_resize)super.position(position);
}
public image_resize() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* Copy a tensor setting everything outside a central band in each innermost matrix
*
* input array:
* x: given tensor with shape {..., M, N} - as vector (matrix) of matricies MxN
*
* int arguments:
* lower band
* upper band
*
* output array:
* matrix with given bands between lower and upper diagonals
* *
*/ */
@ -20736,7 +20802,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
/*
/**
* image.non_max_suppression op. * image.non_max_suppression op.
* input: * input:
* 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type
@ -21270,6 +21337,36 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
/**
* This operation creates new array
* Input:
* array with shape values
*
* IArgs:
* order value
* data type value
*
* BArgs:
* initialization option
*/
// #if NOT_EXCLUDED(OP_create)
@Namespace("nd4j::ops") public static class create extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public create(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public create(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public create position(long position) {
return (create)super.position(position);
}
public create() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif

View File

@ -34,6 +34,7 @@ import org.nd4j.linalg.api.ops.impl.controlflow.Where;
import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.CropAndResize;
import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.shape.Create;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
@ -1085,4 +1086,15 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(1, lsd.size()); assertEquals(1, lsd.size());
assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape());
} }
@Test
public void testCreateOp_1() {
val shape = Nd4j.createFromArray(new int[] {3, 4, 5});
val exp = Nd4j.create(DataType.INT, 3, 4, 5);
val result = Nd4j.exec(new Create(shape, 'c', true, DataType.INT))[0];
assertEquals(exp, result);
}
} }