From 9bb5798caca11753a1957c3f94929a09a709f618 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 2 Feb 2020 23:14:00 +0300 Subject: [PATCH] Null arrays fix (#208) * don't skip null arrays Signed-off-by: raver119 * one test tweak Signed-off-by: raver119 --- libnd4j/include/ops/declarable/DeclarableOp.h | 4 +-- .../declarable/generic/nn/convo/conv1d.cpp | 2 +- .../ops/declarable/impl/DeclarableOp.cpp | 21 ++++--------- .../layers_tests/DeclarableOpsTests19.cpp | 31 ++++++++++++++++++- .../java/org/nd4j/linalg/rng/RandomTests.java | 4 ++- 5 files changed, 42 insertions(+), 20 deletions(-) diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index ff8fe9e83..78f5fcaa4 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -171,7 +171,7 @@ namespace nd4j { Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs); - template + template ::value>> Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, std::initializer_list tArgs); Nd4jStatus execute(const std::vector &inputs, const std::vector &outputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); @@ -179,7 +179,7 @@ namespace nd4j { nd4j::ResultSet* evaluate(const std::vector &inputs); - template + template ::value>> nd4j::ResultSet* evaluate(const std::vector &inputs, std::initializer_list args); nd4j::ResultSet* evaluate(const std::vector &inputs, const std::vector &tArgs, const std::vector &iArgs, const std::vector &bArgs = std::vector(), const std::vector &dArgs = std::vector(), bool isInplace = false); diff --git a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp index 2800e7185..9cd3285f3 100644 --- a/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/convo/conv1d.cpp @@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) { auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC] nd4j::ops::conv2d_bp conv2dBP; - const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); + auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {}); if (status != ND4J_STATUS_OK) return status; diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 7c4138d36..46d10b51c 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -165,10 +165,7 @@ namespace nd4j { // we build list of input shapes if (ctx.isFastPath()) { for (const auto p:ctx.fastpath_in()) { - if (p == nullptr) - continue; - - inSha.push_back(p->getShapeInfo()); + inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo()); } } else { for (auto p: *ctx.inputs()) { @@ -184,6 +181,11 @@ namespace nd4j { } } + // if we override shape function, we'll return size of fastPath + if (ctx.isFastPath() && ctx.shapeFunctionOverride()) { + return (int) ctx.fastpath_out().size(); + } + // optionally saving input time if (Environment::getInstance()->isProfiling() && node != nullptr) { inputEnd = std::chrono::system_clock::now(); @@ -193,11 +195,6 @@ namespace nd4j { shapeStart = std::chrono::system_clock::now(); } - // if we override shape function, we'll return size of fastPath - if (ctx.isFastPath() && ctx.shapeFunctionOverride()) { - return (int) ctx.fastpath_out().size(); - } - auto outSha = this->calculateOutputShape(&inSha, ctx); results = outSha->size(); @@ -870,16 +867,10 @@ namespace nd4j { Context ctx(1); for (int e = 0; e < inputs.size(); e++) { - if (inputs[e] == nullptr) - break; - ctx.setInputArray(e, inputs[e]); } for (int e = 0; e < outputs.size(); e++) { - if (outputs[e] == nullptr) - break; - ctx.setOutputArray(e, outputs[e]); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 871bfe186..9883a9d79 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -37,4 +37,33 @@ public: printf("\n"); fflush(stdout); } -}; \ No newline at end of file +}; + +TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) { + /* + DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") + .addInputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3), + Nd4j.create(DataType.FLOAT, 2,3,6) + ) + .addOutputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3)) + .addIntegerArguments(3,2,0,1,2,0) + .build(); + + Nd4j.exec(op); + */ + + auto t = NDArrayFactory::create('c', {2, 2, 12}); + auto u = NDArrayFactory::create('c', {3, 2, 3}); + auto v = NDArrayFactory::create('c', {2, 3, 6}); + + nd4j::ops::conv1d_bp op; + auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0}); + ASSERT_EQ(Status::OK(), result->status()); + + + delete result; +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java index 8a06bd7e9..b2de46e1b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/rng/RandomTests.java @@ -810,9 +810,11 @@ public class RandomTests extends BaseNd4jTest { threads[x].start(); } - for (int x = 0; x < threads.length; x++) { + // we want all threads finished before comparing arrays + for (int x = 0; x < threads.length; x++) threads[x].join(); + for (int x = 0; x < threads.length; x++) { assertNotEquals(null, list.get(x)); if (x > 0) {