From 026ebabf77e3282215fcc179778be6f7d90784e0 Mon Sep 17 00:00:00 2001 From: Rhys Compton Date: Sat, 25 Jul 2020 23:32:18 +1200 Subject: [PATCH] Add options for changing the default CNN2D format (#8996) This means that when a model is being loaded (e.g., from a Keras H5 file), the default CNN2DFormat can be set. Currently it always defaults to 'Channels First' which can cause problems for some models. --- .../org/deeplearning4j/nn/conf/inputs/InputType.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java index ce5e8c78f..99704c0bb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/inputs/InputType.java @@ -57,6 +57,15 @@ public abstract class InputType implements Serializable { FF, RNN, CNN, CNNFlat, CNN3D } + public static CNN2DFormat getDefaultCNN2DFormat() { + return defaultCNN2DFormat; + } + + public static void setDefaultCNN2DFormat(CNN2DFormat defaultCNN2DFormat) { + InputType.defaultCNN2DFormat = defaultCNN2DFormat; + } + + private static CNN2DFormat defaultCNN2DFormat = CNN2DFormat.NCHW; @JsonIgnore public abstract Type getType(); @@ -137,7 +146,7 @@ public abstract class InputType implements Serializable { * @return InputTypeConvolutional */ public static InputType convolutional(long height, long width, long depth) { - return convolutional(height, width, depth, CNN2DFormat.NCHW); + return convolutional(height, width, depth, getDefaultCNN2DFormat()); } public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){