From 630409cd537643f3cc793304ab69d90e72e877b7 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Wed, 20 Nov 2019 07:49:04 +0530 Subject: [PATCH] Fixes Keras import issue (8373) (#54) * test * fix * rem prn * add test --- .../preprocessors/ReshapePreprocessor.java | 30 +++++++++++++++---- .../Keras2ModelConfigurationTest.java | 11 +++++++ 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index 77c6369c5..dbd5ccd8c 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong; @Data @Slf4j @EqualsAndHashCode(callSuper = false) -@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"}) +@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"}) public class ReshapePreprocessor extends BaseInputPreProcessor { private long[] inputShape; private long[] targetShape; private boolean hasMiniBatchDimension = false; private int miniBatchSize; + private long[] staticTargetShape; public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) { this.inputShape = inputShape; @@ -80,6 +81,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { // the target shape read from a keras config does not have mini-batch size // included. We prepend it here dynamically. + + long[] targetShape; + if (staticTargetShape != null){ + targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize); + hasMiniBatchDimension = true; + this.miniBatchSize = miniBatchSize; + } + else{ + targetShape = this.targetShape; + } if (!this.hasMiniBatchDimension) { targetShape = prependMiniBatchSize(targetShape, miniBatchSize); inputShape = prependMiniBatchSize(inputShape, miniBatchSize); @@ -95,7 +106,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){ input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c'); } - return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape)); + return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape)); } else { throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) + " and output shape" + Arrays.toString(inputShape) + " do not match"); @@ -122,20 +133,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { @Override public InputType getOutputType(InputType inputType) throws InvalidInputTypeException { val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0); + InputType ret; switch (shape.length) { case 2: - return InputType.feedForward(shape[1]); + ret = InputType.feedForward(shape[1]); + break; case 3: - return InputType.recurrent(shape[2], shape[1]); + ret = InputType.recurrent(shape[2], shape[1]); + break; case 4: if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){ - return InputType.convolutional(shape[1], shape[2], shape[3]); + ret = InputType.convolutional(shape[1], shape[2], shape[3]); }else { - return InputType.convolutional(shape[2], shape[3], shape[1]); + ret = InputType.convolutional(shape[2], shape[3], shape[1]); } + break; default: throw new UnsupportedOperationException( "Cannot infer input type for reshape array " + Arrays.toString(shape)); + } + this.staticTargetShape = ret.getShape(); + return ret; } } \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 81103d315..db03128f7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { } } + @Test + public void ReshapeEmbeddingConcatTest() throws Exception{ + try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { + ComputationGraphConfiguration config = + new KerasModel().modelBuilder().modelJsonInputStream(is) + .enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration(); + ComputationGraph model = new ComputationGraph(config); + model.init(); + model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1)); + } + } private void runSequentialConfigTest(String path) throws Exception { try(InputStream is = Resources.asStream(path)) {