parent
66b84b38cf
commit
630409cd53
|
@ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong;
|
|||
@Data
|
||||
@Slf4j
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"})
|
||||
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
|
||||
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||
|
||||
private long[] inputShape;
|
||||
private long[] targetShape;
|
||||
private boolean hasMiniBatchDimension = false;
|
||||
private int miniBatchSize;
|
||||
private long[] staticTargetShape;
|
||||
|
||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
|
||||
this.inputShape = inputShape;
|
||||
|
@ -80,6 +81,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||
// the target shape read from a keras config does not have mini-batch size
|
||||
// included. We prepend it here dynamically.
|
||||
|
||||
long[] targetShape;
|
||||
if (staticTargetShape != null){
|
||||
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
|
||||
hasMiniBatchDimension = true;
|
||||
this.miniBatchSize = miniBatchSize;
|
||||
}
|
||||
else{
|
||||
targetShape = this.targetShape;
|
||||
}
|
||||
if (!this.hasMiniBatchDimension) {
|
||||
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
||||
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
||||
|
@ -95,7 +106,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
|
||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||
}
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape));
|
||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
||||
} else {
|
||||
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
|
||||
+ " and output shape" + Arrays.toString(inputShape) + " do not match");
|
||||
|
@ -122,20 +133,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
|||
@Override
|
||||
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
||||
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
|
||||
InputType ret;
|
||||
switch (shape.length) {
|
||||
case 2:
|
||||
return InputType.feedForward(shape[1]);
|
||||
ret = InputType.feedForward(shape[1]);
|
||||
break;
|
||||
case 3:
|
||||
return InputType.recurrent(shape[2], shape[1]);
|
||||
ret = InputType.recurrent(shape[2], shape[1]);
|
||||
break;
|
||||
case 4:
|
||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
|
||||
return InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||
}else {
|
||||
return InputType.convolutional(shape[2], shape[3], shape[1]);
|
||||
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
||||
|
||||
}
|
||||
this.staticTargetShape = ret.getShape();
|
||||
return ret;
|
||||
}
|
||||
}
|
|
@ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void ReshapeEmbeddingConcatTest() throws Exception{
|
||||
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
||||
ComputationGraphConfiguration config =
|
||||
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
||||
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
|
||||
ComputationGraph model = new ComputationGraph(config);
|
||||
model.init();
|
||||
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
|
||||
}
|
||||
}
|
||||
|
||||
private void runSequentialConfigTest(String path) throws Exception {
|
||||
try(InputStream is = Resources.asStream(path)) {
|
||||
|
|
Loading…
Reference in New Issue