Fix for concat op shape function (empty shapes) (#167)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
9a513a9aa6
commit
b417ca21bf
|
@ -167,9 +167,7 @@ DECLARE_SHAPE_FN(concat) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i = 1; i < numOfArrs; ++i)
|
for(int i = 1; i < numOfArrs; ++i)
|
||||||
if (!shape::isEmpty(arrShapes[i])) {
|
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
||||||
outShapeInfo[axis + 1] += arrShapes[i][axis + 1];
|
|
||||||
}
|
|
||||||
|
|
||||||
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes[0], shape::order(arrShapes[0]));
|
||||||
|
|
||||||
|
|
|
@ -2109,6 +2109,38 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
Nd4j.exec(op);
|
Nd4j.exec(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConcatEmpty2(){
|
||||||
|
INDArray empty10a = Nd4j.create(DataType.INT, 1, 0);
|
||||||
|
INDArray empty10b = Nd4j.create(DataType.INT, 1, 0);
|
||||||
|
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("concat")
|
||||||
|
.addInputs(empty10a, empty10b)
|
||||||
|
.addIntegerArguments(0) //axis = 0
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<LongShapeDescriptor> l = op.calculateOutputShape();
|
||||||
|
assertEquals(1, l.size());
|
||||||
|
assertTrue(l.get(0).isEmpty());
|
||||||
|
assertArrayEquals(new long[]{2, 0}, l.get(0).getShape());
|
||||||
|
assertEquals(DataType.INT, l.get(0).dataType());
|
||||||
|
|
||||||
|
op.addOutputArgument(Nd4j.create(DataType.INT, 2, 0));
|
||||||
|
Nd4j.exec(op);
|
||||||
|
|
||||||
|
|
||||||
|
op = DynamicCustomOp.builder("concat")
|
||||||
|
.addInputs(empty10a, empty10b)
|
||||||
|
.addIntegerArguments(1) //axis = 1
|
||||||
|
.build();
|
||||||
|
l = op.calculateOutputShape();
|
||||||
|
assertEquals(1, l.size());
|
||||||
|
assertTrue(l.get(0).isEmpty());
|
||||||
|
assertArrayEquals(new long[]{1, 0}, l.get(0).getShape());
|
||||||
|
op.addOutputArgument(Nd4j.create(DataType.INT, 1, 0));
|
||||||
|
Nd4j.exec(op);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testEmptyGather(){
|
public void testEmptyGather(){
|
||||||
/*
|
/*
|
||||||
|
@ -2434,4 +2466,5 @@ public class ShapeOpValidation extends BaseOpValidation {
|
||||||
.addInputs(Nd4j.createFromArray(1, 0))
|
.addInputs(Nd4j.createFromArray(1, 0))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue