- correct reshape op for empty shapes (#354)
* - correct reshape op for empty shape in case of -1 at the end Signed-off-by: Yurii <iuriish@yahoo.com> * Fix test + new reshape op constructor Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
81ebfeead1
commit
48102c61d0
|
@ -81,43 +81,79 @@ DECLARE_SHAPE_FN(reshape) {
|
|||
|
||||
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
||||
|
||||
Nd4jLong xLen = x->lengthOf();
|
||||
if(x->isEmpty()) {
|
||||
xLen = 1;
|
||||
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||
if(x->sizeAt(i) != 0)
|
||||
xLen *= x->sizeAt(i);
|
||||
// Nd4jLong xLen = x->lengthOf();
|
||||
// if(x->isEmpty()) {
|
||||
// xLen = 1;
|
||||
// for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||
// if(x->sizeAt(i) != 0)
|
||||
// xLen *= x->sizeAt(i);
|
||||
// }
|
||||
|
||||
// for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||
|
||||
// if (reshapeArgs[i] == -1) {
|
||||
|
||||
// uint shapeLength = 1, numOfZeros = 0;
|
||||
|
||||
// for(uint j = 0; j < i; ++j)
|
||||
// if(reshapeArgs[j] != 0)
|
||||
// shapeLength *= reshapeArgs[j];
|
||||
// else
|
||||
// ++numOfZeros;
|
||||
|
||||
// for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
// if(reshapeArgs[j] != 0)
|
||||
// shapeLength *= reshapeArgs[j];
|
||||
// else
|
||||
// ++numOfZeros;
|
||||
// }
|
||||
|
||||
// const auto dim = xLen / shapeLength;
|
||||
|
||||
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||
// shapeNew.push_back(0);
|
||||
// else
|
||||
// shapeNew.push_back(dim);
|
||||
// }
|
||||
// else
|
||||
// shapeNew.push_back(reshapeArgs[i]);
|
||||
// }
|
||||
|
||||
Nd4jLong newShapeLen = 1;
|
||||
int pos = -1;
|
||||
bool newShapeEmpty = false;
|
||||
|
||||
for (int i = 0; i < reshapeArgs.size(); ++i) {
|
||||
|
||||
const int dim = reshapeArgs[i];
|
||||
|
||||
if (dim == -1) {
|
||||
REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
pos = i;
|
||||
shapeNew.push_back(1);
|
||||
}
|
||||
else if (dim == 0) {
|
||||
shapeNew.push_back(0);
|
||||
newShapeEmpty = true;
|
||||
}
|
||||
else {
|
||||
shapeNew.push_back(dim);
|
||||
newShapeLen *= dim;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||
if (pos != -1) {
|
||||
|
||||
if (reshapeArgs[i] == -1) {
|
||||
|
||||
uint shapeLength = 1, numOfZeros = 0;
|
||||
|
||||
for(uint j = 0; j < i; ++j)
|
||||
if(reshapeArgs[j] != 0)
|
||||
shapeLength *= reshapeArgs[j];
|
||||
else
|
||||
++numOfZeros;
|
||||
|
||||
for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||
REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
if(reshapeArgs[j] != 0)
|
||||
shapeLength *= reshapeArgs[j];
|
||||
else
|
||||
++numOfZeros;
|
||||
}
|
||||
|
||||
const auto dim = xLen / shapeLength;
|
||||
|
||||
if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||
shapeNew.push_back(0);
|
||||
else
|
||||
shapeNew.push_back(dim);
|
||||
Nd4jLong xLen = x->lengthOf();
|
||||
if(x->isEmpty()) {
|
||||
xLen = 1;
|
||||
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||
if(x->sizeAt(i) > 0 || !newShapeEmpty)
|
||||
xLen *= x->sizeAt(i);
|
||||
}
|
||||
else
|
||||
shapeNew.push_back(reshapeArgs[i]);
|
||||
|
||||
shapeNew[pos] = xLen / newShapeLen;
|
||||
}
|
||||
|
||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||
|
@ -126,6 +162,8 @@ DECLARE_SHAPE_FN(reshape) {
|
|||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2288,7 +2288,7 @@ TEST_F(DeclarableOpsTests14, Reshape15) {
|
|||
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
||||
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
||||
|
||||
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
||||
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 1});
|
||||
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
||||
|
||||
sd::ops::reshape op;
|
||||
|
@ -2374,6 +2374,7 @@ TEST_F(DeclarableOpsTests14, Reshape20) {
|
|||
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
|
||||
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
|
||||
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
|
||||
NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32);
|
||||
|
||||
sd::ops::reshape op;
|
||||
|
||||
|
@ -2416,4 +2417,8 @@ TEST_F(DeclarableOpsTests14, Reshape20) {
|
|||
result = op.evaluate({&x7}, {}, {10,0,50,100});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
|
||||
|
||||
result = op.evaluate({&x7}, {}, {2,0,-1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
ASSERT_TRUE(result.at(0)->isSameShape({2,0,1}));
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.linalg.api.ops.impl.shape;
|
||||
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import onnx.Onnx;
|
||||
|
@ -55,8 +56,12 @@ public class Reshape extends DynamicCustomOp {
|
|||
super(null, sameDiff, new SDVariable[]{i_v, shape});
|
||||
}
|
||||
|
||||
public Reshape(INDArray in, INDArray shape, INDArray out){
|
||||
super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List<Integer>)null);
|
||||
public Reshape(INDArray in, INDArray shape){
|
||||
this(in, shape, null);
|
||||
}
|
||||
|
||||
public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){
|
||||
super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null);
|
||||
}
|
||||
|
||||
public Reshape(INDArray in, INDArray shape) {
|
||||
|
|
|
@ -8255,11 +8255,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
|||
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
|
||||
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
|
||||
|
||||
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
|
||||
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
|
||||
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
|
||||
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1)))[0];
|
||||
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1)))[0];
|
||||
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1)))[0];
|
||||
|
||||
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
|
||||
assertArrayEquals(new long[]{2, 0, 1}, out0.shape());
|
||||
assertArrayEquals(new long[]{0, 1}, out1.shape());
|
||||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue