#8470 TrainingConfig json fix for Evaluation instances (#93)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-30 20:08:30 +11:00 committed by GitHub
parent 35ab4a72ba
commit 2be47082c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 81 additions and 62 deletions

View File

@ -18,7 +18,6 @@ package org.nd4j.autodiff.samediff;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.ListenerEvaluations;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
@ -64,6 +63,7 @@ public class TrainingConfig {
private int iterationCount; private int iterationCount;
private int epochCount; private int epochCount;
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>(); private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
private Map<String, Integer> trainEvaluationLabels = new HashMap<>(); private Map<String, Integer> trainEvaluationLabels = new HashMap<>();

View File

@ -1,5 +1,6 @@
/* ****************************************************************************** /* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -17,7 +18,6 @@
package org.nd4j.evaluation; package org.nd4j.evaluation;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.classification.*; import org.nd4j.evaluation.classification.*;
@ -27,24 +27,13 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple; import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect; import org.nd4j.serde.json.JsonMappers;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException; import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.databind.module.SimpleModule;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import java.io.IOException; import java.io.IOException;
import java.io.Serializable; import java.io.Serializable;
@ -60,32 +49,6 @@ import java.util.List;
@EqualsAndHashCode @EqualsAndHashCode
public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvaluation<T> { public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvaluation<T> {
@Getter
private static ObjectMapper objectMapper = configureMapper(new ObjectMapper());
@Getter
private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory()));
private static ObjectMapper configureMapper(ObjectMapper ret) {
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
ret.enable(SerializationFeature.INDENT_OUTPUT);
SimpleModule atomicModule = new SimpleModule();
atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble());
atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean());
atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble());
atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean());
ret.registerModule(atomicModule);
//Serialize fields only, not using getters
ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker()
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
.withCreatorVisibility(JsonAutoDetect.Visibility.ANY)
);
return ret;
}
/** /**
* @param yaml YAML representation * @param yaml YAML representation
* @param clazz Class * @param clazz Class
@ -94,7 +57,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
*/ */
public static <T extends IEvaluation> T fromYaml(String yaml, Class<T> clazz) { public static <T extends IEvaluation> T fromYaml(String yaml, Class<T> clazz) {
try { try {
return yamlMapper.readValue(yaml, clazz); return JsonMappers.getYamlMapper().readValue(yaml, clazz);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -108,7 +71,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
*/ */
public static <T extends IEvaluation> T fromJson(String json, Class<T> clazz) { public static <T extends IEvaluation> T fromJson(String json, Class<T> clazz) {
try { try {
return objectMapper.readValue(json, clazz); return JsonMappers.getMapper().readValue(json, clazz);
} catch (InvalidTypeIdException e) { } catch (InvalidTypeIdException e) {
if (e.getMessage().contains("Could not resolve type id")) { if (e.getMessage().contains("Could not resolve type id")) {
try { try {
@ -332,7 +295,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
@Override @Override
public String toJson() { public String toJson() {
try { try {
return objectMapper.writeValueAsString(this); return JsonMappers.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -349,7 +312,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
@Override @Override
public String toYaml() { public String toYaml() {
try { try {
return yamlMapper.writeValueAsString(this); return JsonMappers.getYamlMapper().writeValueAsString(this);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.evaluation.curves; package org.nd4j.evaluation.curves;
import org.nd4j.evaluation.BaseEvaluation; import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.serde.json.JsonMappers;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
@ -87,7 +88,7 @@ public abstract class BaseCurve {
*/ */
public String toJson() { public String toJson() {
try { try {
return BaseEvaluation.getObjectMapper().writeValueAsString(this); return JsonMappers.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -98,7 +99,7 @@ public abstract class BaseCurve {
*/ */
public String toYaml() { public String toYaml() {
try { try {
return BaseEvaluation.getYamlMapper().writeValueAsString(this); return JsonMappers.getYamlMapper().writeValueAsString(this);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -113,7 +114,7 @@ public abstract class BaseCurve {
*/ */
public static <T extends BaseCurve> T fromJson(String json, Class<T> curveClass) { public static <T extends BaseCurve> T fromJson(String json, Class<T> curveClass) {
try { try {
return BaseEvaluation.getObjectMapper().readValue(json, curveClass); return JsonMappers.getMapper().readValue(json, curveClass);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -128,7 +129,7 @@ public abstract class BaseCurve {
*/ */
public static <T extends BaseCurve> T fromYaml(String yaml, Class<T> curveClass) { public static <T extends BaseCurve> T fromYaml(String yaml, Class<T> curveClass) {
try { try {
return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); return JsonMappers.getYamlMapper().readValue(yaml, curveClass);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -17,6 +17,7 @@
package org.nd4j.evaluation.curves; package org.nd4j.evaluation.curves;
import org.nd4j.evaluation.BaseEvaluation; import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.serde.json.JsonMappers;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo; import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
@ -46,7 +47,7 @@ public abstract class BaseHistogram {
*/ */
public String toJson() { public String toJson() {
try { try {
return BaseEvaluation.getObjectMapper().writeValueAsString(this); return JsonMappers.getMapper().writeValueAsString(this);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -57,7 +58,7 @@ public abstract class BaseHistogram {
*/ */
public String toYaml() { public String toYaml() {
try { try {
return BaseEvaluation.getYamlMapper().writeValueAsString(this); return JsonMappers.getYamlMapper().writeValueAsString(this);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -72,7 +73,7 @@ public abstract class BaseHistogram {
*/ */
public static <T extends BaseHistogram> T fromJson(String json, Class<T> curveClass) { public static <T extends BaseHistogram> T fromJson(String json, Class<T> curveClass) {
try { try {
return BaseEvaluation.getObjectMapper().readValue(json, curveClass); return JsonMappers.getMapper().readValue(json, curveClass);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -87,7 +88,7 @@ public abstract class BaseHistogram {
*/ */
public static <T extends BaseHistogram> T fromYaml(String yaml, Class<T> curveClass) { public static <T extends BaseHistogram> T fromYaml(String yaml, Class<T> curveClass) {
try { try {
return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass); return JsonMappers.getYamlMapper().readValue(yaml, curveClass);
} catch (IOException e) { } catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }

View File

@ -36,7 +36,9 @@ public class ROCSerializer extends JsonSerializer<ROC> {
@Override @Override
public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
throws IOException { throws IOException {
if (roc.isExact()) { boolean empty = roc.getExampleCount() == 0;
if (roc.isExact() && !empty) {
//For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such //For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such
//that we have them once deserialized. //that we have them once deserialized.
//Due to potentially huge size, exact mode doesn't store the original predictions in JSON //Due to potentially huge size, exact mode doesn't store the original predictions in JSON
@ -47,9 +49,11 @@ public class ROCSerializer extends JsonSerializer<ROC> {
jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive()); jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive());
jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative()); jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative());
jsonGenerator.writeObjectField("counts", roc.getCounts()); jsonGenerator.writeObjectField("counts", roc.getCounts());
if(!empty) {
jsonGenerator.writeNumberField("auc", roc.calculateAUC()); jsonGenerator.writeNumberField("auc", roc.calculateAUC());
jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR()); jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR());
if (roc.isExact()) { }
if (roc.isExact() && !empty) {
//Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode //Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode
jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve()); jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve());
jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve()); jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve());

View File

@ -17,10 +17,19 @@
package org.nd4j.serde.json; package org.nd4j.serde.json;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.databind.DeserializationFeature; import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature; import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature; import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.module.SimpleModule;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
/** /**
* JSON mappers for serializing/deserializing objects * JSON mappers for serializing/deserializing objects
@ -30,19 +39,41 @@ import org.nd4j.shade.jackson.databind.SerializationFeature;
@Slf4j @Slf4j
public class JsonMappers { public class JsonMappers {
private static ObjectMapper jsonMapper = new ObjectMapper(); private static ObjectMapper jsonMapper = configureMapper(new ObjectMapper());
private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory()));
/** /**
* @return The default/primary ObjectMapper for deserializing JSON network configurations in DL4J * @return The default/primary ObjectMapper for deserializing JSON objects
*/ */
public static ObjectMapper getMapper(){ public static ObjectMapper getMapper(){
return jsonMapper; return jsonMapper;
} }
private static void configureMapper(ObjectMapper ret) { /**
* @return The default/primary ObjectMapper for deserializing JSON objects
*/
public static ObjectMapper getYamlMapper(){
return jsonMapper;
}
private static ObjectMapper configureMapper(ObjectMapper ret) {
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true); ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
ret.enable(SerializationFeature.INDENT_OUTPUT); ret.enable(SerializationFeature.INDENT_OUTPUT);
SimpleModule atomicModule = new SimpleModule();
atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble());
atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean());
atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble());
atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean());
ret.registerModule(atomicModule);
//Serialize fields only, not using getters
ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker()
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
.withCreatorVisibility(JsonAutoDetect.Visibility.ANY)
);
return ret;
} }
} }

View File

@ -43,6 +43,9 @@ import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional; import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional;
import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.OpValidation;
import org.nd4j.autodiff.validation.TestCase; import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
@ -3501,4 +3504,20 @@ public class SameDiffTests extends BaseNd4jTest {
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat"); Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
assertEquals(map.get("input"), map.get("concat")); assertEquals(map.get("input"), map.get("concat"));
} }
@Test
public void testTrainingConfigJson(){
for(IEvaluation e : new IEvaluation[]{new Evaluation(), new RegressionEvaluation(), new EvaluationBinary(), new ROC(),
new ROCMultiClass(), new ROCBinary(), new EvaluationCalibration()}) {
TrainingConfig config = new TrainingConfig.Builder()
.l2(1e-4)
.updater(new Adam(0.1))
.dataSetFeatureMapping("out").dataSetLabelMapping("label")
.trainEvaluation("out", 0, e)
.build();
String json = config.toJson();
TrainingConfig fromJson = TrainingConfig.fromJson(json);
assertEquals(config, fromJson);
}
}
} }