- concat empty scalar fix
- couple of tests for empty scalar concat Signed-off-by: raver119 <raver119@gmail.com>master
parent
729dc5e879
commit
fb8de5006f
|
@ -124,8 +124,10 @@ DECLARE_SHAPE_FN(concat) {
|
|||
for(int i = 0; i < block.width(); ++i) {
|
||||
|
||||
if(inputShape->at(i)[0] == 0) {
|
||||
// FIXME, use this instead: block.dataType()
|
||||
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
||||
if (shape::isEmpty(inputShape->at(i)))
|
||||
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType()));
|
||||
else
|
||||
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
||||
}
|
||||
else{
|
||||
arrShapes.push_back(inputShape->at(i));
|
||||
|
@ -165,7 +167,9 @@ DECLARE_SHAPE_FN(concat) {
|
|||
}
|
||||
|
||||
for(int i = 1; i < numOfArrs; ++i)
|
||||
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||
if (!shape::isEmpty(arrShapes[i])) {
|
||||
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||
}
|
||||
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||
|
||||
|
|
|
@ -107,6 +107,46 @@ TEST_F(EmptyTests, Test_Concat_2) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Concat_3) {
|
||||
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
|
||||
auto scalar1 = NDArrayFactory::create<float>(1.0f);
|
||||
auto scalar2 = NDArrayFactory::create<float>(2.0f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||
|
||||
ASSERT_TRUE(empty.isEmpty());
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto result = op.execute({&empty, &scalar1, &scalar2}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
z->printIndexedBuffer("z");
|
||||
ASSERT_EQ(exp, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Concat_4) {
|
||||
auto empty = NDArrayFactory::empty<float>(); //NDArrayFactory::empty_<float>();
|
||||
auto scalar1 = NDArrayFactory::create<float>(1.0f);
|
||||
auto scalar2 = NDArrayFactory::create<float>(2.0f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {2}, {1.f, 2.f});
|
||||
|
||||
ASSERT_TRUE(empty.isEmpty());
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto result = op.execute({&scalar1, &empty, &scalar2}, {}, {0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
z->printIndexedBuffer("z");
|
||||
ASSERT_EQ(exp, *z);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(EmptyTests, Test_Reshape_1) {
|
||||
auto vector = NDArrayFactory::create<float>('c', {1}, {119.0f});
|
||||
auto exp = NDArrayFactory::create<float>(119.f);
|
||||
|
|
Loading…
Reference in New Issue