diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 5ac7686e2..023e9bf89 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -81,43 +81,79 @@ DECLARE_SHAPE_FN(reshape) { REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); - Nd4jLong xLen = x->lengthOf(); - if(x->isEmpty()) { - xLen = 1; - for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - if(x->sizeAt(i) != 0) - xLen *= x->sizeAt(i); + // Nd4jLong xLen = x->lengthOf(); + // if(x->isEmpty()) { + // xLen = 1; + // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + // if(x->sizeAt(i) != 0) + // xLen *= x->sizeAt(i); + // } + + // for (uint i = 0; i < reshapeArgs.size(); ++i) { + + // if (reshapeArgs[i] == -1) { + + // uint shapeLength = 1, numOfZeros = 0; + + // for(uint j = 0; j < i; ++j) + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + + // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { + // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + // } + + // const auto dim = xLen / shapeLength; + + // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) + // shapeNew.push_back(0); + // else + // shapeNew.push_back(dim); + // } + // else + // shapeNew.push_back(reshapeArgs[i]); + // } + + Nd4jLong newShapeLen = 1; + int pos = -1; + bool newShapeEmpty = false; + + for (int i = 0; i < reshapeArgs.size(); ++i) { + + const int dim = reshapeArgs[i]; + + if (dim == -1) { + REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + pos = i; + shapeNew.push_back(1); + } + else if (dim == 0) { + shapeNew.push_back(0); + newShapeEmpty = true; + } + else { + shapeNew.push_back(dim); + newShapeLen *= dim; + } } - for (uint i = 0; i < reshapeArgs.size(); ++i) { + if (pos != -1) { - if (reshapeArgs[i] == -1) { - - uint shapeLength = 1, numOfZeros = 0; - - for(uint j = 0; j < i; ++j) - if(reshapeArgs[j] != 0) - shapeLength *= reshapeArgs[j]; - else - ++numOfZeros; - - for(uint j = i + 1; j < reshapeArgs.size(); ++j) { - REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - if(reshapeArgs[j] != 0) - shapeLength *= reshapeArgs[j]; - else - ++numOfZeros; - } - - const auto dim = xLen / shapeLength; - - if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) - shapeNew.push_back(0); - else - shapeNew.push_back(dim); + Nd4jLong xLen = x->lengthOf(); + if(x->isEmpty()) { + xLen = 1; + for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + if(x->sizeAt(i) > 0 || !newShapeEmpty) + xLen *= x->sizeAt(i); } - else - shapeNew.push_back(reshapeArgs[i]); + + shapeNew[pos] = xLen / newShapeLen; } auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); @@ -126,6 +162,8 @@ DECLARE_SHAPE_FN(reshape) { return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew)); } + + } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index db49c12f2..b4c9839ab 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -2288,7 +2288,7 @@ TEST_F(DeclarableOpsTests14, Reshape15) { auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); - auto e0 = NDArrayFactory::create('c', {2, 0, 0}); + auto e0 = NDArrayFactory::create('c', {2, 0, 1}); auto e1 = NDArrayFactory::create('c', {0, 1}); sd::ops::reshape op; @@ -2374,6 +2374,7 @@ TEST_F(DeclarableOpsTests14, Reshape20) { NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); + NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32); sd::ops::reshape op; @@ -2416,4 +2417,8 @@ TEST_F(DeclarableOpsTests14, Reshape20) { result = op.evaluate({&x7}, {}, {10,0,50,100}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); + + result = op.evaluate({&x7}, {}, {2,0,-1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,1})); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index ddf0224db..c6b79b2b1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.shape; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import onnx.Onnx; @@ -55,8 +56,12 @@ public class Reshape extends DynamicCustomOp { super(null, sameDiff, new SDVariable[]{i_v, shape}); } - public Reshape(INDArray in, INDArray shape, INDArray out){ - super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List)null); + public Reshape(INDArray in, INDArray shape){ + this(in, shape, null); + } + + public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){ + super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List)null); } public Reshape(INDArray in, INDArray shape) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 5424d3c50..da91fb6cf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8255,11 +8255,11 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); - INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; - INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; - INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; + INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1)))[0]; + INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1)))[0]; + INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1)))[0]; - assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); + assertArrayEquals(new long[]{2, 0, 1}, out0.shape()); assertArrayEquals(new long[]{0, 1}, out1.shape()); assertArrayEquals(new long[]{10, 0}, out2.shape()); }