- correct reshape op for empty shapes (#354)
* - correct reshape op for empty shape in case of -1 at the end Signed-off-by: Yurii <iuriish@yahoo.com> * Fix test + new reshape op constructor Signed-off-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com>master
parent
81ebfeead1
commit
48102c61d0
|
@ -81,43 +81,79 @@ DECLARE_SHAPE_FN(reshape) {
|
||||||
|
|
||||||
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
||||||
|
|
||||||
|
// Nd4jLong xLen = x->lengthOf();
|
||||||
|
// if(x->isEmpty()) {
|
||||||
|
// xLen = 1;
|
||||||
|
// for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||||
|
// if(x->sizeAt(i) != 0)
|
||||||
|
// xLen *= x->sizeAt(i);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||||
|
|
||||||
|
// if (reshapeArgs[i] == -1) {
|
||||||
|
|
||||||
|
// uint shapeLength = 1, numOfZeros = 0;
|
||||||
|
|
||||||
|
// for(uint j = 0; j < i; ++j)
|
||||||
|
// if(reshapeArgs[j] != 0)
|
||||||
|
// shapeLength *= reshapeArgs[j];
|
||||||
|
// else
|
||||||
|
// ++numOfZeros;
|
||||||
|
|
||||||
|
// for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||||
|
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
// if(reshapeArgs[j] != 0)
|
||||||
|
// shapeLength *= reshapeArgs[j];
|
||||||
|
// else
|
||||||
|
// ++numOfZeros;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const auto dim = xLen / shapeLength;
|
||||||
|
|
||||||
|
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||||
|
// shapeNew.push_back(0);
|
||||||
|
// else
|
||||||
|
// shapeNew.push_back(dim);
|
||||||
|
// }
|
||||||
|
// else
|
||||||
|
// shapeNew.push_back(reshapeArgs[i]);
|
||||||
|
// }
|
||||||
|
|
||||||
|
Nd4jLong newShapeLen = 1;
|
||||||
|
int pos = -1;
|
||||||
|
bool newShapeEmpty = false;
|
||||||
|
|
||||||
|
for (int i = 0; i < reshapeArgs.size(); ++i) {
|
||||||
|
|
||||||
|
const int dim = reshapeArgs[i];
|
||||||
|
|
||||||
|
if (dim == -1) {
|
||||||
|
REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||||
|
pos = i;
|
||||||
|
shapeNew.push_back(1);
|
||||||
|
}
|
||||||
|
else if (dim == 0) {
|
||||||
|
shapeNew.push_back(0);
|
||||||
|
newShapeEmpty = true;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
shapeNew.push_back(dim);
|
||||||
|
newShapeLen *= dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos != -1) {
|
||||||
|
|
||||||
Nd4jLong xLen = x->lengthOf();
|
Nd4jLong xLen = x->lengthOf();
|
||||||
if(x->isEmpty()) {
|
if(x->isEmpty()) {
|
||||||
xLen = 1;
|
xLen = 1;
|
||||||
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||||
if(x->sizeAt(i) != 0)
|
if(x->sizeAt(i) > 0 || !newShapeEmpty)
|
||||||
xLen *= x->sizeAt(i);
|
xLen *= x->sizeAt(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
shapeNew[pos] = xLen / newShapeLen;
|
||||||
|
|
||||||
if (reshapeArgs[i] == -1) {
|
|
||||||
|
|
||||||
uint shapeLength = 1, numOfZeros = 0;
|
|
||||||
|
|
||||||
for(uint j = 0; j < i; ++j)
|
|
||||||
if(reshapeArgs[j] != 0)
|
|
||||||
shapeLength *= reshapeArgs[j];
|
|
||||||
else
|
|
||||||
++numOfZeros;
|
|
||||||
|
|
||||||
for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
|
||||||
REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
|
||||||
if(reshapeArgs[j] != 0)
|
|
||||||
shapeLength *= reshapeArgs[j];
|
|
||||||
else
|
|
||||||
++numOfZeros;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto dim = xLen / shapeLength;
|
|
||||||
|
|
||||||
if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
|
||||||
shapeNew.push_back(0);
|
|
||||||
else
|
|
||||||
shapeNew.push_back(dim);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
shapeNew.push_back(reshapeArgs[i]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||||
|
@ -126,6 +162,8 @@ DECLARE_SHAPE_FN(reshape) {
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(x->dataType(), orderNew, shapeNew));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2288,7 +2288,7 @@ TEST_F(DeclarableOpsTests14, Reshape15) {
|
||||||
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
||||||
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
||||||
|
|
||||||
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 1});
|
||||||
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
|
@ -2374,6 +2374,7 @@ TEST_F(DeclarableOpsTests14, Reshape20) {
|
||||||
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
|
NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32);
|
||||||
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
|
NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32);
|
||||||
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
|
NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32);
|
||||||
|
NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32);
|
||||||
|
|
||||||
sd::ops::reshape op;
|
sd::ops::reshape op;
|
||||||
|
|
||||||
|
@ -2416,4 +2417,8 @@ TEST_F(DeclarableOpsTests14, Reshape20) {
|
||||||
result = op.evaluate({&x7}, {}, {10,0,50,100});
|
result = op.evaluate({&x7}, {}, {10,0,50,100});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
|
ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100}));
|
||||||
|
|
||||||
|
result = op.evaluate({&x7}, {}, {2,0,-1});
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
ASSERT_TRUE(result.at(0)->isSameShape({2,0,1}));
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.shape;
|
package org.nd4j.linalg.api.ops.impl.shape;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
|
@ -55,8 +56,12 @@ public class Reshape extends DynamicCustomOp {
|
||||||
super(null, sameDiff, new SDVariable[]{i_v, shape});
|
super(null, sameDiff, new SDVariable[]{i_v, shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
public Reshape(INDArray in, INDArray shape, INDArray out){
|
public Reshape(INDArray in, INDArray shape){
|
||||||
super(null, new INDArray[]{in, shape}, new INDArray[]{out}, null, (List<Integer>)null);
|
this(in, shape, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Reshape(@NonNull INDArray in, @NonNull INDArray shape, INDArray out){
|
||||||
|
super(null, new INDArray[]{in, shape}, wrapOrNull(out), null, (List<Integer>)null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Reshape(INDArray in, INDArray shape) {
|
public Reshape(INDArray in, INDArray shape) {
|
||||||
|
|
|
@ -8255,11 +8255,11 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
|
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
|
||||||
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
|
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
|
||||||
|
|
||||||
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
|
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1)))[0];
|
||||||
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
|
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1)))[0];
|
||||||
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
|
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1)))[0];
|
||||||
|
|
||||||
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
|
assertArrayEquals(new long[]{2, 0, 1}, out0.shape());
|
||||||
assertArrayEquals(new long[]{0, 1}, out1.shape());
|
assertArrayEquals(new long[]{0, 1}, out1.shape());
|
||||||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue