diff --git a/libnd4j/include/ops/declarable/OpDescriptor.h b/libnd4j/include/ops/declarable/OpDescriptor.h index a5175914d..2c857f3c0 100644 --- a/libnd4j/include/ops/declarable/OpDescriptor.h +++ b/libnd4j/include/ops/declarable/OpDescriptor.h @@ -87,6 +87,10 @@ namespace nd4j { std::map> _outputTypes; std::map> _inputTypes; + + // field for ops that allow data type override at runtime + bool _dtypeOverride = false; + bool checkDataTypesMatch(nd4j::DataType needle, std::vector &haystack) const; public: // default constructor @@ -164,6 +168,7 @@ namespace nd4j { OpDescriptor* setAllowedOutputTypes(int index, nd4j::DataType dtype); OpDescriptor* setAllowedInputTypes(nd4j::DataType dtype); OpDescriptor* setAllowedOutputTypes(nd4j::DataType dtype); + OpDescriptor* allowOverride(bool reallyAllow); OpDescriptor* setSameMode(bool reallySame); OpDescriptor* setInputType(int idx, nd4j::DataType dtype); OpDescriptor* setOutputType(int idx, nd4j::DataType dtype); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp index 9ef0ed12b..6035b267d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/rank.cpp @@ -31,7 +31,8 @@ namespace nd4j { REQUIRE_TRUE(output->isScalar(), 0, "Rank output should be scalar"); - output->assign(input->rankOf()); + output->p(0, input->rankOf()); + output->syncToDevice(); return Status::OK(); } @@ -43,7 +44,8 @@ namespace nd4j { DECLARE_TYPES(rank) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->allowOverride(true); } } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/size.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/size.cpp index c4c3a1f55..ed8927f98 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/size.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/size.cpp @@ -31,7 +31,8 @@ namespace nd4j { REQUIRE_TRUE(output->isScalar(), 0, "Size output should be scalar"); - output->assign(input->lengthOf()); + output->p(0, input->lengthOf()); + output->syncToDevice(); return Status::OK(); } @@ -42,7 +43,8 @@ namespace nd4j { DECLARE_TYPES(size) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes({ALL_INTS}); + ->setAllowedOutputTypes({ALL_INTS, ALL_FLOATS}) + ->allowOverride(true); } } } diff --git a/libnd4j/include/ops/declarable/generic/shape/size_at.cpp b/libnd4j/include/ops/declarable/generic/shape/size_at.cpp index c0ea21f81..0f1e9c669 100644 --- a/libnd4j/include/ops/declarable/generic/shape/size_at.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/size_at.cpp @@ -35,7 +35,8 @@ namespace nd4j { REQUIRE_TRUE(dim < input->rankOf(), 0, "Size_At: Dim can't be higher then input rank") - output->assign(input->sizeAt(dim)); + output->p(0, input->sizeAt(dim)); + output->syncToDevice(); return Status::OK(); } @@ -47,7 +48,8 @@ namespace nd4j { DECLARE_TYPES(size_at) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes(DataType::INT64); + ->setAllowedOutputTypes(DataType::INT64) + ->allowOverride(true); } } } diff --git a/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp b/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp index 7f91b2ae4..5139a95cc 100644 --- a/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp +++ b/libnd4j/include/ops/declarable/impl/OpDescriptor.cpp @@ -166,6 +166,11 @@ namespace nd4j { return this; } + OpDescriptor* OpDescriptor::allowOverride(bool allowOverride) { + _dtypeOverride = allowOverride; + return this; + } + OpDescriptor* OpDescriptor::setAllowedInputTypes(const nd4j::DataType dtype) { _allowedIns.clear(); _allowedIns.emplace_back(dtype); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp index 0bed375cc..a23d5421e 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -54,4 +54,16 @@ TEST_F(DeclarableOpsTests16, test_scatter_update_119) { ASSERT_EQ(e, *z); delete result; +} + +TEST_F(DeclarableOpsTests16, test_size_dtype_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); + + nd4j::ops::size op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 1661bd99d..ef3710371 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1189,6 +1189,22 @@ TEST_F(JavaInteropTests, test_ismax_view) { delete t; } +TEST_F(JavaInteropTests, test_size_dtype_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 1, 1}); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + nd4j::ops::size op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 0541f914a..2c7962c5d 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -9272,6 +9272,7 @@ public static final int PREALLOC_SIZE = 33554432; public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype); + public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow); public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 4fee4b3b7..43a05a33a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -11557,6 +11557,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native OpDescriptor setAllowedOutputTypes(int index, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedInputTypes(@Cast("nd4j::DataType") int dtype); public native OpDescriptor setAllowedOutputTypes(@Cast("nd4j::DataType") int dtype); + public native OpDescriptor allowOverride(@Cast("bool") boolean reallyAllow); public native OpDescriptor setSameMode(@Cast("bool") boolean reallySame); public native OpDescriptor setInputType(int idx, @Cast("nd4j::DataType") int dtype); public native OpDescriptor setOutputType(int idx, @Cast("nd4j::DataType") int dtype); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index e325adea8..a1bc56703 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -44,6 +44,7 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.nativeblas.NativeOpsHolder; +import java.util.ArrayList; import java.util.List; import static org.junit.Assert.*; @@ -641,4 +642,34 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(result1, result2); } + + @Test + public void testSizeTypes(){ + List failed = new ArrayList<>(); + for(DataType dt : new DataType[]{DataType.LONG, DataType.INT, DataType.SHORT, DataType.BYTE, + DataType.UINT64, DataType.UINT32, DataType.UINT16, DataType.UBYTE, + DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.BFLOAT16}) { + + INDArray in = Nd4j.create(DataType.FLOAT, 100); + INDArray out = Nd4j.scalar(dt, 0); + INDArray e = Nd4j.scalar(dt, 100); + + DynamicCustomOp op = DynamicCustomOp.builder("size") + .addInputs(in) + .addOutputs(out) + .build(); + + try { + Nd4j.exec(op); + + assertEquals(e, out); + } catch (Throwable t){ + failed.add(dt); + } + } + + if(!failed.isEmpty()){ + fail("Failed datatypes: " + failed.toString()); + } + } }