Keras cropping fixes
parent
2ae9f58909
commit
87cf665e22
|
@ -364,16 +364,17 @@ public class KerasConvolutionUtils {
|
|||
}
|
||||
|
||||
if ((paddingNoCast.size() == dimension) && !isNested) {
|
||||
for (int i=0; i < dimension; i++)
|
||||
for (int i = 0; i < dimension; i++)
|
||||
paddingList.add((int) paddingNoCast.get(i));
|
||||
padding = ArrayUtil.toArray(paddingList);
|
||||
} else if ((paddingNoCast.size() == dimension) && isNested) {
|
||||
for (int j=0; j < dimension; j++) {
|
||||
for (int j = 0; j < dimension; j++) {
|
||||
@SuppressWarnings("unchecked")
|
||||
List<Integer> item = (List<Integer>) paddingNoCast.get(0);
|
||||
List<Integer> item = (List<Integer>) paddingNoCast.get(j);
|
||||
paddingList.add((item.get(0)));
|
||||
paddingList.add((item.get(1)));
|
||||
}
|
||||
|
||||
padding = ArrayUtil.toArray(paddingList);
|
||||
} else {
|
||||
throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension
|
||||
|
|
|
@ -29,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
|
|||
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
|
||||
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
|
||||
import org.nd4j.common.util.ArrayUtil;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
|
Loading…
Reference in New Issue