diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index ce5e8c78f..99704c0bb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -57,6 +57,15 @@ public abstract class InputType implements Serializable { FF, RNN, CNN, CNNFlat, CNN3D } + public static CNN2DFormat getDefaultCNN2DFormat() { + return defaultCNN2DFormat; + } + + public static void setDefaultCNN2DFormat(CNN2DFormat defaultCNN2DFormat) { + InputType.defaultCNN2DFormat = defaultCNN2DFormat; + } + + private static CNN2DFormat defaultCNN2DFormat = CNN2DFormat.NCHW; @JsonIgnore public abstract Type getType(); @@ -137,7 +146,7 @@ public abstract class InputType implements Serializable { * @return InputTypeConvolutional */ public static InputType convolutional(long height, long width, long depth) { - return convolutional(height, width, depth, CNN2DFormat.NCHW); + return convolutional(height, width, depth, getDefaultCNN2DFormat()); } public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){