negative handling for empty arrays (#146)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-24 13:23:25 +03:00 committed by GitHub
parent 1f9e1b6022
commit 62f93ac211
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 8 deletions

View File

@ -173,13 +173,6 @@ namespace nd4j {
order = shape::order(inp);
e = 0;
}
// //Special case: empty.reshape(-1) -> return empty
// if (INPUT_VARIABLE(0)->isEmpty()) {
// //
// auto newShape = ConstantShapeHelper::getInstance()->emptyShapeInfo(ArrayOptions::dataType(inp));
// return SHAPELIST(newShape);
// }
std::vector<Nd4jLong> shapeNew;
@ -226,11 +219,25 @@ namespace nd4j {
//REQUIRE_TRUE(y->lengthOf() == 1 && y->e<Nd4jLong>(0) == -1, 0, "Reshape: when input is empty, shape must be [-1]");
auto shapeOf = y->getBufferAsVector<Nd4jLong>();
Nd4jLong prod = 1;
for (auto v:shapeOf)
bool hasNegs = false;
for (auto v:shapeOf) {
if (v < 0) {
hasNegs = true;
v = 0;
}
prod *= v;
}
REQUIRE_TRUE(prod == 0, 0, "Reshape: in case of empty arrays reshape must return empty array as well");
// if there are -1s - we turn them into zeros
if (hasNegs) {
for (int e = 0; e < shapeOf.size(); e++)
if (shapeOf[e] < 0)
shapeOf[e] = 0;
}
auto newShape = ShapeBuilders::createShapeInfo(ArrayOptions::dataType(inp), shape::order(inp), y->lengthOf(), shapeOf.data());
return SHAPELIST(CONSTANT(newShape));
}

View File

@ -286,4 +286,41 @@ TEST_F(EmptyTests, test_shaped_empty_4) {
ASSERT_TRUE(array.isEmpty());
ASSERT_EQ(1, array.rankOf());
ASSERT_EQ(shapeOf, array.getShapeAsVector());
}
TEST_F(EmptyTests, test_empty_reshape_1) {
/*
INDArray arr0 = Nd4j.create(DataType.FLOAT, 2, 0);
INDArray arr1 = Nd4j.create(DataType.FLOAT, 0, 1, 2);
INDArray out0 = Nd4j.exec(new Reshape(arr0, Nd4j.createFromArray(2, 0, -1), Nd4j.create(DataType.FLOAT, 2, 0, 0)))[0];
INDArray out1 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(-1, 1), Nd4j.create(DataType.FLOAT, 0, 1)))[0];
INDArray out2 = Nd4j.exec(new Reshape(arr1, Nd4j.createFromArray(10, -1), Nd4j.create(DataType.FLOAT, 10, 0)))[0];
assertArrayEquals(new long[]{2, 0, 0}, out0.shape());
assertArrayEquals(new long[]{0, 1}, out1.shape());
assertArrayEquals(new long[]{10, 0}, out2.shape());
*/
auto x0 = NDArrayFactory::create<float>('c', {2, 0});
auto x1 = NDArrayFactory::create<float>('c', {0, 1, 2});
auto shape0 = NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 0, -1});
auto shape1 = NDArrayFactory::create<Nd4jLong>('c', {2}, {-1, 1});
auto e0 = NDArrayFactory::create<float>('c', {2, 0, 0});
auto e1 = NDArrayFactory::create<float>('c', {0, 1});
nd4j::ops::reshape op;
auto result0 = op.execute({&x0, &shape0}, {}, {});
ASSERT_EQ(Status::OK(), result0->status());
auto z0 = result0->at(0);
ASSERT_EQ(e0, *z0);
auto result1 = op.execute({&x1, &shape1}, {}, {});
ASSERT_EQ(Status::OK(), result1->status());
auto z1 = result1->at(0);
ASSERT_EQ(e1, *z1);
delete result0;
delete result1;
}