Fix for concat op shape function (empty shapes) (#167)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-26 23:10:28 +10:00 committed by GitHub
parent 9a513a9aa6
commit b417ca21bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 3 deletions

View File

@ -167,9 +167,7 @@ DECLARE_SHAPE_FN(concat) {
} }
for(int i = 1; i < numOfArrs; ++i) for(int i = 1; i < numOfArrs; ++i)
if (!shape::isEmpty(arrShapes[i])) {
outShapeInfo[axis + 1] += arrShapes[i][axis + 1]; outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
}
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0])); ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));

View File

@ -2109,6 +2109,38 @@ public class ShapeOpValidation extends BaseOpValidation {
Nd4j.exec(op); Nd4j.exec(op);
} }
@Test
public void testConcatEmpty2(){
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
DynamicCustomOp op = DynamicCustomOp.builder("concat")
.addInputs(empty10a, empty10b)
.addIntegerArguments(0) //axis = 0
.build();
List<LongShapeDescriptor> l = op.calculateOutputShape();
assertEquals(1, l.size());
assertTrue(l.get(0).isEmpty());
assertArrayEquals(new long[]{2, 0}, l.get(0).getShape());
assertEquals(DataType.INT, l.get(0).dataType());
op.addOutputArgument(Nd4j.create(DataType.INT, 2, 0));
Nd4j.exec(op);
op = DynamicCustomOp.builder("concat")
.addInputs(empty10a, empty10b)
.addIntegerArguments(1) //axis = 1
.build();
l = op.calculateOutputShape();
assertEquals(1, l.size());
assertTrue(l.get(0).isEmpty());
assertArrayEquals(new long[]{1, 0}, l.get(0).getShape());
op.addOutputArgument(Nd4j.create(DataType.INT, 1, 0));
Nd4j.exec(op);
}
@Test @Test
public void testEmptyGather(){ public void testEmptyGather(){
/* /*
@ -2434,4 +2466,5 @@ public class ShapeOpValidation extends BaseOpValidation {
.addInputs(Nd4j.createFromArray(1, 0)) .addInputs(Nd4j.createFromArray(1, 0))
.build(); .build();
} }
} }