Add options for changing the default CNN2D format (#8996)
This means that when a model is being loaded (e.g., from a Keras H5 file), the default CNN2DFormat can be set. Currently it always defaults to 'Channels First' which can cause problems for some models.
This commit is contained in:
		
							parent
							
								
									5ff8d28b89
								
							
						
					
					
						commit
						026ebabf77
					
				| @ -57,6 +57,15 @@ public abstract class InputType implements Serializable { | ||||
|         FF, RNN, CNN, CNNFlat, CNN3D | ||||
|     } | ||||
| 
 | ||||
|     public static CNN2DFormat getDefaultCNN2DFormat() { | ||||
|         return defaultCNN2DFormat; | ||||
|     } | ||||
| 
 | ||||
|     public static void setDefaultCNN2DFormat(CNN2DFormat defaultCNN2DFormat) { | ||||
|         InputType.defaultCNN2DFormat = defaultCNN2DFormat; | ||||
|     } | ||||
| 
 | ||||
|     private static CNN2DFormat defaultCNN2DFormat = CNN2DFormat.NCHW; | ||||
| 
 | ||||
|     @JsonIgnore | ||||
|     public abstract Type getType(); | ||||
| @ -137,7 +146,7 @@ public abstract class InputType implements Serializable { | ||||
|      * @return InputTypeConvolutional | ||||
|      */ | ||||
|     public static InputType convolutional(long height, long width, long depth) { | ||||
|         return convolutional(height, width, depth, CNN2DFormat.NCHW); | ||||
|         return convolutional(height, width, depth, getDefaultCNN2DFormat()); | ||||
|     } | ||||
| 
 | ||||
|     public static InputType convolutional(long height, long width, long depth, CNN2DFormat format){ | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user