- correct reshape op for empty shapes (#354)

* - correct reshape op for empty shape in case of -1 at the end

Signed-off-by: Yurii <iuriish@yahoo.com>

* Fix test + new reshape op constructor

Signed-off-by: Alex Black <blacka101@gmail.com>

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Yurii Shyrma 2020-04-01 07:13:34 +03:00 committed by GitHub
parent 81ebfeead1
commit 48102c61d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 88 additions and 40 deletions

View File

@ -81,43 +81,79 @@ DECLARE_SHAPE_FN(reshape) {
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
Nd4jLong xLen = x->lengthOf(); // Nd4jLong xLen = x->lengthOf();
if(x->isEmpty()) { // if(x->isEmpty()) {
xLen = 1; // xLen = 1;
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
if(x->sizeAt(i) != 0) // if(x->sizeAt(i) != 0)
xLen *= x->sizeAt(i); // xLen *= x->sizeAt(i);
// }
// for (uint i = 0; i < reshapeArgs.size(); ++i) {
// if (reshapeArgs[i] == -1) {
// uint shapeLength = 1, numOfZeros = 0;
// for(uint j = 0; j < i; ++j)
// if(reshapeArgs[j] != 0)
// shapeLength *= reshapeArgs[j];
// else
// ++numOfZeros;
// for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
// if(reshapeArgs[j] != 0)
// shapeLength *= reshapeArgs[j];
// else
// ++numOfZeros;
// }
// const auto dim = xLen / shapeLength;
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
// shapeNew.push_back(0);
// else
// shapeNew.push_back(dim);
// }
// else
// shapeNew.push_back(reshapeArgs[i]);
// }
Nd4jLong newShapeLen = 1;
int pos = -1;
bool newShapeEmpty = false;
for (int i = 0; i < reshapeArgs.size(); ++i) {
const int dim = reshapeArgs[i];
if (dim == -1) {
REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
pos = i;
shapeNew.push_back(1);
}
else if (dim == 0) {
shapeNew.push_back(0);
newShapeEmpty = true;
}
else {
shapeNew.push_back(dim);
newShapeLen *= dim;
}
} }
for (uint i = 0; i < reshapeArgs.size(); ++i) { if (pos != -1) {
if (reshapeArgs[i] == -1) { Nd4jLong xLen = x->lengthOf();
if(x->isEmpty()) {
uint shapeLength = 1, numOfZeros = 0; xLen = 1;
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
for(uint j = 0; j < i; ++j) if(x->sizeAt(i) > 0 || !newShapeEmpty)
if(reshapeArgs[j] != 0) xLen *= x->sizeAt(i);
shapeLength *= reshapeArgs[j];
else
++numOfZeros;
for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
if(reshapeArgs[j] != 0)
shapeLength *= reshapeArgs[j];
else
++numOfZeros;
}
const auto dim = xLen / shapeLength;
if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
shapeNew.push_back(0);
else
shapeNew.push_back(dim);
} }
else
shapeNew.push_back(reshapeArgs[i]); shapeNew[pos] = xLen / newShapeLen;
} }
auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
@ -126,6 +162,8 @@ DECLARE_SHAPE_FN(reshape) {
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew)); return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
} }
} }
} }

View File

@ -2288,7 +2288,7 @@ TEST_F(DeclarableOpsTests14, Reshape15) {
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1}); auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1}); auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0}); auto e0 = NDArrayFactory::create<float>('c', {2, 0, 1});
auto e1 = NDArrayFactory::create<float>('c', {0, 1}); auto e1 = NDArrayFactory::create<float>('c', {0, 1});
sd::ops::reshape op; sd::ops::reshape op;
@ -2374,6 +2374,7 @@ TEST_F(DeclarableOpsTests14, Reshape20) {
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32);
sd::ops::reshape op; sd::ops::reshape op;
@ -2416,4 +2417,8 @@ TEST_F(DeclarableOpsTests14, Reshape20) {
result = op.evaluate({&x7}, {}, {10,0,50,100}); result = op.evaluate({&x7}, {}, {10,0,50,100});
ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
result = op.evaluate({&x7}, {}, {2,0,-1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
ASSERT_TRUE(result.at(0)->isSameShape({2,0,1}));
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.shape; package org.nd4j.linalg.api.ops.impl.shape;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val; import lombok.val;
import onnx.Onnx; import onnx.Onnx;
@ -55,8 +56,12 @@ public class Reshape extends DynamicCustomOp {
super(null, sameDiff, new SDVariable[]{i_v, shape}); super(null, sameDiff, new SDVariable[]{i_v, shape});
} }
public Reshape(INDArray in, INDArray shape, INDArray out){ public Reshape(INDArray in, INDArray shape){
super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List<Integer>)null); this(in, shape, null);
}
public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){
super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null);
} }
public Reshape(INDArray in, INDArray shape) { public Reshape(INDArray in, INDArray shape) {

View File

@ -8255,11 +8255,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0); INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2); INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0]; INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1)))[0];
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0]; INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1)))[0];
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0]; INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1)))[0];
assertArrayEquals(new long[]{2, 0, 0}, out0.shape()); assertArrayEquals(new long[]{2, 0, 1}, out0.shape());
assertArrayEquals(new long[]{0, 1}, out1.shape()); assertArrayEquals(new long[]{0, 1}, out1.shape());
assertArrayEquals(new long[]{10, 0}, out2.shape()); assertArrayEquals(new long[]{10, 0}, out2.shape());
} }