diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp index 0bcbd2439..0d9465b02 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_add.cpp @@ -35,6 +35,9 @@ OP_IMPL(scatter_add, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); @@ -68,10 +71,8 @@ OP_IMPL(scatter_add, 3, 1, true) { REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_ADD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::Add, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp index a711916a1..dccc34e59 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_div.cpp @@ -16,7 +16,7 @@ // // @author Created by raver119 on 24.11.17. -// @author Yurii Shyrma (iuriish@yahoo.com) +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -28,21 +28,24 @@ namespace nd4j { namespace ops { OP_IMPL(scatter_div, 3, 1, true) { - + auto input = INPUT_VARIABLE(0); auto indices = INPUT_VARIABLE(1); auto updates = INPUT_VARIABLE(2); auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_DIV OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_DIV OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -50,28 +53,27 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_DIV OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_DIV OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); + if (!indices->isEmpty()) + // ScatterHelper::template scatterApply>(output, indices, updates); + helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::Divide, *indices, *updates, *output, lock); return Status::OK(); } DECLARE_SYN(ScatterDiv, scatter_div); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp index a9f0ab889..5d37a71d0 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_max.cpp @@ -35,14 +35,17 @@ OP_IMPL(scatter_max, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MAX OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MAX OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -50,28 +53,26 @@ OP_IMPL(scatter_max, 3, 1, true) { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MAX OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MAX OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); - helpers::scatter(block.launchContext(), pairwise::MaxPairwise, *indices, *updates, *output, lock); - return Status::OK(); } DECLARE_SYN(ScatterMax, scatter_max); diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp index cce22b6fb..1bed296f9 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_min.cpp @@ -35,14 +35,17 @@ OP_IMPL(scatter_min, 3, 1, true) { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MIN OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MIN OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -50,27 +53,25 @@ OP_IMPL(scatter_min, 3, 1, true) { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MIN OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MIN OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::MinPairwise, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp index 02eebb50c..46b9f7008 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_mul.cpp @@ -39,9 +39,13 @@ namespace nd4j { const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + + if (!block.isInplace()) + output->assign(input); + + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_MUL OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_MUL OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -49,27 +53,25 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_MUL OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_MUL OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + helpers::scatter(block.launchContext(), pairwise::Multiply, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp index de2bf4fa4..cf3745236 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_sub.cpp @@ -34,14 +34,17 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? false : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_SUB OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_SUB OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -49,29 +52,27 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_SUB OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_SUB OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + // ScatterHelper::template scatterApply>(output, indices, updates); + helpers::scatter(block.launchContext(), pairwise::Subtract, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp index bc13581bf..55076e51e 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/scatter_upd.cpp @@ -33,14 +33,17 @@ namespace nd4j { auto output = OUTPUT_VARIABLE(0); + if (!block.isInplace()) + output->assign(input); + const bool lock = block.getBArguments()->empty() ? true : B_ARG(0); const int inRank = input->rankOf(); const int indRank = indices->rankOf(); const int updRank = updates->rankOf(); - + REQUIRE_TRUE(inRank > 0, 0, "SCATTER_UPD OP: input should not be scalar !"); - + if(inRank == 1) { REQUIRE_TRUE(indices->isSameShape(updates), 0, "SCATTER_UPD OP: when input array has rank = 1 then indices and updates must have the same shapes, but got %s and %s correspondingly !", ShapeUtils::shapeAsString(indices).c_str(), ShapeUtils::shapeAsString(updates).c_str()); } @@ -48,28 +51,26 @@ namespace nd4j { std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = {indices->lengthOf()}; + std::vector expectedUpdShape = {indices->lengthOf()}; expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } else { - + REQUIRE_TRUE(updRank == indRank + inRank - 1, 0, "SCATTER_UPD OP: wrong rank of updates array, expected is %i, but got %i instead !", indRank + inRank - 1 , updRank); - + std::vector updShape = updates->getShapeAsVector(); std::vector inShape = input->getShapeAsVector(); - std::vector expectedUpdShape = indices->getShapeAsVector(); + std::vector expectedUpdShape = indices->getShapeAsVector(); expectedUpdShape.insert(expectedUpdShape.end(), inShape.begin()+1, inShape.end()); - + REQUIRE_TRUE(expectedUpdShape == updShape, 0, "SCATTER_UPD OP: wrong shape of updates array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedUpdShape).c_str(), ShapeUtils::shapeAsString(updShape).c_str()); } - if (!block.isInplace()) - output->assign(input); - - // ScatterHelper::template scatterApply>(output, indices, updates); - helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); + if (!indices->isEmpty()) + // ScatterHelper::template scatterApply>(output, indices, updates); + helpers::scatter(block.launchContext(), pairwise::CopyPws, *indices, *updates, *output, lock); return Status::OK(); } diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index b9445cc70..baba901bf 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -207,6 +207,46 @@ TEST_F(EmptyTests, Test_dup_1) { delete dup; } +TEST_F(EmptyTests, test_empty_scatter_1) { + auto x = NDArrayFactory::create('c', {5}); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); + + x.linspace(1.0f); + + nd4j::ops::scatter_upd op; + auto result = op.execute({&x, &indices, &updates}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(x, *z); + + delete result; +} + +TEST_F(EmptyTests, test_empty_scatter_2) { + auto x = NDArrayFactory::create('c', {5}); + auto z = NDArrayFactory::create('c', {5}); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); + + x.linspace(1.0f); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, indices.buffer(), indices.shapeInfo(), indices.specialBuffer(), indices.specialShapeInfo()); + ctx.setInputArray(2, updates.buffer(), updates.shapeInfo(), updates.specialBuffer(), updates.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + bool args[] = {true}; + ctx.setBArguments(args, 1); + + nd4j::ops::scatter_upd op; + auto result = op.execute(&ctx); + ASSERT_EQ(Status::OK(), result); + + ASSERT_EQ(x, z); +} + TEST_F(EmptyTests, test_shaped_empty_1) { auto empty = NDArrayFactory::create('c', {2, 0, 3}); std::vector shape = {2, 0, 3}; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 95a087301..46dd786c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -1732,6 +1732,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { if(isEmpty()) return this; + Nd4j.getCompressor().autoDecompress(this); + // fixme: eventually it would be nice to have this in native code if (isS()) { val list = new ArrayList(); @@ -1741,8 +1743,9 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.create(list, this.shape(), this.ordering()); } - Nd4j.getCompressor().autoDecompress(this); - return Shape.toOffsetZeroCopy(this, order); + val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order); + z.assign(this); + return z; } /**