Squashed and signed, last try (#9136)

Signed-off-by: mjlorenzo305 <mario@mjlorenzo.com>
master
Mario Lorenzo 2021-01-20 21:50:36 -05:00 committed by GitHub
parent 2ec24c762f
commit 124d0a1965
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 3 deletions

View File

@ -106,7 +106,6 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) { public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
// the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically. // the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
long[] targetShape = getShape(this.targetShape, miniBatchSize); long[] targetShape = getShape(this.targetShape, miniBatchSize);
long[] inputShape = getShape(this.inputShape, miniBatchSize);
if (prodLong(input.shape()) == prodLong((targetShape))) { if (prodLong(input.shape()) == prodLong((targetShape))) {
if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) { if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
@ -115,7 +114,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape)); return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
} else { } else {
throw new IllegalStateException("Input shape " + Arrays.toString(input.shape()) throw new IllegalStateException("Input shape " + Arrays.toString(input.shape())
+ " and output shape" + Arrays.toString(inputShape) + " do not match"); + " and target shape" + Arrays.toString(targetShape) + " do not match");
} }
} }
@ -178,4 +177,4 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
} }
return ret; return ret;
} }
} }