parent
1f9e1b6022
commit
62f93ac211
|
@ -174,13 +174,6 @@ namespace nd4j {
|
|||
e = 0;
|
||||
}
|
||||
|
||||
// //Special case: empty.reshape(-1) -> return empty
|
||||
// if (INPUT_VARIABLE(0)->isEmpty()) {
|
||||
// //
|
||||
// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
||||
// return SHAPELIST(newShape);
|
||||
// }
|
||||
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
|
||||
int e2 = e;
|
||||
|
@ -226,11 +219,25 @@ namespace nd4j {
|
|||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||
Nd4jLong prod = 1;
|
||||
for (auto v:shapeOf)
|
||||
bool hasNegs = false;
|
||||
for (auto v:shapeOf) {
|
||||
if (v < 0) {
|
||||
hasNegs = true;
|
||||
v = 0;
|
||||
}
|
||||
|
||||
prod *= v;
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||
|
||||
// if there are -1s - we turn them into zeros
|
||||
if (hasNegs) {
|
||||
for (int e = 0; e < shapeOf.size(); e++)
|
||||
if (shapeOf[e] < 0)
|
||||
shapeOf[e] = 0;
|
||||
}
|
||||
|
||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
}
|
||||
|
|
|
@ -287,3 +287,40 @@ TEST_F(EmptyTests, test_shaped_empty_4) {
|
|||
ASSERT_EQ(1, array.rankOf());
|
||||
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, test_empty_reshape_1) {
|
||||
/*
|
||||
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
|
||||
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
|
||||
|
||||
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
|
||||
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
|
||||
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
|
||||
|
||||
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
|
||||
assertArrayEquals(new long[]{0, 1}, out1.shape());
|
||||
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
||||
*/
|
||||
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
|
||||
|
||||
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
||||
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
||||
|
||||
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
||||
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
||||
|
||||
nd4j::ops::reshape op;
|
||||
auto result0 = op.execute({&x0, &shape0}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result0->status());
|
||||
auto z0 = result0->at(0);
|
||||
ASSERT_EQ(e0, *z0);
|
||||
|
||||
auto result1 = op.execute({&x1, &shape1}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result1->status());
|
||||
auto z1 = result1->at(0);
|
||||
ASSERT_EQ(e1, *z1);
|
||||
|
||||
delete result0;
|
||||
delete result1;
|
||||
}
|
Loading…
Reference in New Issue