From fb8de5006ff433b7cd0018f85d98ce44a0e56690 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 23 Aug 2019 13:16:50 +0300 Subject: [PATCH] - concat empty scalar fix - couple of tests for empty scalar concat Signed-off-by: raver119 --- .../declarable/generic/transforms/concat.cpp | 10 +++-- libnd4j/tests_cpu/layers_tests/EmptyTests.cpp | 40 +++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index ac211d17d..5eb278f68 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -124,8 +124,10 @@ DECLARE_SHAPE_FN(concat) { for(int i = 0; i < block.width(); ++i) { if(inputShape->at(i)[0] == 0) { - // FIXME, use this instead: block.dataType() - arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType())); + if (shape::isEmpty(inputShape->at(i))) + arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType())); + else + arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType())); } else{ arrShapes.push_back(inputShape->at(i)); @@ -165,7 +167,9 @@ DECLARE_SHAPE_FN(concat) { } for(int i = 1; i < numOfArrs; ++i) - outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; + if (!shape::isEmpty(arrShapes[i])) { + outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; + } ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0])); diff --git a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp index 65de2729f..b9445cc70 100644 --- a/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/EmptyTests.cpp @@ -107,6 +107,46 @@ TEST_F(EmptyTests, Test_Concat_2) { delete result; } +TEST_F(EmptyTests, Test_Concat_3) { + auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + + ASSERT_TRUE(empty.isEmpty()); + + nd4j::ops::concat op; + auto result = op.execute({&empty, &scalar1, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + z->printIndexedBuffer("z"); + ASSERT_EQ(exp, *z); + + delete result; +} + +TEST_F(EmptyTests, Test_Concat_4) { + auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + + ASSERT_TRUE(empty.isEmpty()); + + nd4j::ops::concat op; + auto result = op.execute({&scalar1, &empty, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + z->printIndexedBuffer("z"); + ASSERT_EQ(exp, *z); + + delete result; +} + TEST_F(EmptyTests, Test_Reshape_1) { auto vector = NDArrayFactory::create('c', {1}, {119.0f}); auto exp = NDArrayFactory::create(119.f);