From 83cb0d9329b099856b335d67aea678ceb0477030 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 21 Nov 2019 13:31:20 +0300 Subject: [PATCH] [WIP] Create and small fix (#67) * - create op - skip exec for empty inputs for non_max_suppression - EmptyHandling idea Signed-off-by: raver119 * Create op and mapping for it Signed-off-by: raver119 --- libnd4j/include/ops/declarable/DeclarableOp.h | 3 +- .../include/ops/declarable/EmptyHandling.h | 32 ++++ .../parity_ops/non_max_suppression.cpp | 3 + .../ops/declarable/generic/shape/create.cpp | 57 ++++++++ .../include/ops/declarable/headers/shape.h | 16 ++ .../ops/declarable/impl/DeclarableOp.cpp | 4 + .../DifferentialFunctionFactory.java | 40 ++--- .../converters/ImportClassMapping.java | 1 + .../linalg/api/ops/impl/shape/Create.java | 138 ++++++++++++++++++ .../java/org/nd4j/nativeblas/Nd4jCuda.java | 1 + .../java/org/nd4j/nativeblas/Nd4jCpu.java | 129 ++++++++++++++-- .../nd4j/linalg/custom/CustomOpsTests.java | 12 ++ 12 files changed, 389 insertions(+), 47 deletions(-) create mode 100644 libnd4j/include/ops/declarable/EmptyHandling.h create mode 100644 libnd4j/include/ops/declarable/generic/shape/create.cpp create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index 5da74860b..ea1f20d34 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -32,6 +32,7 @@ #include #include #include +#include //#include #include @@ -111,7 +112,7 @@ namespace nd4j { */ int prepareOutputs(Context& block); - //std::vector* calculateOutputShape(std::vector* inputShape, nd4j::graph::Block& block); + virtual samediff::EmptyHandling emptyHandling(); public: // for special cases, like BooleanOps DeclarableOp(); diff --git a/libnd4j/include/ops/declarable/EmptyHandling.h b/libnd4j/include/ops/declarable/EmptyHandling.h new file mode 100644 index 000000000..c25fea498 --- /dev/null +++ b/libnd4j/include/ops/declarable/EmptyHandling.h @@ -0,0 +1,32 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef SAMEDIFF_EMPTYHANDLING_H +#define SAMEDIFF_EMPTYHANDLING_H + +namespace samediff { + enum EmptyHandling { + EMPTY_SKIP = 1, + EMPTY_EXCEPTION = 2, + EMPTY_EXECUTE = 3 + }; +} + +#endif //SAMEDIFF_EMPTYHANDLING_H diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp index e2fe58b7a..c56e32f31 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/non_max_suppression.cpp @@ -37,6 +37,9 @@ namespace nd4j { else REQUIRE_TRUE(false, 0, "image.non_max_suppression: Max output size argument cannot be retrieved."); + if (boxes->isEmpty() || scales->isEmpty()) + return Status::OK(); + REQUIRE_TRUE(boxes->rankOf() == 2, 0, "image.non_max_suppression: The rank of boxes array should be 2, but %i is given", boxes->rankOf()); REQUIRE_TRUE(boxes->sizeAt(1) == 4, 0, "image.non_max_suppression: The last dimension of boxes array should be 4, but %i is given", boxes->sizeAt(1)); REQUIRE_TRUE(scales->rankOf() == 1 && scales->lengthOf() == boxes->sizeAt(0), 0, "image.non_max_suppression: The rank of scales array should be 1, but %i is given", boxes->rankOf()); diff --git a/libnd4j/include/ops/declarable/generic/shape/create.cpp b/libnd4j/include/ops/declarable/generic/shape/create.cpp new file mode 100644 index 000000000..e743a5cad --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/shape/create.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_shapes_of) + +#include + +namespace nd4j { + namespace ops { + CUSTOM_OP_IMPL(create, 1, 1, false, 0, 1) { + auto init = block.numB() > 0 ? B_ARG(0) : true; + + if (init) + OUTPUT_VARIABLE(0)->nullify(); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(create) { + auto shapeInput = INPUT_VARIABLE(0); + auto order = (char) INT_ARG(0); + auto dtype = DataTypeUtils::fromInt(INT_ARG(1)); + + REQUIRE_TRUE(order == 'c' || order == 'f', 0, "create: order must be either c or f"); + + auto shape = shapeInput->getBufferAsVector(); + + return SHAPELIST(nd4j::ConstantShapeHelper::getInstance()->createShapeInfo(dtype, order, shape)); + } + + DECLARE_TYPES(create) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setAllowedOutputTypes(nd4j::DataType::ANY); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/shape.h b/libnd4j/include/ops/declarable/headers/shape.h index 9e0c19a7b..3d47c24bf 100644 --- a/libnd4j/include/ops/declarable/headers/shape.h +++ b/libnd4j/include/ops/declarable/headers/shape.h @@ -99,6 +99,22 @@ namespace nd4j { #if NOT_EXCLUDED(OP_evaluate_reduction_shape) DECLARE_CUSTOM_OP(evaluate_reduction_shape, 2, 1, false, 0, 0); #endif + + /** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ + #if NOT_EXCLUDED(OP_create) + DECLARE_CUSTOM_OP(create, 1, 1, false, 0, 1); + #endif } } diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 8c65ac25e..3aef09bcd 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -933,6 +933,10 @@ namespace nd4j { return ND4J_STATUS_OK; } + samediff::EmptyHandling DeclarableOp::emptyHandling() { + return samediff::EmptyHandling::EMPTY_SKIP; + } + void DeclarableOp::registerTypes() { this->getOpDescriptor()->setSameMode(true); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 445be0a6a..abea31459 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -149,36 +149,7 @@ import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; -import org.nd4j.linalg.api.ops.impl.shape.Concat; -import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix; -import org.nd4j.linalg.api.ops.impl.shape.Cross; -import org.nd4j.linalg.api.ops.impl.shape.Diag; -import org.nd4j.linalg.api.ops.impl.shape.DiagPart; -import org.nd4j.linalg.api.ops.impl.shape.ExpandDims; -import org.nd4j.linalg.api.ops.impl.shape.Gather; -import org.nd4j.linalg.api.ops.impl.shape.GatherNd; -import org.nd4j.linalg.api.ops.impl.shape.MergeAvg; -import org.nd4j.linalg.api.ops.impl.shape.MergeMax; -import org.nd4j.linalg.api.ops.impl.shape.MeshGrid; -import org.nd4j.linalg.api.ops.impl.shape.OneHot; -import org.nd4j.linalg.api.ops.impl.shape.OnesLike; -import org.nd4j.linalg.api.ops.impl.shape.ParallelStack; -import org.nd4j.linalg.api.ops.impl.shape.Permute; -import org.nd4j.linalg.api.ops.impl.shape.Rank; -import org.nd4j.linalg.api.ops.impl.shape.ReductionShape; -import org.nd4j.linalg.api.ops.impl.shape.Repeat; -import org.nd4j.linalg.api.ops.impl.shape.Reshape; -import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; -import org.nd4j.linalg.api.ops.impl.shape.Size; -import org.nd4j.linalg.api.ops.impl.shape.SizeAt; -import org.nd4j.linalg.api.ops.impl.shape.Slice; -import org.nd4j.linalg.api.ops.impl.shape.Squeeze; -import org.nd4j.linalg.api.ops.impl.shape.Stack; -import org.nd4j.linalg.api.ops.impl.shape.StridedSlice; -import org.nd4j.linalg.api.ops.impl.shape.Tile; -import org.nd4j.linalg.api.ops.impl.shape.Transpose; -import org.nd4j.linalg.api.ops.impl.shape.Unstack; -import org.nd4j.linalg.api.ops.impl.shape.ZerosLike; +import org.nd4j.linalg.api.ops.impl.shape.*; import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; @@ -339,6 +310,15 @@ public class DifferentialFunctionFactory { return new ZerosLike(name, sameDiff(), input).outputVariable(); } + public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) { + return create(name, shape, 'c', initialize, dataType); + } + + public SDVariable create(String name, SDVariable shape, char order, boolean initialize, DataType dataType) { + validateDifferentialFunctionsameDiff(shape); + return new Create(name, sameDiff(), shape, order, initialize, dataType).outputVariable(); + } + public SDVariable onesLike(String name, SDVariable input, DataType dataType) { validateDifferentialFunctionsameDiff(input); return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 186bd96f0..7b19406ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -65,6 +65,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp.class, org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp.class, org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo.class, + org.nd4j.linalg.api.ops.impl.shape.Create.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java new file mode 100644 index 000000000..0b966008c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java @@ -0,0 +1,138 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.shape; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * This operation creates a new, optionally nullified, array with a given shape, order and data type + * + * @author raver119@gmail.com + */ +@Slf4j +public class Create extends DynamicCustomOp { + + protected boolean initialize = false; + protected char order = 'c'; + protected DataType outputType = DataType.FLOAT; //Allow customizing dtype for TF import + + public Create() { + } + + public Create(String name, SameDiff sameDiff, SDVariable input, boolean initialize) { + this(name, sameDiff, input, 'c', initialize, input.dataType()); + } + + public Create(String name, SameDiff sameDiff, SDVariable input, char order, boolean initialize, DataType dataType) { + super(name, sameDiff, new SDVariable[]{input}, false); + this.outputType = dataType; + this.initialize = initialize; + this.order = order; + + addArgs(); + } + + public Create(INDArray shape, DataType dataType) { + this(shape, 'c', false, dataType); + } + + public Create(INDArray shape, boolean initialize, DataType dataType) { + this(shape, 'c', initialize, dataType); + } + + public Create(@NonNull INDArray shape, char order, boolean initialize, DataType dataType) { + super(new INDArray[]{shape}, new INDArray[0]); + this.order = order; + this.initialize = initialize; + this.outputType = dataType; + + addArgs(); + } + + protected void addArgs() { + addBArgument(initialize); + addIArgument((int) order,outputType.toInt()); + } + + @Override + public String opName() { + return "create"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No op found for " + opName()); + } + + @Override + public String tensorflowName() { + return "Empty"; + } + + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + // convert output data type + if(attributesForNode.containsKey("dtype")) { + outputType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType()); + } + + // get init field + if(attributesForNode.containsKey("init")) { + initialize = attributesForNode.get("init").getB(); + } + + // there's no order in TF, just plain C + this.order = 'c'; + addArgs(); + } + + @Override + public List doDiff(List i_v) { + SDVariable ret = sameDiff.zerosLike(outputVariables()[0]); + return Arrays.asList(ret); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.size() == 1, "Expected list with exactly 1 datatype for %s, got %s", getClass(), dataTypes); + if(outputType != null){ + return Collections.singletonList(outputType); + } else { + //Output type is same as input type + return dataTypes; + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index e03937d0f..1799ceb22 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -9632,6 +9632,7 @@ public static final int PREALLOC_SIZE = 33554432; // #include // #include // #include +// #include //#include // #include diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 84f8b2c12..828d9b290 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -11834,6 +11834,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #include // #include // #include +// #include //#include // #include @@ -17126,6 +17127,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20576,7 +20578,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** + /** * This op make bilinear or nearest neighbor interpolated resize for given tensor * * input array: @@ -20612,7 +20614,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** + /** * This op make bilinear interpolated resize for given tensor * * input array: @@ -20647,7 +20649,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** + /** * This op make nearest neighbor interpolated resize for given tensor * * input array: @@ -20659,7 +20661,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * 1 - new height * * output array: - * the 4D-Tensor with calculated backproped dots + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) * * CAUTION: either size tensor or a pair of int params should be provided. */ @@ -20682,21 +20684,85 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif - /** - * This op calculates backprop dot for two tensors along given dimensions + /** + * This op make bicubic interpolated resize for given tensor * * input array: - * x: tensor to calculate dot for - * y: tensor to calculate dot for - * z: tensor with gradient output of the FF dot for x and y - * - * int arguments: - * list of integers - dimensions to calculate dot along, - * default corresponds to empty list in which case calculation - * is performed for all dimensions and scalar is returned. + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) * * output array: - * the tensor with calculated backproped dots + * the 4D-Tensor with resized image (shape is {batch, newWidth, newHeight, channels}) + * + */ +// #if NOT_EXCLUDED(OP_resize_bicubic) + @Namespace("nd4j::ops") public static class resize_bicubic extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public resize_bicubic(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public resize_bicubic(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public resize_bicubic position(long position) { + return (resize_bicubic)super.position(position); + } + + public resize_bicubic() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * This op make interpolated resize for given tensor with given algorithm. + * Supported algorithms are bilinear, bicubic, nearest_neighbor. + * Need to implement to full compatibility with TF: lanczos5, gaussian, area and mitchellcubic + * + * input array: + * 0 - 4D-Tensor with shape (batch, sizeX, sizeY, channels) + * 1 - 1D-Tensor with 2 values (newWidth, newHeight) + * + * optional int args: + * 0 - algorithm - bilinear by default + * optional bool args: + * 0 - preserve_aspect_ratio - default False + * 1 - antialias - default False + * + * output array: + * the 4D-Tensor with resized by given algorithm image (shape is {batch, newWidth, newHeight, channels}) + * + */ + +// #if NOT_EXCLUDED(OP_image_resize) + @Namespace("nd4j::ops") public static class image_resize extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public image_resize(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public image_resize(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public image_resize position(long position) { + return (image_resize)super.position(position); + } + + public image_resize() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * Copy a tensor setting everything outside a central band in each innermost matrix + * + * input array: + * x: given tensor with shape {..., M, N} - as vector (matrix) of matricies MxN + * + * int arguments: + * lower band + * upper band + * + * output array: + * matrix with given bands between lower and upper diagonals * */ @@ -20736,7 +20802,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif - /* + + /** * image.non_max_suppression op. * input: * 0 - boxes - 2D-tensor with shape (num_boxes, 4) by float type @@ -21270,6 +21337,36 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif + + /** + * This operation creates new array + * Input: + * array with shape values + * + * IArgs: + * order value + * data type value + * + * BArgs: + * initialization option + */ +// #if NOT_EXCLUDED(OP_create) + @Namespace("nd4j::ops") public static class create extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public create(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public create(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public create position(long position) { + return (create)super.position(position); + } + + public create() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index b2e6268d3..84dd02cd4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -34,6 +34,7 @@ import org.nd4j.linalg.api.ops.impl.controlflow.Where; import org.nd4j.linalg.api.ops.impl.image.CropAndResize; import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; +import org.nd4j.linalg.api.ops.impl.shape.Create; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; @@ -1085,4 +1086,15 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(1, lsd.size()); assertArrayEquals(new long[]{1,10, 2}, lsd.get(0).getShape()); } + + + @Test + public void testCreateOp_1() { + val shape = Nd4j.createFromArray(new int[] {3, 4, 5}); + val exp = Nd4j.create(DataType.INT, 3, 4, 5); + + val result = Nd4j.exec(new Create(shape, 'c', true, DataType.INT))[0]; + + assertEquals(exp, result); + } }