Add options for changing the default CNN2D format (#8996)

This means that when a model is being loaded (e.g., from a Keras H5 file), the default CNN2DFormat can be set. Currently it always defaults to 'Channels First' which can cause problems for some models.
master
Rhys Compton 2020-07-25 23:32:18 +12:00 committed by GitHub
parent 5ff8d28b89
commit 026ebabf77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 1 deletions

View File

@ -57,6 +57,15 @@ public abstract class InputType implements Serializable {
FF, RNN, CNN, CNNFlat, CNN3D FF, RNN, CNN, CNNFlat, CNN3D
} }
public static CNN2DFormat getDefaultCNN2DFormat() {
return defaultCNN2DFormat;
}
public static void setDefaultCNN2DFormat(CNN2DFormat defaultCNN2DFormat) {
InputType.defaultCNN2DFormat = defaultCNN2DFormat;
}
private static CNN2DFormat defaultCNN2DFormat = CNN2DFormat.NCHW;
@JsonIgnore @JsonIgnore
public abstract Type getType(); public abstract Type getType();
@ -137,7 +146,7 @@ public abstract class InputType implements Serializable {
* @return InputTypeConvolutional * @return InputTypeConvolutional
*/ */
public static InputType convolutional(long height, long width, long depth) { public static InputType convolutional(long height, long width, long depth) {
return convolutional(height, width, depth, CNN2DFormat.NCHW); return convolutional(height, width, depth, getDefaultCNN2DFormat());
} }
public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){ public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){