Keras cropping fixes
parent
2ae9f58909
commit
87cf665e22
|
@ -59,7 +59,7 @@ public class KerasConvolutionUtils {
|
||||||
List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_CONVOLUTION_STRIDES());
|
List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_CONVOLUTION_STRIDES());
|
||||||
strides = ArrayUtil.toArray(stridesList);
|
strides = ArrayUtil.toArray(stridesList);
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()) && dimension == 1) {
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()) && dimension == 1) {
|
||||||
/* 1D Convolutional layers. */
|
/* 1D Convolutional layers. */
|
||||||
if ((int) layerConfig.get("keras_version") == 2) {
|
if ((int) layerConfig.get("keras_version") == 2) {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH());
|
List<Integer> stridesList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH());
|
||||||
|
@ -163,7 +163,7 @@ public class KerasConvolutionUtils {
|
||||||
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
* @throws InvalidKerasConfigurationException Invalid Keras configuration
|
||||||
*/
|
*/
|
||||||
static int[] getUpsamplingSizeFromConfig(Map<String, Object> layerConfig, int dimension,
|
static int[] getUpsamplingSizeFromConfig(Map<String, Object> layerConfig, int dimension,
|
||||||
KerasLayerConfiguration conf)
|
KerasLayerConfiguration conf)
|
||||||
throws InvalidKerasConfigurationException {
|
throws InvalidKerasConfigurationException {
|
||||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||||
int[] size;
|
int[] size;
|
||||||
|
@ -200,7 +200,7 @@ public class KerasConvolutionUtils {
|
||||||
if (kerasMajorVersion != 2) {
|
if (kerasMajorVersion != 2) {
|
||||||
if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_ROW()) && dimension == 2
|
if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_ROW()) && dimension == 2
|
||||||
&& innerConfig.containsKey(conf.getLAYER_FIELD_NB_COL())) {
|
&& innerConfig.containsKey(conf.getLAYER_FIELD_NB_COL())) {
|
||||||
/* 2D Convolutional layers. */
|
/* 2D Convolutional layers. */
|
||||||
List<Integer> kernelSizeList = new ArrayList<>();
|
List<Integer> kernelSizeList = new ArrayList<>();
|
||||||
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_ROW()));
|
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_ROW()));
|
||||||
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_COL()));
|
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_COL()));
|
||||||
|
@ -208,23 +208,23 @@ public class KerasConvolutionUtils {
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_1()) && dimension == 3
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_1()) && dimension == 3
|
||||||
&& innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_2())
|
&& innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_2())
|
||||||
&& innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_3())) {
|
&& innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_3())) {
|
||||||
/* 3D Convolutional layers. */
|
/* 3D Convolutional layers. */
|
||||||
List<Integer> kernelSizeList = new ArrayList<>();
|
List<Integer> kernelSizeList = new ArrayList<>();
|
||||||
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_1()));
|
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_1()));
|
||||||
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_2()));
|
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_2()));
|
||||||
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_3()));
|
kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_3()));
|
||||||
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) {
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) {
|
||||||
/* 1D Convolutional layers. */
|
/* 1D Convolutional layers. */
|
||||||
int filterLength = (int) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH());
|
int filterLength = (int) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH());
|
||||||
kernelSize = new int[]{filterLength};
|
kernelSize = new int[]{filterLength};
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) {
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) {
|
||||||
/* 2D/3D Pooling layers. */
|
/* 2D/3D Pooling layers. */
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE());
|
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE());
|
||||||
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) {
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) {
|
||||||
/* 1D Pooling layers. */
|
/* 1D Pooling layers. */
|
||||||
int poolSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE());
|
int poolSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE());
|
||||||
kernelSize = new int[]{poolSize1D};
|
kernelSize = new int[]{poolSize1D};
|
||||||
} else {
|
} else {
|
||||||
|
@ -242,17 +242,17 @@ public class KerasConvolutionUtils {
|
||||||
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_KERNEL_SIZE());
|
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_KERNEL_SIZE());
|
||||||
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) {
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) {
|
||||||
/* 1D Convolutional layers. */
|
/* 1D Convolutional layers. */
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH());
|
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH());
|
||||||
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) {
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) {
|
||||||
/* 2D Pooling layers. */
|
/* 2D Pooling layers. */
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE());
|
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE());
|
||||||
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
||||||
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) {
|
} else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) {
|
||||||
/* 1D Pooling layers. */
|
/* 1D Pooling layers. */
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE());
|
List<Integer> kernelSizeList = (List<Integer>) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE());
|
||||||
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
kernelSize = ArrayUtil.toArray(kernelSizeList);
|
||||||
|
@ -364,16 +364,17 @@ public class KerasConvolutionUtils {
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((paddingNoCast.size() == dimension) && !isNested) {
|
if ((paddingNoCast.size() == dimension) && !isNested) {
|
||||||
for (int i=0; i < dimension; i++)
|
for (int i = 0; i < dimension; i++)
|
||||||
paddingList.add((int) paddingNoCast.get(i));
|
paddingList.add((int) paddingNoCast.get(i));
|
||||||
padding = ArrayUtil.toArray(paddingList);
|
padding = ArrayUtil.toArray(paddingList);
|
||||||
} else if ((paddingNoCast.size() == dimension) && isNested) {
|
} else if ((paddingNoCast.size() == dimension) && isNested) {
|
||||||
for (int j=0; j < dimension; j++) {
|
for (int j = 0; j < dimension; j++) {
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
List<Integer> item = (List<Integer>) paddingNoCast.get(0);
|
List<Integer> item = (List<Integer>) paddingNoCast.get(j);
|
||||||
paddingList.add((item.get(0)));
|
paddingList.add((item.get(0)));
|
||||||
paddingList.add((item.get(1)));
|
paddingList.add((item.get(1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
padding = ArrayUtil.toArray(paddingList);
|
padding = ArrayUtil.toArray(paddingList);
|
||||||
} else {
|
} else {
|
||||||
throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension
|
throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension
|
||||||
|
|
|
@ -29,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||||
|
import org.nd4j.common.util.ArrayUtil;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue