Keras cropping fixes

master
agibsonccc 2021-03-05 10:30:46 +09:00
parent 2ae9f58909
commit 87cf665e22
2 changed files with 16 additions and 13 deletions

View File

@ -370,10 +370,11 @@ public class KerasConvolutionUtils {
} else if ((paddingNoCast.size() == dimension) && isNested) { } else if ((paddingNoCast.size() == dimension) && isNested) {
for (int j = 0; j < dimension; j++) { for (int j = 0; j < dimension; j++) {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
List<Integer> item = (List<Integer>) paddingNoCast.get(0); List<Integer> item = (List<Integer>) paddingNoCast.get(j);
paddingList.add((item.get(0))); paddingList.add((item.get(0)));
paddingList.add((item.get(1))); paddingList.add((item.get(1)));
} }
padding = ArrayUtil.toArray(paddingList); padding = ArrayUtil.toArray(paddingList);
} else { } else {
throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension

View File

@ -29,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Map; import java.util.Map;