casting fixes (#464)
parent
60f103fb03
commit
1c15d0f33e
|
@ -63,10 +63,10 @@ public class KerasReLU extends KerasLayer {
|
||||||
double negativeSlope = 0.0;
|
double negativeSlope = 0.0;
|
||||||
double threshold = 0.0;
|
double threshold = 0.0;
|
||||||
if (innerConfig.containsKey("negative_slope")) {
|
if (innerConfig.containsKey("negative_slope")) {
|
||||||
negativeSlope = (double) innerConfig.get("negative_slope");
|
negativeSlope = ((Number)innerConfig.get("negative_slope")).doubleValue();
|
||||||
}
|
}
|
||||||
if (innerConfig.containsKey("threshold")) {
|
if (innerConfig.containsKey("threshold")) {
|
||||||
threshold = (double) innerConfig.get("threshold");
|
threshold = ((Number)innerConfig.get("threshold")).doubleValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
this.layer = new ActivationLayer.Builder().name(this.layerName)
|
this.layer = new ActivationLayer.Builder().name(this.layerName)
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
@ -351,6 +352,10 @@ public class KerasBatchNormalization extends KerasLayer {
|
||||||
private int getBatchNormAxis(Map<String, Object> layerConfig)
|
private int getBatchNormAxis(Map<String, Object> layerConfig)
|
||||||
throws InvalidKerasConfigurationException {
|
throws InvalidKerasConfigurationException {
|
||||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||||
return (int) innerConfig.get(LAYER_FIELD_AXIS);
|
Object batchNormAxis = innerConfig.get(LAYER_FIELD_AXIS);
|
||||||
|
if (batchNormAxis instanceof List){
|
||||||
|
return ((Number)((List)batchNormAxis).get(0)).intValue();
|
||||||
|
}
|
||||||
|
return ((Number)innerConfig.get(LAYER_FIELD_AXIS)).intValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue