casting fixes (#464)

master
Fariz Rahman 2020-05-13 18:36:38 +04:00 committed by GitHub
parent 60f103fb03
commit 1c15d0f33e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 3 deletions

View File

@ -63,10 +63,10 @@ public class KerasReLU extends KerasLayer {
double negativeSlope = 0.0; double negativeSlope = 0.0;
double threshold = 0.0; double threshold = 0.0;
if (innerConfig.containsKey("negative_slope")) { if (innerConfig.containsKey("negative_slope")) {
negativeSlope = (double) innerConfig.get("negative_slope"); negativeSlope = ((Number)innerConfig.get("negative_slope")).doubleValue();
} }
if (innerConfig.containsKey("threshold")) { if (innerConfig.containsKey("threshold")) {
threshold = (double) innerConfig.get("threshold"); threshold = ((Number)innerConfig.get("threshold")).doubleValue();
} }
this.layer = new ActivationLayer.Builder().name(this.layerName) this.layer = new ActivationLayer.Builder().name(this.layerName)

View File

@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
@ -351,6 +352,10 @@ public class KerasBatchNormalization extends KerasLayer {
private int getBatchNormAxis(Map<String, Object> layerConfig) private int getBatchNormAxis(Map<String, Object> layerConfig)
throws InvalidKerasConfigurationException { throws InvalidKerasConfigurationException {
Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
return (int) innerConfig.get(LAYER_FIELD_AXIS); Object batchNormAxis = innerConfig.get(LAYER_FIELD_AXIS);
if (batchNormAxis instanceof List){
return ((Number)((List)batchNormAxis).get(0)).intValue();
}
return ((Number)innerConfig.get(LAYER_FIELD_AXIS)).intValue();
} }
} }