Fixes Keras import issue (8373) (#54)

* test

* fix

* rem prn

* add test
master
Fariz Rahman 2019-11-20 07:49:04 +05:30 committed by Alex Black
parent 66b84b38cf
commit 630409cd53
2 changed files with 35 additions and 6 deletions

View File

@ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong;
@Data @Data
@Slf4j @Slf4j
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"}) @JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor extends BaseInputPreProcessor { public class ReshapePreprocessor extends BaseInputPreProcessor {
private long[] inputShape; private long[] inputShape;
private long[] targetShape; private long[] targetShape;
private boolean hasMiniBatchDimension = false; private boolean hasMiniBatchDimension = false;
private int miniBatchSize; private int miniBatchSize;
private long[] staticTargetShape;
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) { public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
this.inputShape = inputShape; this.inputShape = inputShape;
@ -80,6 +81,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
// the target shape read from a keras config does not have mini-batch size // the target shape read from a keras config does not have mini-batch size
// included. We prepend it here dynamically. // included. We prepend it here dynamically.
long[] targetShape;
if (staticTargetShape != null){
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
hasMiniBatchDimension = true;
this.miniBatchSize = miniBatchSize;
}
else{
targetShape = this.targetShape;
}
if (!this.hasMiniBatchDimension) { if (!this.hasMiniBatchDimension) {
targetShape = prependMiniBatchSize(targetShape, miniBatchSize); targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
inputShape = prependMiniBatchSize(inputShape, miniBatchSize); inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
@ -95,7 +106,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){ if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
} }
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape)); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
} else { } else {
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
+ " and output shape" + Arrays.toString(inputShape) + " do not match"); + " and output shape" + Arrays.toString(inputShape) + " do not match");
@ -122,20 +133,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override @Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException { public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0); val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
InputType ret;
switch (shape.length) { switch (shape.length) {
case 2: case 2:
return InputType.feedForward(shape[1]); ret = InputType.feedForward(shape[1]);
break;
case 3: case 3:
return InputType.recurrent(shape[2], shape[1]); ret = InputType.recurrent(shape[2], shape[1]);
break;
case 4: case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){ if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
return InputType.convolutional(shape[1], shape[2], shape[3]); ret = InputType.convolutional(shape[1], shape[2], shape[3]);
}else { }else {
return InputType.convolutional(shape[2], shape[3], shape[1]); ret = InputType.convolutional(shape[2], shape[3], shape[1]);
} }
break;
default: default:
throw new UnsupportedOperationException( throw new UnsupportedOperationException(
"Cannot infer input type for reshape array " + Arrays.toString(shape)); "Cannot infer input type for reshape array " + Arrays.toString(shape));
} }
this.staticTargetShape = ret.getShape();
return ret;
} }
} }

View File

@ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
} }
} }
@Test
public void ReshapeEmbeddingConcatTest() throws Exception{
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
ComputationGraphConfiguration config =
new KerasModel().modelBuilder().modelJsonInputStream(is)
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
}
}
private void runSequentialConfigTest(String path) throws Exception { private void runSequentialConfigTest(String path) throws Exception {
try(InputStream is = Resources.asStream(path)) { try(InputStream is = Resources.asStream(path)) {