Fix for concat op shape function (empty shapes) (#167)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
9a513a9aa6
commit
b417ca21bf
|
@ -167,9 +167,7 @@ DECLARE_SHAPE_FN(concat) {
|
|||
}
|
||||
|
||||
for(int i = 1; i < numOfArrs; ++i)
|
||||
if (!shape::isEmpty(arrShapes[i])) {
|
||||
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||
}
|
||||
|
||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||
|
||||
|
|
|
@ -2109,6 +2109,38 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
Nd4j.exec(op);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConcatEmpty2(){
|
||||
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
|
||||
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
|
||||
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("concat")
|
||||
.addInputs(empty10a, empty10b)
|
||||
.addIntegerArguments(0) //axis = 0
|
||||
.build();
|
||||
|
||||
List<LongShapeDescriptor> l = op.calculateOutputShape();
|
||||
assertEquals(1, l.size());
|
||||
assertTrue(l.get(0).isEmpty());
|
||||
assertArrayEquals(new long[]{2, 0}, l.get(0).getShape());
|
||||
assertEquals(DataType.INT, l.get(0).dataType());
|
||||
|
||||
op.addOutputArgument(Nd4j.create(DataType.INT, 2, 0));
|
||||
Nd4j.exec(op);
|
||||
|
||||
|
||||
op = DynamicCustomOp.builder("concat")
|
||||
.addInputs(empty10a, empty10b)
|
||||
.addIntegerArguments(1) //axis = 1
|
||||
.build();
|
||||
l = op.calculateOutputShape();
|
||||
assertEquals(1, l.size());
|
||||
assertTrue(l.get(0).isEmpty());
|
||||
assertArrayEquals(new long[]{1, 0}, l.get(0).getShape());
|
||||
op.addOutputArgument(Nd4j.create(DataType.INT, 1, 0));
|
||||
Nd4j.exec(op);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmptyGather(){
|
||||
/*
|
||||
|
@ -2434,4 +2466,5 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
.addInputs(Nd4j.createFromArray(1, 0))
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue