From 62f93ac211e0a80c87b2fe1ca5cfe05408455fa3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 24 Dec 2019 13:23:25 +0300 Subject: [PATCH] negative handling for empty arrays (#146) Signed-off-by: raver119 --- .../ops/declarable/generic/shape/reshape.cpp | 23 ++++++++---- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 37 +++++++++++++++++++ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 0bc80fa91..1d76138f2 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -173,13 +173,6 @@ namespace nd4j { order = shape::order(inp); e = 0; } - -// //Special case: empty.reshape(-1) -> return empty -// if (INPUT_VARIABLE(0)->isEmpty()) { -// // -// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp)); -// return SHAPELIST(newShape); -// } std::vector shapeNew; @@ -226,11 +219,25 @@ namespace nd4j { //REQUIRE_TRUE(y->lengthOf() == 1 && y->e(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]"); auto shapeOf = y->getBufferAsVector(); Nd4jLong prod = 1; - for (auto v:shapeOf) + bool hasNegs = false; + for (auto v:shapeOf) { + if (v < 0) { + hasNegs = true; + v = 0; + } + prod *= v; + } REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well"); + // if there are -1s - we turn them into zeros + if (hasNegs) { + for (int e = 0; e < shapeOf.size(); e++) + if (shapeOf[e] < 0) + shapeOf[e] = 0; + } + auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data()); return SHAPELIST(CONSTANT(newShape)); } diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index ca1479210..f17c5aa5a 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -286,4 +286,41 @@ TEST_F(EmptyTests, test_shaped_empty_4) { ASSERT_TRUE(array.isEmpty()); ASSERT_EQ(1, array.rankOf()); ASSERT_EQ(shapeOf, array.getShapeAsVector()); +} + +TEST_F(EmptyTests, test_empty_reshape_1) { + /* + INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); + INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); + + INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; + INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; + INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; + + assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); + assertArrayEquals(new long[]{0, 1}, out1.shape()); + assertArrayEquals(new long[]{10, 0}, out2.shape()); + */ + auto x0 = NDArrayFactory::create('c', {2, 0}); + auto x1 = NDArrayFactory::create('c', {0, 1, 2}); + + auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); + auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); + + auto e0 = NDArrayFactory::create('c', {2, 0, 0}); + auto e1 = NDArrayFactory::create('c', {0, 1}); + + nd4j::ops::reshape op; + auto result0 = op.execute({&x0, &shape0}, {}, {}); + ASSERT_EQ(Status::OK(), result0->status()); + auto z0 = result0->at(0); + ASSERT_EQ(e0, *z0); + + auto result1 = op.execute({&x1, &shape1}, {}, {}); + ASSERT_EQ(Status::OK(), result1->status()); + auto z1 = result1->at(0); + ASSERT_EQ(e1, *z1); + + delete result0; + delete result1; } \ No newline at end of file