Keras cropping fixes
This commit is contained in:
		
							parent
							
								
									2ae9f58909
								
							
						
					
					
						commit
						87cf665e22
					
				| @ -59,7 +59,7 @@ public class KerasConvolutionUtils { | ||||
|             List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_CONVOLUTION_STRIDES()); | ||||
|             strides = ArrayUtil.toArray(stridesList); | ||||
|         } else if (innerConfig.containsKey(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()) && dimension == 1) { | ||||
|            /* 1D Convolutional layers. */ | ||||
|             /* 1D Convolutional layers. */ | ||||
|             if ((int) layerConfig.get("keras_version") == 2) { | ||||
|                 @SuppressWarnings("unchecked") | ||||
|                 List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()); | ||||
| @ -163,7 +163,7 @@ public class KerasConvolutionUtils { | ||||
|      * @throws InvalidKerasConfigurationException Invalid Keras configuration | ||||
|      */ | ||||
|     static int[] getUpsamplingSizeFromConfig(Map<String, Object> layerConfig, int dimension, | ||||
|                                                     KerasLayerConfiguration conf) | ||||
|                                              KerasLayerConfiguration conf) | ||||
|             throws InvalidKerasConfigurationException { | ||||
|         Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); | ||||
|         int[] size; | ||||
| @ -200,7 +200,7 @@ public class KerasConvolutionUtils { | ||||
|         if (kerasMajorVersion != 2) { | ||||
|             if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_ROW()) && dimension == 2 | ||||
|                     && innerConfig.containsKey(conf.getLAYER_FIELD_NB_COL())) { | ||||
|             /* 2D Convolutional layers. */ | ||||
|                 /* 2D Convolutional layers. */ | ||||
|                 List<Integer> kernelSizeList = new ArrayList<>(); | ||||
|                 kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_ROW())); | ||||
|                 kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_COL())); | ||||
| @ -208,23 +208,23 @@ public class KerasConvolutionUtils { | ||||
|             } else if (innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_1()) && dimension == 3 | ||||
|                     && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_2()) | ||||
|                     && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_3())) { | ||||
|             /* 3D Convolutional layers. */ | ||||
|                 /* 3D Convolutional layers. */ | ||||
|                 List<Integer> kernelSizeList = new ArrayList<>(); | ||||
|                 kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_1())); | ||||
|                 kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_2())); | ||||
|                 kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_3())); | ||||
|                 kernelSize = ArrayUtil.toArray(kernelSizeList); | ||||
|             } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) { | ||||
|             /* 1D Convolutional layers. */ | ||||
|                 /* 1D Convolutional layers. */ | ||||
|                 int filterLength = (int) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH()); | ||||
|                 kernelSize = new int[]{filterLength}; | ||||
|             } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) { | ||||
|             /* 2D/3D Pooling layers. */ | ||||
|                 /* 2D/3D Pooling layers. */ | ||||
|                 @SuppressWarnings("unchecked") | ||||
|                 List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE()); | ||||
|                 kernelSize = ArrayUtil.toArray(kernelSizeList); | ||||
|             } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) { | ||||
|             /* 1D Pooling layers. */ | ||||
|                 /* 1D Pooling layers. */ | ||||
|                 int poolSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE()); | ||||
|                 kernelSize = new int[]{poolSize1D}; | ||||
|             } else { | ||||
| @ -242,17 +242,17 @@ public class KerasConvolutionUtils { | ||||
|                 List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_KERNEL_SIZE()); | ||||
|                 kernelSize = ArrayUtil.toArray(kernelSizeList); | ||||
|             } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) { | ||||
|             /* 1D Convolutional layers. */ | ||||
|                 /* 1D Convolutional layers. */ | ||||
|                 @SuppressWarnings("unchecked") | ||||
|                 List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH()); | ||||
|                 kernelSize = ArrayUtil.toArray(kernelSizeList); | ||||
|             } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) { | ||||
|             /* 2D Pooling layers. */ | ||||
|                 /* 2D Pooling layers. */ | ||||
|                 @SuppressWarnings("unchecked") | ||||
|                 List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE()); | ||||
|                 kernelSize = ArrayUtil.toArray(kernelSizeList); | ||||
|             } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) { | ||||
|             /* 1D Pooling layers. */ | ||||
|                 /* 1D Pooling layers. */ | ||||
|                 @SuppressWarnings("unchecked") | ||||
|                 List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE()); | ||||
|                 kernelSize = ArrayUtil.toArray(kernelSizeList); | ||||
| @ -364,16 +364,17 @@ public class KerasConvolutionUtils { | ||||
|                 } | ||||
| 
 | ||||
|                 if ((paddingNoCast.size() == dimension) && !isNested) { | ||||
|                     for (int i=0; i < dimension; i++) | ||||
|                     for (int i = 0; i < dimension; i++) | ||||
|                         paddingList.add((int) paddingNoCast.get(i)); | ||||
|                     padding = ArrayUtil.toArray(paddingList); | ||||
|                 } else if ((paddingNoCast.size() == dimension) && isNested) { | ||||
|                     for (int j=0; j < dimension; j++) { | ||||
|                     for (int j = 0; j < dimension; j++) { | ||||
|                         @SuppressWarnings("unchecked") | ||||
|                         List<Integer> item = (List<Integer>) paddingNoCast.get(0); | ||||
|                         List<Integer> item = (List<Integer>) paddingNoCast.get(j); | ||||
|                         paddingList.add((item.get(0))); | ||||
|                         paddingList.add((item.get(1))); | ||||
|                     } | ||||
| 
 | ||||
|                     padding = ArrayUtil.toArray(paddingList); | ||||
|                 } else { | ||||
|                     throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension | ||||
|  | ||||
| @ -29,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; | ||||
| import org.deeplearning4j.nn.modelimport.keras.KerasLayer; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; | ||||
| import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; | ||||
| import org.nd4j.common.util.ArrayUtil; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.util.Map; | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user