parent
1f9e1b6022
commit
62f93ac211
|
@ -173,13 +173,6 @@ namespace nd4j {
|
||||||
order = shape::order(inp);
|
order = shape::order(inp);
|
||||||
e = 0;
|
e = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// //Special case: empty.reshape(-1) -> return empty
|
|
||||||
// if (INPUT_VARIABLE(0)->isEmpty()) {
|
|
||||||
// //
|
|
||||||
// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
|
|
||||||
// return SHAPELIST(newShape);
|
|
||||||
// }
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> shapeNew;
|
std::vector<Nd4jLong> shapeNew;
|
||||||
|
|
||||||
|
@ -226,11 +219,25 @@ namespace nd4j {
|
||||||
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
|
||||||
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
|
||||||
Nd4jLong prod = 1;
|
Nd4jLong prod = 1;
|
||||||
for (auto v:shapeOf)
|
bool hasNegs = false;
|
||||||
|
for (auto v:shapeOf) {
|
||||||
|
if (v < 0) {
|
||||||
|
hasNegs = true;
|
||||||
|
v = 0;
|
||||||
|
}
|
||||||
|
|
||||||
prod *= v;
|
prod *= v;
|
||||||
|
}
|
||||||
|
|
||||||
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
|
||||||
|
|
||||||
|
// if there are -1s - we turn them into zeros
|
||||||
|
if (hasNegs) {
|
||||||
|
for (int e = 0; e < shapeOf.size(); e++)
|
||||||
|
if (shapeOf[e] < 0)
|
||||||
|
shapeOf[e] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
|
||||||
return SHAPELIST(CONSTANT(newShape));
|
return SHAPELIST(CONSTANT(newShape));
|
||||||
}
|
}
|
||||||
|
|
|
@ -286,4 +286,41 @@ TEST_F(EmptyTests, test_shaped_empty_4) {
|
||||||
ASSERT_TRUE(array.isEmpty());
|
ASSERT_TRUE(array.isEmpty());
|
||||||
ASSERT_EQ(1, array.rankOf());
|
ASSERT_EQ(1, array.rankOf());
|
||||||
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
ASSERT_EQ(shapeOf, array.getShapeAsVector());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, test_empty_reshape_1) {
|
||||||
|
/*
|
||||||
|
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
|
||||||
|
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
|
||||||
|
|
||||||
|
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
|
||||||
|
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
|
||||||
|
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
|
||||||
|
|
||||||
|
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
|
||||||
|
assertArrayEquals(new long[]{0, 1}, out1.shape());
|
||||||
|
assertArrayEquals(new long[]{10, 0}, out2.shape());
|
||||||
|
*/
|
||||||
|
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
|
||||||
|
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
|
||||||
|
|
||||||
|
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
|
||||||
|
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
|
||||||
|
|
||||||
|
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
|
||||||
|
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
|
||||||
|
|
||||||
|
nd4j::ops::reshape op;
|
||||||
|
auto result0 = op.execute({&x0, &shape0}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result0->status());
|
||||||
|
auto z0 = result0->at(0);
|
||||||
|
ASSERT_EQ(e0, *z0);
|
||||||
|
|
||||||
|
auto result1 = op.execute({&x1, &shape1}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result1->status());
|
||||||
|
auto z1 = result1->at(0);
|
||||||
|
ASSERT_EQ(e1, *z1);
|
||||||
|
|
||||||
|
delete result0;
|
||||||
|
delete result1;
|
||||||
}
|
}
|
Loading…
Reference in New Issue