Fixes Keras import issue (8373) (#54)

* test

* fix

* rem prn

* add test
master
Fariz Rahman 2019-11-20 07:49:04 +05:30 committed by Alex Black
parent 66b84b38cf
commit 630409cd53
2 changed files with 35 additions and 6 deletions

View File

@ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong;
@Data
@Slf4j
@EqualsAndHashCode(callSuper = false)
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"})
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor extends BaseInputPreProcessor {
private long[] inputShape;
private long[] targetShape;
private boolean hasMiniBatchDimension = false;
private int miniBatchSize;
private long[] staticTargetShape;
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
this.inputShape = inputShape;
@ -80,6 +81,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
// the target shape read from a keras config does not have mini-batch size
// included. We prepend it here dynamically.
long[] targetShape;
if (staticTargetShape != null){
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
hasMiniBatchDimension = true;
this.miniBatchSize = miniBatchSize;
}
else{
targetShape = this.targetShape;
}
if (!this.hasMiniBatchDimension) {
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
@ -95,7 +106,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape));
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
} else {
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
+ " and output shape" + Arrays.toString(inputShape) + " do not match");
@ -122,20 +133,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
InputType ret;
switch (shape.length) {
case 2:
return InputType.feedForward(shape[1]);
ret = InputType.feedForward(shape[1]);
break;
case 3:
return InputType.recurrent(shape[2], shape[1]);
ret = InputType.recurrent(shape[2], shape[1]);
break;
case 4:
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
return InputType.convolutional(shape[1], shape[2], shape[3]);
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
}else {
return InputType.convolutional(shape[2], shape[3], shape[1]);
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
}
break;
default:
throw new UnsupportedOperationException(
"Cannot infer input type for reshape array " + Arrays.toString(shape));
}
this.staticTargetShape = ret.getShape();
return ret;
}
}

View File

@ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
}
}
@Test
public void ReshapeEmbeddingConcatTest() throws Exception{
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
ComputationGraphConfiguration config =
new KerasModel().modelBuilder().modelJsonInputStream(is)
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
}
}
private void runSequentialConfigTest(String path) throws Exception {
try(InputStream is = Resources.asStream(path)) {