- concat empty scalar fix

- couple of tests for empty scalar concat

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 13:16:50 +03:00
parent 729dc5e879
commit fb8de5006f
2 changed files with 47 additions and 3 deletions

View File

@ -124,8 +124,10 @@ DECLARE_SHAPE_FN(concat) {
for(int i = 0; i < block.width(); ++i) {
if(inputShape->at(i)[0] == 0) {
// FIXME, use this instead: block.dataType()
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
if (shape::isEmpty(inputShape->at(i)))
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType()));
else
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
}
else{
arrShapes.push_back(inputShape->at(i));
@ -165,7 +167,9 @@ DECLARE_SHAPE_FN(concat) {
}
for(int i = 1; i < numOfArrs; ++i)
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
if (!shape::isEmpty(arrShapes[i])) {
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
}
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));

View File

@ -107,6 +107,46 @@ TEST_F(EmptyTests, Test_Concat_2) {
delete result;
}
TEST_F(EmptyTests, Test_Concat_3) {
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
auto scalar1 = NDArrayFactory::create<float>(1.0f);
auto scalar2 = NDArrayFactory::create<float>(2.0f);
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
ASSERT_TRUE(empty.isEmpty());
nd4j::ops::concat op;
auto result = op.execute({&empty, &scalar1, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("z");
ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(EmptyTests, Test_Concat_4) {
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
auto scalar1 = NDArrayFactory::create<float>(1.0f);
auto scalar2 = NDArrayFactory::create<float>(2.0f);
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
ASSERT_TRUE(empty.isEmpty());
nd4j::ops::concat op;
auto result = op.execute({&scalar1, &empty, &scalar2}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("z");
ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(EmptyTests, Test_Reshape_1) {
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
auto exp = NDArrayFactory::create<float>(119.f);