parent
66b84b38cf
commit
630409cd53
|
@ -43,13 +43,14 @@ import static org.nd4j.linalg.util.ArrayUtil.prodLong;
|
||||||
@Data
|
@Data
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize"})
|
@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
|
||||||
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
|
|
||||||
private long[] inputShape;
|
private long[] inputShape;
|
||||||
private long[] targetShape;
|
private long[] targetShape;
|
||||||
private boolean hasMiniBatchDimension = false;
|
private boolean hasMiniBatchDimension = false;
|
||||||
private int miniBatchSize;
|
private int miniBatchSize;
|
||||||
|
private long[] staticTargetShape;
|
||||||
|
|
||||||
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
|
public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
|
||||||
this.inputShape = inputShape;
|
this.inputShape = inputShape;
|
||||||
|
@ -80,6 +81,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
// the target shape read from a keras config does not have mini-batch size
|
// the target shape read from a keras config does not have mini-batch size
|
||||||
// included. We prepend it here dynamically.
|
// included. We prepend it here dynamically.
|
||||||
|
|
||||||
|
long[] targetShape;
|
||||||
|
if (staticTargetShape != null){
|
||||||
|
targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
|
||||||
|
hasMiniBatchDimension = true;
|
||||||
|
this.miniBatchSize = miniBatchSize;
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
targetShape = this.targetShape;
|
||||||
|
}
|
||||||
if (!this.hasMiniBatchDimension) {
|
if (!this.hasMiniBatchDimension) {
|
||||||
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
|
||||||
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
|
||||||
|
@ -95,7 +106,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
|
if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
|
||||||
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
|
||||||
}
|
}
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(this.targetShape));
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
|
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
|
||||||
+ " and output shape" + Arrays.toString(inputShape) + " do not match");
|
+ " and output shape" + Arrays.toString(inputShape) + " do not match");
|
||||||
|
@ -122,20 +133,27 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
|
||||||
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
|
val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
|
||||||
|
InputType ret;
|
||||||
switch (shape.length) {
|
switch (shape.length) {
|
||||||
case 2:
|
case 2:
|
||||||
return InputType.feedForward(shape[1]);
|
ret = InputType.feedForward(shape[1]);
|
||||||
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
return InputType.recurrent(shape[2], shape[1]);
|
ret = InputType.recurrent(shape[2], shape[1]);
|
||||||
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
|
if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
|
||||||
return InputType.convolutional(shape[1], shape[2], shape[3]);
|
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
|
||||||
}else {
|
}else {
|
||||||
return InputType.convolutional(shape[2], shape[3], shape[1]);
|
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
|
||||||
}
|
}
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
"Cannot infer input type for reshape array " + Arrays.toString(shape));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
this.staticTargetShape = ret.getShape();
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -255,6 +255,17 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void ReshapeEmbeddingConcatTest() throws Exception{
|
||||||
|
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
|
||||||
|
ComputationGraphConfiguration config =
|
||||||
|
new KerasModel().modelBuilder().modelJsonInputStream(is)
|
||||||
|
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
|
||||||
|
ComputationGraph model = new ComputationGraph(config);
|
||||||
|
model.init();
|
||||||
|
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void runSequentialConfigTest(String path) throws Exception {
|
private void runSequentialConfigTest(String path) throws Exception {
|
||||||
try(InputStream is = Resources.asStream(path)) {
|
try(InputStream is = Resources.asStream(path)) {
|
||||||
|
|
Loading…
Reference in New Issue