Add options for changing the default CNN2D format (#8996)
This means that when a model is being loaded (e.g., from a Keras H5 file), the default CNN2DFormat can be set. Currently it always defaults to 'Channels First' which can cause problems for some models.master
parent
5ff8d28b89
commit
026ebabf77
|
@ -57,6 +57,15 @@ public abstract class InputType implements Serializable {
|
|||
FF, RNN, CNN, CNNFlat, CNN3D
|
||||
}
|
||||
|
||||
public static CNN2DFormat getDefaultCNN2DFormat() {
|
||||
return defaultCNN2DFormat;
|
||||
}
|
||||
|
||||
public static void setDefaultCNN2DFormat(CNN2DFormat defaultCNN2DFormat) {
|
||||
InputType.defaultCNN2DFormat = defaultCNN2DFormat;
|
||||
}
|
||||
|
||||
private static CNN2DFormat defaultCNN2DFormat = CNN2DFormat.NCHW;
|
||||
|
||||
@JsonIgnore
|
||||
public abstract Type getType();
|
||||
|
@ -137,7 +146,7 @@ public abstract class InputType implements Serializable {
|
|||
* @return InputTypeConvolutional
|
||||
*/
|
||||
public static InputType convolutional(long height, long width, long depth) {
|
||||
return convolutional(height, width, depth, CNN2DFormat.NCHW);
|
||||
return convolutional(height, width, depth, getDefaultCNN2DFormat());
|
||||
}
|
||||
|
||||
public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){
|
||||
|
|
Loading…
Reference in New Issue