Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
35ab4a72ba
commit
2be47082c9
|
@ -18,7 +18,6 @@ package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.nd4j.autodiff.listeners.ListenerEvaluations;
|
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
@ -64,6 +63,7 @@ public class TrainingConfig {
|
||||||
private int iterationCount;
|
private int iterationCount;
|
||||||
private int epochCount;
|
private int epochCount;
|
||||||
|
|
||||||
|
|
||||||
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
|
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
|
||||||
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
|
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/* ******************************************************************************
|
/* ******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -17,7 +18,6 @@
|
||||||
package org.nd4j.evaluation;
|
package org.nd4j.evaluation;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.classification.*;
|
import org.nd4j.evaluation.classification.*;
|
||||||
|
@ -27,24 +27,13 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.AtomicBoolean;
|
|
||||||
import org.nd4j.linalg.primitives.AtomicDouble;
|
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.primitives.Triple;
|
import org.nd4j.linalg.primitives.Triple;
|
||||||
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean;
|
|
||||||
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble;
|
|
||||||
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean;
|
|
||||||
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble;
|
|
||||||
import org.nd4j.linalg.util.ArrayUtil;
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
import org.nd4j.serde.json.JsonMappers;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
|
||||||
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
|
||||||
import org.nd4j.shade.jackson.databind.module.SimpleModule;
|
|
||||||
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
@ -60,32 +49,6 @@ import java.util.List;
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvaluation<T> {
|
public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvaluation<T> {
|
||||||
|
|
||||||
@Getter
|
|
||||||
private static ObjectMapper objectMapper = configureMapper(new ObjectMapper());
|
|
||||||
@Getter
|
|
||||||
private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory()));
|
|
||||||
|
|
||||||
private static ObjectMapper configureMapper(ObjectMapper ret) {
|
|
||||||
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
|
||||||
ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
|
||||||
ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
|
|
||||||
ret.enable(SerializationFeature.INDENT_OUTPUT);
|
|
||||||
SimpleModule atomicModule = new SimpleModule();
|
|
||||||
atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble());
|
|
||||||
atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean());
|
|
||||||
atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble());
|
|
||||||
atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean());
|
|
||||||
ret.registerModule(atomicModule);
|
|
||||||
//Serialize fields only, not using getters
|
|
||||||
ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker()
|
|
||||||
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
|
|
||||||
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
|
|
||||||
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
|
|
||||||
.withCreatorVisibility(JsonAutoDetect.Visibility.ANY)
|
|
||||||
);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param yaml YAML representation
|
* @param yaml YAML representation
|
||||||
* @param clazz Class
|
* @param clazz Class
|
||||||
|
@ -94,7 +57,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
|
||||||
*/
|
*/
|
||||||
public static <T extends IEvaluation> T fromYaml(String yaml, Class<T> clazz) {
|
public static <T extends IEvaluation> T fromYaml(String yaml, Class<T> clazz) {
|
||||||
try {
|
try {
|
||||||
return yamlMapper.readValue(yaml, clazz);
|
return JsonMappers.getYamlMapper().readValue(yaml, clazz);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -108,7 +71,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
|
||||||
*/
|
*/
|
||||||
public static <T extends IEvaluation> T fromJson(String json, Class<T> clazz) {
|
public static <T extends IEvaluation> T fromJson(String json, Class<T> clazz) {
|
||||||
try {
|
try {
|
||||||
return objectMapper.readValue(json, clazz);
|
return JsonMappers.getMapper().readValue(json, clazz);
|
||||||
} catch (InvalidTypeIdException e) {
|
} catch (InvalidTypeIdException e) {
|
||||||
if (e.getMessage().contains("Could not resolve type id")) {
|
if (e.getMessage().contains("Could not resolve type id")) {
|
||||||
try {
|
try {
|
||||||
|
@ -332,7 +295,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
|
||||||
@Override
|
@Override
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
try {
|
try {
|
||||||
return objectMapper.writeValueAsString(this);
|
return JsonMappers.getMapper().writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -349,7 +312,7 @@ public abstract class BaseEvaluation<T extends BaseEvaluation> implements IEvalu
|
||||||
@Override
|
@Override
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
try {
|
try {
|
||||||
return yamlMapper.writeValueAsString(this);
|
return JsonMappers.getYamlMapper().writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.evaluation.curves;
|
package org.nd4j.evaluation.curves;
|
||||||
|
|
||||||
import org.nd4j.evaluation.BaseEvaluation;
|
import org.nd4j.evaluation.BaseEvaluation;
|
||||||
|
import org.nd4j.serde.json.JsonMappers;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
@ -87,7 +88,7 @@ public abstract class BaseCurve {
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getObjectMapper().writeValueAsString(this);
|
return JsonMappers.getMapper().writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -98,7 +99,7 @@ public abstract class BaseCurve {
|
||||||
*/
|
*/
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getYamlMapper().writeValueAsString(this);
|
return JsonMappers.getYamlMapper().writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -113,7 +114,7 @@ public abstract class BaseCurve {
|
||||||
*/
|
*/
|
||||||
public static <T extends BaseCurve> T fromJson(String json, Class<T> curveClass) {
|
public static <T extends BaseCurve> T fromJson(String json, Class<T> curveClass) {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getObjectMapper().readValue(json, curveClass);
|
return JsonMappers.getMapper().readValue(json, curveClass);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -128,7 +129,7 @@ public abstract class BaseCurve {
|
||||||
*/
|
*/
|
||||||
public static <T extends BaseCurve> T fromYaml(String yaml, Class<T> curveClass) {
|
public static <T extends BaseCurve> T fromYaml(String yaml, Class<T> curveClass) {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass);
|
return JsonMappers.getYamlMapper().readValue(yaml, curveClass);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.nd4j.evaluation.curves;
|
package org.nd4j.evaluation.curves;
|
||||||
|
|
||||||
import org.nd4j.evaluation.BaseEvaluation;
|
import org.nd4j.evaluation.BaseEvaluation;
|
||||||
|
import org.nd4j.serde.json.JsonMappers;
|
||||||
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ public abstract class BaseHistogram {
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getObjectMapper().writeValueAsString(this);
|
return JsonMappers.getMapper().writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -57,7 +58,7 @@ public abstract class BaseHistogram {
|
||||||
*/
|
*/
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getYamlMapper().writeValueAsString(this);
|
return JsonMappers.getYamlMapper().writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -72,7 +73,7 @@ public abstract class BaseHistogram {
|
||||||
*/
|
*/
|
||||||
public static <T extends BaseHistogram> T fromJson(String json, Class<T> curveClass) {
|
public static <T extends BaseHistogram> T fromJson(String json, Class<T> curveClass) {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getObjectMapper().readValue(json, curveClass);
|
return JsonMappers.getMapper().readValue(json, curveClass);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -87,7 +88,7 @@ public abstract class BaseHistogram {
|
||||||
*/
|
*/
|
||||||
public static <T extends BaseHistogram> T fromYaml(String yaml, Class<T> curveClass) {
|
public static <T extends BaseHistogram> T fromYaml(String yaml, Class<T> curveClass) {
|
||||||
try {
|
try {
|
||||||
return BaseEvaluation.getYamlMapper().readValue(yaml, curveClass);
|
return JsonMappers.getYamlMapper().readValue(yaml, curveClass);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,9 @@ public class ROCSerializer extends JsonSerializer<ROC> {
|
||||||
@Override
|
@Override
|
||||||
public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
|
public void serialize(ROC roc, JsonGenerator jsonGenerator, SerializerProvider serializerProvider)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
if (roc.isExact()) {
|
boolean empty = roc.getExampleCount() == 0;
|
||||||
|
|
||||||
|
if (roc.isExact() && !empty) {
|
||||||
//For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such
|
//For exact ROC implementation: force AUC and AUPRC calculation, so result can be stored in JSON, such
|
||||||
//that we have them once deserialized.
|
//that we have them once deserialized.
|
||||||
//Due to potentially huge size, exact mode doesn't store the original predictions in JSON
|
//Due to potentially huge size, exact mode doesn't store the original predictions in JSON
|
||||||
|
@ -47,9 +49,11 @@ public class ROCSerializer extends JsonSerializer<ROC> {
|
||||||
jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive());
|
jsonGenerator.writeNumberField("countActualPositive", roc.getCountActualPositive());
|
||||||
jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative());
|
jsonGenerator.writeNumberField("countActualNegative", roc.getCountActualNegative());
|
||||||
jsonGenerator.writeObjectField("counts", roc.getCounts());
|
jsonGenerator.writeObjectField("counts", roc.getCounts());
|
||||||
|
if(!empty) {
|
||||||
jsonGenerator.writeNumberField("auc", roc.calculateAUC());
|
jsonGenerator.writeNumberField("auc", roc.calculateAUC());
|
||||||
jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR());
|
jsonGenerator.writeNumberField("auprc", roc.calculateAUCPR());
|
||||||
if (roc.isExact()) {
|
}
|
||||||
|
if (roc.isExact() && !empty) {
|
||||||
//Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode
|
//Store ROC and PR curves only for exact mode... they are redundant + can be calculated again for thresholded mode
|
||||||
jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve());
|
jsonGenerator.writeObjectField("rocCurve", roc.getRocCurve());
|
||||||
jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve());
|
jsonGenerator.writeObjectField("prCurve", roc.getPrecisionRecallCurve());
|
||||||
|
|
|
@ -17,10 +17,19 @@
|
||||||
package org.nd4j.serde.json;
|
package org.nd4j.serde.json;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.linalg.primitives.AtomicBoolean;
|
||||||
|
import org.nd4j.linalg.primitives.AtomicDouble;
|
||||||
|
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean;
|
||||||
|
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble;
|
||||||
|
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean;
|
||||||
|
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
|
||||||
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
import org.nd4j.shade.jackson.databind.DeserializationFeature;
|
||||||
import org.nd4j.shade.jackson.databind.MapperFeature;
|
import org.nd4j.shade.jackson.databind.MapperFeature;
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||||
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
|
import org.nd4j.shade.jackson.databind.module.SimpleModule;
|
||||||
|
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* JSON mappers for serializing/deserializing objects
|
* JSON mappers for serializing/deserializing objects
|
||||||
|
@ -30,19 +39,41 @@ import org.nd4j.shade.jackson.databind.SerializationFeature;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class JsonMappers {
|
public class JsonMappers {
|
||||||
|
|
||||||
private static ObjectMapper jsonMapper = new ObjectMapper();
|
private static ObjectMapper jsonMapper = configureMapper(new ObjectMapper());
|
||||||
|
private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory()));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return The default/primary ObjectMapper for deserializing JSON network configurations in DL4J
|
* @return The default/primary ObjectMapper for deserializing JSON objects
|
||||||
*/
|
*/
|
||||||
public static ObjectMapper getMapper(){
|
public static ObjectMapper getMapper(){
|
||||||
return jsonMapper;
|
return jsonMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void configureMapper(ObjectMapper ret) {
|
/**
|
||||||
|
* @return The default/primary ObjectMapper for deserializing JSON objects
|
||||||
|
*/
|
||||||
|
public static ObjectMapper getYamlMapper(){
|
||||||
|
return jsonMapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ObjectMapper configureMapper(ObjectMapper ret) {
|
||||||
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
|
||||||
ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
|
||||||
ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
|
ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
|
||||||
ret.enable(SerializationFeature.INDENT_OUTPUT);
|
ret.enable(SerializationFeature.INDENT_OUTPUT);
|
||||||
|
SimpleModule atomicModule = new SimpleModule();
|
||||||
|
atomicModule.addSerializer(AtomicDouble.class, new JsonSerializerAtomicDouble());
|
||||||
|
atomicModule.addSerializer(AtomicBoolean.class, new JsonSerializerAtomicBoolean());
|
||||||
|
atomicModule.addDeserializer(AtomicDouble.class, new JsonDeserializerAtomicDouble());
|
||||||
|
atomicModule.addDeserializer(AtomicBoolean.class, new JsonDeserializerAtomicBoolean());
|
||||||
|
ret.registerModule(atomicModule);
|
||||||
|
//Serialize fields only, not using getters
|
||||||
|
ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker()
|
||||||
|
.withFieldVisibility(JsonAutoDetect.Visibility.ANY)
|
||||||
|
.withGetterVisibility(JsonAutoDetect.Visibility.NONE)
|
||||||
|
.withSetterVisibility(JsonAutoDetect.Visibility.NONE)
|
||||||
|
.withCreatorVisibility(JsonAutoDetect.Visibility.ANY)
|
||||||
|
);
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,6 +43,9 @@ import org.nd4j.autodiff.samediff.api.OutAndGrad;
|
||||||
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional;
|
import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional;
|
||||||
import org.nd4j.autodiff.validation.OpValidation;
|
import org.nd4j.autodiff.validation.OpValidation;
|
||||||
import org.nd4j.autodiff.validation.TestCase;
|
import org.nd4j.autodiff.validation.TestCase;
|
||||||
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
|
import org.nd4j.evaluation.classification.*;
|
||||||
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
|
@ -3501,4 +3504,20 @@ public class SameDiffTests extends BaseNd4jTest {
|
||||||
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
|
Map<String, INDArray> map = sd.calculateGradients(null,"input", "concat");
|
||||||
assertEquals(map.get("input"), map.get("concat"));
|
assertEquals(map.get("input"), map.get("concat"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testTrainingConfigJson(){
|
||||||
|
for(IEvaluation e : new IEvaluation[]{new Evaluation(), new RegressionEvaluation(), new EvaluationBinary(), new ROC(),
|
||||||
|
new ROCMultiClass(), new ROCBinary(), new EvaluationCalibration()}) {
|
||||||
|
TrainingConfig config = new TrainingConfig.Builder()
|
||||||
|
.l2(1e-4)
|
||||||
|
.updater(new Adam(0.1))
|
||||||
|
.dataSetFeatureMapping("out").dataSetLabelMapping("label")
|
||||||
|
.trainEvaluation("out", 0, e)
|
||||||
|
.build();
|
||||||
|
String json = config.toJson();
|
||||||
|
TrainingConfig fromJson = TrainingConfig.fromJson(json);
|
||||||
|
assertEquals(config, fromJson);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue