From 2be47082c901162ab7e99b0708df13f3901ddd99 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 30 Nov 2019 20:08:30 +1100 Subject: [PATCH] #8470 TrainingConfig json fix for Evaluation instances (#93) Signed-off-by: AlexDBlack --- .../autodiff/samediff/TrainingConfig.java | 4 +- .../org/nd4j/evaluation/BaseEvaluation.java | 51 +++---------------- .../org/nd4j/evaluation/curves/BaseCurve.java | 9 ++-- .../nd4j/evaluation/curves/BaseHistogram.java | 9 ++-- .../nd4j/evaluation/serde/ROCSerializer.java | 12 +++-- .../java/org/nd4j/serde/json/JsonMappers.java | 39 ++++++++++++-- .../nd4j/autodiff/samediff/SameDiffTests.java | 19 +++++++ 7 files changed, 81 insertions(+), 62 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java index d50daddb8..25aec0028 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -18,7 +18,6 @@ package org.nd4j.autodiff.samediff; import lombok.*; import lombok.extern.slf4j.Slf4j; -import org.nd4j.autodiff.listeners.ListenerEvaluations; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.learning.config.IUpdater; @@ -64,6 +63,7 @@ public class TrainingConfig { private int iterationCount; private int epochCount; + private Map> trainEvaluations = new HashMap<>(); private Map trainEvaluationLabels = new HashMap<>(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java index fd08e4270..3f4ce04f3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java @@ -1,5 +1,6 @@ -/******************************************************************************* +/* ****************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -17,7 +18,6 @@ package org.nd4j.evaluation; import lombok.EqualsAndHashCode; -import lombok.Getter; import lombok.NonNull; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.classification.*; @@ -27,24 +27,13 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.primitives.AtomicBoolean; -import org.nd4j.linalg.primitives.AtomicDouble; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Triple; -import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean; -import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble; -import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean; -import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble; import org.nd4j.linalg.util.ArrayUtil; -import org.nd4j.shade.jackson.annotation.JsonAutoDetect; +import org.nd4j.serde.json.JsonMappers; import org.nd4j.shade.jackson.core.JsonProcessingException; -import org.nd4j.shade.jackson.databind.DeserializationFeature; -import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; -import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; -import org.nd4j.shade.jackson.databind.module.SimpleModule; -import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; import java.io.IOException; import java.io.Serializable; @@ -60,32 +49,6 @@ import java.util.List; @EqualsAndHashCode public abstract class BaseEvaluation implements IEvaluation { - @Getter - private static ObjectMapper objectMapper = configureMapper(new ObjectMapper()); - @Getter - private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); - - private static ObjectMapper configureMapper(ObjectMapper ret) { - ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false); - ret.enable(SerializationFeature.INDENT_OUTPUT); - SimpleModule atomicModule = new SimpleModule(); - atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble()); - atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean()); - atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble()); - atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean()); - ret.registerModule(atomicModule); - //Serialize fields only, not using getters - ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker() - .withFieldVisibility(JsonAutoDetect.Visibility.ANY) - .withGetterVisibility(JsonAutoDetect.Visibility.NONE) - .withSetterVisibility(JsonAutoDetect.Visibility.NONE) - .withCreatorVisibility(JsonAutoDetect.Visibility.ANY) - ); - return ret; - } - /** * @param yaml YAML representation * @param clazz Class @@ -94,7 +57,7 @@ public abstract class BaseEvaluation implements IEvalu */ public static T fromYaml(String yaml, Class clazz) { try { - return yamlMapper.readValue(yaml, clazz); + return JsonMappers.getYamlMapper().readValue(yaml, clazz); } catch (IOException e) { throw new RuntimeException(e); } @@ -108,7 +71,7 @@ public abstract class BaseEvaluation implements IEvalu */ public static T fromJson(String json, Class clazz) { try { - return objectMapper.readValue(json, clazz); + return JsonMappers.getMapper().readValue(json, clazz); } catch (InvalidTypeIdException e) { if (e.getMessage().contains("Could not resolve type id")) { try { @@ -332,7 +295,7 @@ public abstract class BaseEvaluation implements IEvalu @Override public String toJson() { try { - return objectMapper.writeValueAsString(this); + return JsonMappers.getMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -349,7 +312,7 @@ public abstract class BaseEvaluation implements IEvalu @Override public String toYaml() { try { - return yamlMapper.writeValueAsString(this); + return JsonMappers.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java index ee9339da4..2e61e80bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseCurve.java @@ -17,6 +17,7 @@ package org.nd4j.evaluation.curves; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.serde.json.JsonMappers; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.core.JsonProcessingException; @@ -87,7 +88,7 @@ public abstract class BaseCurve { */ public String toJson() { try { - return BaseEvaluation.getObjectMapper().writeValueAsString(this); + return JsonMappers.getMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -98,7 +99,7 @@ public abstract class BaseCurve { */ public String toYaml() { try { - return BaseEvaluation.getYamlMapper().writeValueAsString(this); + return JsonMappers.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -113,7 +114,7 @@ public abstract class BaseCurve { */ public static T fromJson(String json, Class curveClass) { try { - return BaseEvaluation.getObjectMapper().readValue(json, curveClass); + return JsonMappers.getMapper().readValue(json, curveClass); } catch (IOException e) { throw new RuntimeException(e); } @@ -128,7 +129,7 @@ public abstract class BaseCurve { */ public static T fromYaml(String yaml, Class curveClass) { try { - return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); + return JsonMappers.getYamlMapper().readValue(yaml, curveClass); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java index a941f2088..1adcc32d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/curves/BaseHistogram.java @@ -17,6 +17,7 @@ package org.nd4j.evaluation.curves; import org.nd4j.evaluation.BaseEvaluation; +import org.nd4j.serde.json.JsonMappers; import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.core.JsonProcessingException; @@ -46,7 +47,7 @@ public abstract class BaseHistogram { */ public String toJson() { try { - return BaseEvaluation.getObjectMapper().writeValueAsString(this); + return JsonMappers.getMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -57,7 +58,7 @@ public abstract class BaseHistogram { */ public String toYaml() { try { - return BaseEvaluation.getYamlMapper().writeValueAsString(this); + return JsonMappers.getYamlMapper().writeValueAsString(this); } catch (JsonProcessingException e) { throw new RuntimeException(e); } @@ -72,7 +73,7 @@ public abstract class BaseHistogram { */ public static T fromJson(String json, Class curveClass) { try { - return BaseEvaluation.getObjectMapper().readValue(json, curveClass); + return JsonMappers.getMapper().readValue(json, curveClass); } catch (IOException e) { throw new RuntimeException(e); } @@ -87,7 +88,7 @@ public abstract class BaseHistogram { */ public static T fromYaml(String yaml, Class curveClass) { try { - return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); + return JsonMappers.getYamlMapper().readValue(yaml, curveClass); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java index 236407527..331585ad4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/serde/ROCSerializer.java @@ -36,7 +36,9 @@ public class ROCSerializer extends JsonSerializer { @Override public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { - if (roc.isExact()) { + boolean empty = roc.getExampleCount() == 0; + + if (roc.isExact() && !empty) { //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such //that we have them once deserialized. //Due to potentially huge size, exact mode doesn't store the original predictions in JSON @@ -47,9 +49,11 @@ public class ROCSerializer extends JsonSerializer { jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive()); jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative()); jsonGenerator.writeObjectField("counts", roc.getCounts()); - jsonGenerator.writeNumberField("auc", roc.calculateAUC()); - jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); - if (roc.isExact()) { + if(!empty) { + jsonGenerator.writeNumberField("auc", roc.calculateAUC()); + jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); + } + if (roc.isExact() && !empty) { //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve()); jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve()); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java index 81bb46e75..4a1344ae3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/json/JsonMappers.java @@ -17,10 +17,19 @@ package org.nd4j.serde.json; import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.primitives.AtomicBoolean; +import org.nd4j.linalg.primitives.AtomicDouble; +import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean; +import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble; +import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean; +import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble; +import org.nd4j.shade.jackson.annotation.JsonAutoDetect; import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.SerializationFeature; +import org.nd4j.shade.jackson.databind.module.SimpleModule; +import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory; /** * JSON mappers for serializing/deserializing objects @@ -30,19 +39,41 @@ import org.nd4j.shade.jackson.databind.SerializationFeature; @Slf4j public class JsonMappers { - private static ObjectMapper jsonMapper = new ObjectMapper(); + private static ObjectMapper jsonMapper = configureMapper(new ObjectMapper()); + private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); /** - * @return The default/primary ObjectMapper for deserializing JSON network configurations in DL4J + * @return The default/primary ObjectMapper for deserializing JSON objects */ public static ObjectMapper getMapper(){ return jsonMapper; } - private static void configureMapper(ObjectMapper ret) { + /** + * @return The default/primary ObjectMapper for deserializing JSON objects + */ + public static ObjectMapper getYamlMapper(){ + return jsonMapper; + } + + private static ObjectMapper configureMapper(ObjectMapper ret) { ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true); + ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false); ret.enable(SerializationFeature.INDENT_OUTPUT); + SimpleModule atomicModule = new SimpleModule(); + atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble()); + atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean()); + atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble()); + atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean()); + ret.registerModule(atomicModule); + //Serialize fields only, not using getters + ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker() + .withFieldVisibility(JsonAutoDetect.Visibility.ANY) + .withGetterVisibility(JsonAutoDetect.Visibility.NONE) + .withSetterVisibility(JsonAutoDetect.Visibility.NONE) + .withCreatorVisibility(JsonAutoDetect.Visibility.ANY) + ); + return ret; } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index e10ffcddb..db8d7d551 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -43,6 +43,9 @@ import org.nd4j.autodiff.samediff.api.OutAndGrad; import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; +import org.nd4j.evaluation.IEvaluation; +import org.nd4j.evaluation.classification.*; +import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.blas.params.MMulTranspose; @@ -3501,4 +3504,20 @@ public class SameDiffTests extends BaseNd4jTest { Map map = sd.calculateGradients(null,"input", "concat"); assertEquals(map.get("input"), map.get("concat")); } + + @Test + public void testTrainingConfigJson(){ + for(IEvaluation e : new IEvaluation[]{new Evaluation(), new RegressionEvaluation(), new EvaluationBinary(), new ROC(), + new ROCMultiClass(), new ROCBinary(), new EvaluationCalibration()}) { + TrainingConfig config = new TrainingConfig.Builder() + .l2(1e-4) + .updater(new Adam(0.1)) + .dataSetFeatureMapping("out").dataSetLabelMapping("label") + .trainEvaluation("out", 0, e) + .build(); + String json = config.toJson(); + TrainingConfig fromJson = TrainingConfig.fromJson(json); + assertEquals(config, fromJson); + } + } }