- concat empty scalar fix
- couple of tests for empty scalar concat Signed-off-by: raver119 <raver119@gmail.com>master
parent
729dc5e879
commit
fb8de5006f
|
@ -124,7 +124,9 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
for(int i = 0; i < block.width(); ++i) {
|
for(int i = 0; i < block.width(); ++i) {
|
||||||
|
|
||||||
if(inputShape->at(i)[0] == 0) {
|
if(inputShape->at(i)[0] == 0) {
|
||||||
// FIXME, use this instead: block.dataType()
|
if (shape::isEmpty(inputShape->at(i)))
|
||||||
|
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType()));
|
||||||
|
else
|
||||||
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
|
@ -165,7 +167,9 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i)
|
for(int i = 1; i < numOfArrs; ++i)
|
||||||
|
if (!shape::isEmpty(arrShapes[i])) {
|
||||||
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||||
|
}
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||||
|
|
||||||
|
|
|
@ -107,6 +107,46 @@ TEST_F(EmptyTests, Test_Concat_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Concat_3) {
|
||||||
|
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
|
||||||
|
auto scalar1 = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto scalar2 = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
|
||||||
|
nd4j::ops::concat op;
|
||||||
|
auto result = op.execute({&empty, &scalar1, &scalar2}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
z->printIndexedBuffer("z");
|
||||||
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmptyTests, Test_Concat_4) {
|
||||||
|
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
|
||||||
|
auto scalar1 = NDArrayFactory::create<float>(1.0f);
|
||||||
|
auto scalar2 = NDArrayFactory::create<float>(2.0f);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||||
|
|
||||||
|
ASSERT_TRUE(empty.isEmpty());
|
||||||
|
|
||||||
|
nd4j::ops::concat op;
|
||||||
|
auto result = op.execute({&scalar1, &empty, &scalar2}, {}, {0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
z->printIndexedBuffer("z");
|
||||||
|
ASSERT_EQ(exp, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(EmptyTests, Test_Reshape_1) {
|
TEST_F(EmptyTests, Test_Reshape_1) {
|
||||||
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
||||||
auto exp = NDArrayFactory::create<float>(119.f);
|
auto exp = NDArrayFactory::create<float>(119.f);
|
||||||
|
|
Loading…
Reference in New Issue