diff --git a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp index 5249758bf..3c165f64f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/concat.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/concat.cpp @@ -167,9 +167,7 @@ DECLARE_SHAPE_FN(concat) { } for(int i = 1; i < numOfArrs; ++i) - if (!shape::isEmpty(arrShapes[i])) { - outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; - } + outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0])); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index e53cfa5ff..2965f367f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -2109,6 +2109,38 @@ public class ShapeOpValidation extends BaseOpValidation { Nd4j.exec(op); } + @Test + public void testConcatEmpty2(){ + INDArray empty10a = Nd4j.create(DataType.INT, 1, 0); + INDArray empty10b = Nd4j.create(DataType.INT, 1, 0); + + DynamicCustomOp op = DynamicCustomOp.builder("concat") + .addInputs(empty10a, empty10b) + .addIntegerArguments(0) //axis = 0 + .build(); + + List l = op.calculateOutputShape(); + assertEquals(1, l.size()); + assertTrue(l.get(0).isEmpty()); + assertArrayEquals(new long[]{2, 0}, l.get(0).getShape()); + assertEquals(DataType.INT, l.get(0).dataType()); + + op.addOutputArgument(Nd4j.create(DataType.INT, 2, 0)); + Nd4j.exec(op); + + + op = DynamicCustomOp.builder("concat") + .addInputs(empty10a, empty10b) + .addIntegerArguments(1) //axis = 1 + .build(); + l = op.calculateOutputShape(); + assertEquals(1, l.size()); + assertTrue(l.get(0).isEmpty()); + assertArrayEquals(new long[]{1, 0}, l.get(0).getShape()); + op.addOutputArgument(Nd4j.create(DataType.INT, 1, 0)); + Nd4j.exec(op); + } + @Test public void testEmptyGather(){ /* @@ -2434,4 +2466,5 @@ public class ShapeOpValidation extends BaseOpValidation { .addInputs(Nd4j.createFromArray(1, 0)) .build(); } + }