parent
2ec24c762f
commit
124d0a1965
|
@ -106,7 +106,6 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
|
||||||
// the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
|
// the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
|
||||||
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
long[] targetShape = getShape(this.targetShape, miniBatchSize);
|
||||||
long[] inputShape = getShape(this.inputShape, miniBatchSize);
|
|
||||||
|
|
||||||
if (prodLong(input.shape()) == prodLong((targetShape))) {
|
if (prodLong(input.shape()) == prodLong((targetShape))) {
|
||||||
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
|
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
|
||||||
|
@ -115,7 +114,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
|
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
|
||||||
+ " and output shape" + Arrays.toString(inputShape) + " do not match");
|
+ " and target shape" + Arrays.toString(targetShape) + " do not match");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,4 +177,4 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue