casting fixes (#464)
parent
60f103fb03
commit
1c15d0f33e
|
@ -63,10 +63,10 @@ public class KerasReLU extends KerasLayer {
|
|||
double negativeSlope = 0.0;
|
||||
double threshold = 0.0;
|
||||
if (innerConfig.containsKey("negative_slope")) {
|
||||
negativeSlope = (double) innerConfig.get("negative_slope");
|
||||
negativeSlope = ((Number)innerConfig.get("negative_slope")).doubleValue();
|
||||
}
|
||||
if (innerConfig.containsKey("threshold")) {
|
||||
threshold = (double) innerConfig.get("threshold");
|
||||
threshold = ((Number)innerConfig.get("threshold")).doubleValue();
|
||||
}
|
||||
|
||||
this.layer = new ActivationLayer.Builder().name(this.layerName)
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
|
@ -351,6 +352,10 @@ public class KerasBatchNormalization extends KerasLayer {
|
|||
private int getBatchNormAxis(Map<String, Object> layerConfig)
|
||||
throws InvalidKerasConfigurationException {
|
||||
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||
return (int) innerConfig.get(LAYER_FIELD_AXIS);
|
||||
Object batchNormAxis = innerConfig.get(LAYER_FIELD_AXIS);
|
||||
if (batchNormAxis instanceof List){
|
||||
return ((Number)((List)batchNormAxis).get(0)).intValue();
|
||||
}
|
||||
return ((Number)innerConfig.get(LAYER_FIELD_AXIS)).intValue();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue