From 87cf665e22e2bb1296999ddd38ab5f84f5f7330a Mon Sep 17 00:00:00 2001 From: agibsonccc Date: Fri, 5 Mar 2021 10:30:46 +0900 Subject: [PATCH] Keras cropping fixes --- .../convolutional/KerasConvolutionUtils.java | 27 ++++++++++--------- .../layers/convolutional/KerasCropping2D.java | 2 ++ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java index 271dcd4a4..0a03b1dc7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java @@ -59,7 +59,7 @@ public class KerasConvolutionUtils { List stridesList = (List) innerConfig.get(conf.getLAYER_FIELD_CONVOLUTION_STRIDES()); strides = ArrayUtil.toArray(stridesList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()) && dimension == 1) { - /* 1D Convolutional layers. */ + /* 1D Convolutional layers. */ if ((int) layerConfig.get("keras_version") == 2) { @SuppressWarnings("unchecked") List stridesList = (List) innerConfig.get(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()); @@ -163,7 +163,7 @@ public class KerasConvolutionUtils { * @throws InvalidKerasConfigurationException Invalid Keras configuration */ static int[] getUpsamplingSizeFromConfig(Map layerConfig, int dimension, - KerasLayerConfiguration conf) + KerasLayerConfiguration conf) throws InvalidKerasConfigurationException { Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); int[] size; @@ -200,7 +200,7 @@ public class KerasConvolutionUtils { if (kerasMajorVersion != 2) { if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_ROW()) && dimension == 2 && innerConfig.containsKey(conf.getLAYER_FIELD_NB_COL())) { - /* 2D Convolutional layers. */ + /* 2D Convolutional layers. */ List kernelSizeList = new ArrayList<>(); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_ROW())); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_COL())); @@ -208,23 +208,23 @@ public class KerasConvolutionUtils { } else if (innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_1()) && dimension == 3 && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_2()) && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_3())) { - /* 3D Convolutional layers. */ + /* 3D Convolutional layers. */ List kernelSizeList = new ArrayList<>(); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_1())); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_2())); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_3())); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) { - /* 1D Convolutional layers. */ + /* 1D Convolutional layers. */ int filterLength = (int) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH()); kernelSize = new int[]{filterLength}; } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) { - /* 2D/3D Pooling layers. */ + /* 2D/3D Pooling layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) { - /* 1D Pooling layers. */ + /* 1D Pooling layers. */ int poolSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE()); kernelSize = new int[]{poolSize1D}; } else { @@ -242,17 +242,17 @@ public class KerasConvolutionUtils { List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_KERNEL_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) { - /* 1D Convolutional layers. */ + /* 1D Convolutional layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) { - /* 2D Pooling layers. */ + /* 2D Pooling layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) { - /* 1D Pooling layers. */ + /* 1D Pooling layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); @@ -364,16 +364,17 @@ public class KerasConvolutionUtils { } if ((paddingNoCast.size() == dimension) && !isNested) { - for (int i=0; i < dimension; i++) + for (int i = 0; i < dimension; i++) paddingList.add((int) paddingNoCast.get(i)); padding = ArrayUtil.toArray(paddingList); } else if ((paddingNoCast.size() == dimension) && isNested) { - for (int j=0; j < dimension; j++) { + for (int j = 0; j < dimension; j++) { @SuppressWarnings("unchecked") - List item = (List) paddingNoCast.get(0); + List item = (List) paddingNoCast.get(j); paddingList.add((item.get(0))); paddingList.add((item.get(1))); } + padding = ArrayUtil.toArray(paddingList); } else { throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java index b4df34c5b..66d49d37a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java @@ -29,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.nd4j.common.util.ArrayUtil; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Map;