parent
8d73a7a410
commit
0bed17c97f
|
@ -45,6 +45,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
import org.deeplearning4j.optimize.api.BaseTrainingListener;
|
||||||
|
@ -924,8 +925,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for(EpochTerminationCondition e : etc ){
|
for(EpochTerminationCondition e : etc ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(e);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(e);
|
||||||
EpochTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, EpochTerminationCondition.class);
|
EpochTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, EpochTerminationCondition.class);
|
||||||
assertEquals(e, c);
|
assertEquals(e, c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -936,8 +937,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
for(IterationTerminationCondition i : itc ){
|
for(IterationTerminationCondition i : itc ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(i);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(i);
|
||||||
IterationTerminationCondition c = NeuralNetConfiguration.mapper().readValue(s, IterationTerminationCondition.class);
|
IterationTerminationCondition c = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, IterationTerminationCondition.class);
|
||||||
assertEquals(i, c);
|
assertEquals(i, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL)
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1))
|
.dist(new NormalDistribution(0, 1))
|
||||||
.layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
|
.layer(0, Convolution3D.builder().activation(afn).kernelSize(kernel)
|
||||||
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
|
.stride(stride).nIn(convNIn).nOut(convNOut1).hasBias(false)
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
import org.deeplearning4j.nn.conf.layers.LossLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
@ -336,7 +337,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest {
|
||||||
// to ensure that we carry the parameters through
|
// to ensure that we carry the parameters through
|
||||||
// the serializer.
|
// the serializer.
|
||||||
try{
|
try{
|
||||||
ObjectMapper m = NeuralNetConfiguration.mapper();
|
ObjectMapper m = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
String s = m.writeValueAsString(lossFunctions[i]);
|
String s = m.writeValueAsString(lossFunctions[i]);
|
||||||
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
|
ILossFunction lf2 = m.readValue(s, lossFunctions[i].getClass());
|
||||||
lossFunctions[i] = lf2;
|
lossFunctions[i] = lf2;
|
||||||
|
|
|
@ -23,6 +23,7 @@ package org.deeplearning4j.regressiontest;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.distribution.*;
|
import org.deeplearning4j.nn.conf.distribution.*;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
|
new Distribution[] {new NormalDistribution(3, 0.5), new UniformDistribution(-2, 1),
|
||||||
new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
|
new GaussianDistribution(2, 1.0), new BinomialDistribution(10, 0.3)};
|
||||||
|
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
|
|
||||||
for (Distribution d : distributions) {
|
for (Distribution d : distributions) {
|
||||||
String json = om.writeValueAsString(d);
|
String json = om.writeValueAsString(d);
|
||||||
|
@ -50,7 +51,7 @@ public class TestDistributionDeserializer extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testDistributionDeserializerLegacyFormat() throws Exception {
|
public void testDistributionDeserializerLegacyFormat() throws Exception {
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
|
|
||||||
String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n"
|
String normalJson = "{\n" + " \"normal\" : {\n" + " \"mean\" : 0.1,\n"
|
||||||
+ " \"std\" : 1.2\n" + " }\n" + " }";
|
+ " \"std\" : 1.2\n" + " }\n" + " }";
|
||||||
|
|
|
@ -41,8 +41,8 @@ public class JsonTest extends BaseDL4JTest {
|
||||||
|
|
||||||
};
|
};
|
||||||
for(InputPreProcessor p : pp ){
|
for(InputPreProcessor p : pp ){
|
||||||
String s = NeuralNetConfiguration.mapper().writeValueAsString(p);
|
String s = CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(p);
|
||||||
InputPreProcessor p2 = NeuralNetConfiguration.mapper().readValue(s, InputPreProcessor.class);
|
InputPreProcessor p2 = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(s, InputPreProcessor.class);
|
||||||
assertEquals(p, p2);
|
assertEquals(p, p2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -34,6 +34,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
@ -110,7 +111,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
* @return YAML representation of configuration
|
* @return YAML representation of configuration
|
||||||
*/
|
*/
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
try {
|
try {
|
||||||
return mapper.writeValueAsString(this);
|
return mapper.writeValueAsString(this);
|
||||||
|
@ -127,7 +128,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
* @return {@link ComputationGraphConfiguration}
|
* @return {@link ComputationGraphConfiguration}
|
||||||
*/
|
*/
|
||||||
public static ComputationGraphConfiguration fromYaml(String json) {
|
public static ComputationGraphConfiguration fromYaml(String json) {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
try {
|
try {
|
||||||
return mapper.readValue(json, ComputationGraphConfiguration.class);
|
return mapper.readValue(json, ComputationGraphConfiguration.class);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
|
@ -140,7 +141,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
//As per NeuralNetConfiguration.toJson()
|
//As per NeuralNetConfiguration.toJson()
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper =CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
||||||
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
||||||
|
@ -160,7 +161,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
|
||||||
*/
|
*/
|
||||||
public static ComputationGraphConfiguration fromJson(String json) {
|
public static ComputationGraphConfiguration fromJson(String json) {
|
||||||
//As per NeuralNetConfiguration.fromJson()
|
//As per NeuralNetConfiguration.fromJson()
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper =CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
ComputationGraphConfiguration conf;
|
ComputationGraphConfiguration conf;
|
||||||
try {
|
try {
|
||||||
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
conf = mapper.readValue(json, ComputationGraphConfiguration.class);
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.conf;
|
package org.deeplearning4j.nn.conf;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
@ -39,6 +40,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
||||||
|
@ -417,113 +419,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
@Getter @Setter @lombok.Builder.Default private double biasInit = 0.0;
|
@Getter @Setter @lombok.Builder.Default private double biasInit = 0.0;
|
||||||
@Getter @Setter @lombok.Builder.Default private double gainInit = 1.0;
|
@Getter @Setter @lombok.Builder.Default private double gainInit = 1.0;
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
|
|
||||||
* from handling of {@link Activation} above.
|
|
||||||
*
|
|
||||||
* @return True if all is well and layer iteration shall continue. False else-wise.
|
|
||||||
*/
|
|
||||||
private static boolean handleLegacyWeightInitFromJson(
|
|
||||||
String json, LayerConfiguration l, ObjectMapper mapper, JsonNode confs, int layerCount) {
|
|
||||||
if ((l instanceof BaseLayerConfiguration)
|
|
||||||
&& ((BaseLayerConfiguration) l).getWeightInit() == null) {
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
return false; // Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
|
||||||
JsonNode weightInit =
|
|
||||||
layerNode.get("weightInit"); // Should only have 1 element: "dense", "output", etc
|
|
||||||
JsonNode distribution = layerNode.get("dist");
|
|
||||||
|
|
||||||
Distribution dist = null;
|
|
||||||
if (distribution != null) {
|
|
||||||
dist = mapper.treeToValue(distribution, Distribution.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weightInit != null) {
|
|
||||||
final IWeightInit wi =
|
|
||||||
WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist);
|
|
||||||
((BaseLayerConfiguration) l).setWeightInit(wi);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"ILayer with null WeightInit detected: " + l.getName() + ", could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapperYaml() {
|
|
||||||
return JsonMappers.getMapperYaml();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapper() {
|
|
||||||
return JsonMappers.getMapper();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static NeuralNetBaseBuilderConfiguration fromYaml(String input) {
|
|
||||||
throw new RuntimeException("Needs fixing - not supported."); // TODO
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return JSON representation of NN configuration
|
|
||||||
*/
|
|
||||||
public String toYaml() {
|
|
||||||
ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapperYaml();
|
|
||||||
synchronized (mapper) {
|
|
||||||
try {
|
|
||||||
return mapper.writeValueAsString(this);
|
|
||||||
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return JSON representation of NN configuration
|
|
||||||
*/
|
|
||||||
public String toJson() {
|
|
||||||
ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapper();
|
|
||||||
synchronized (mapper) {
|
|
||||||
// JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields
|
|
||||||
// occasionally
|
|
||||||
// when writeValueAsString is used by multiple threads. This results in invalid JSON. See
|
|
||||||
// issue #3243
|
|
||||||
try {
|
|
||||||
return mapper.writeValueAsString(this);
|
|
||||||
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public NeuralNetBaseBuilderConfiguration clone() {
|
public NeuralNetBaseBuilderConfiguration clone() {
|
||||||
NeuralNetBaseBuilderConfiguration clone;
|
NeuralNetBaseBuilderConfiguration clone;
|
||||||
|
@ -562,15 +457,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
|
|
||||||
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
|
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
|
||||||
|
|
||||||
public B activation(Activation activation) {
|
|
||||||
this.activation = activation;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
@JsonIgnore
|
|
||||||
public B activation(IActivation activation) {
|
|
||||||
this.activation = activation;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to all layers. Default: no constraints.<br>
|
* Set constraints to be applied to all layers. Default: no constraints.<br>
|
||||||
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
||||||
|
@ -897,6 +784,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
*
|
*
|
||||||
* @param distribution Distribution to use for weight initialization
|
* @param distribution Distribution to use for weight initialization
|
||||||
*/
|
*/
|
||||||
|
@JsonIgnore @Deprecated
|
||||||
public B weightInit(Distribution distribution) {
|
public B weightInit(Distribution distribution) {
|
||||||
this.weightInit$value = new WeightInitDistribution(distribution);
|
this.weightInit$value = new WeightInitDistribution(distribution);
|
||||||
this.weightInit$set = true;
|
this.weightInit$set = true;
|
||||||
|
@ -909,6 +797,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@JsonProperty("weightInit") //this is needed for Jackson < 2.4, otherwise JsonIgnore on the other setters will ignore this also
|
||||||
public B weightInit(IWeightInit iWeightInit) {
|
public B weightInit(IWeightInit iWeightInit) {
|
||||||
this.weightInit$value = iWeightInit;
|
this.weightInit$value = iWeightInit;
|
||||||
this.weightInit$set = true;
|
this.weightInit$set = true;
|
||||||
|
@ -921,6 +810,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
* @param distribution
|
* @param distribution
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@JsonIgnore
|
||||||
public B dist(@NonNull Distribution distribution) {
|
public B dist(@NonNull Distribution distribution) {
|
||||||
return weightInit(distribution);
|
return weightInit(distribution);
|
||||||
}
|
}
|
||||||
|
@ -951,5 +841,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
innerConfigurations$set = true;
|
innerConfigurations$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,39 +24,24 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.*;
|
import com.fasterxml.jackson.databind.*;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
import java.util.*;
|
||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
import java.util.stream.Collectors;
|
||||||
import com.fasterxml.jackson.databind.json.JsonMapper;
|
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import lombok.extern.jackson.Jacksonized;
|
import lombok.extern.jackson.Jacksonized;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.brutex.ai.dnn.api.IModel;
|
import net.brutex.ai.dnn.api.IModel;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
|
||||||
import org.deeplearning4j.util.OutputLayerUtil;
|
import org.deeplearning4j.util.OutputLayerUtil;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
* Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of
|
||||||
|
@ -64,71 +49,50 @@ import java.util.stream.Collectors;
|
||||||
* and their hyperparameters. Hyperparameters are variables that determine how a neural network
|
* and their hyperparameters. Hyperparameters are variables that determine how a neural network
|
||||||
* learns. They include how many times to update the weights of the model, how to initialize those
|
* learns. They include how many times to update the weights of the model, how to initialize those
|
||||||
* weights, which activation function to attach to the nodes, which optimization algorithm to use,
|
* weights, which activation function to attach to the nodes, which optimization algorithm to use,
|
||||||
* and how fast the model should learn. This is what one configuration would look like:
|
* and how fast the model should learn. This is what one configuration would look like: <br>
|
||||||
* <br/><br/>
|
* <br>
|
||||||
*
|
* NeuralNetConfiguration conf = NeuralNetConfiguration.builder()<br>
|
||||||
* NeuralNetConfiguration conf = NeuralNetConfiguration.builder()<br/>
|
* .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br>
|
||||||
* .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)<br/>
|
* .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br>
|
||||||
* .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)<br/>
|
* .updater(new Sgd(0.05)) //... other hyperparameters <br>
|
||||||
* .updater(new Sgd(0.05)) //... other hyperparameters <br/>
|
* .backprop(true)<br>
|
||||||
* .backprop(true)<br/>
|
* .build();<br>
|
||||||
* .build();<br/><br/>
|
* <br>
|
||||||
*
|
* With Deeplearning4j, you add a layer by calling layer on the
|
||||||
* With Deeplearning4j, you add a layer
|
* NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of
|
||||||
* by calling layer on the NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of
|
|
||||||
* layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
|
* layers (the zero-indexed layer below is the input layer), the number of input and output nodes,
|
||||||
* nIn and nOut, as well as the type: DenseLayer.<br/><br/>
|
* nIn and nOut, as well as the type: DenseLayer.<br>
|
||||||
*
|
* <br>
|
||||||
* .layer(0, DenseLayer.builder().nIn(784).nOut(250)<br/>
|
* .layer(0, DenseLayer.builder().nIn(784).nOut(250)<br>
|
||||||
* .build())<br/><br/>
|
* .build())<br>
|
||||||
*
|
* <br>
|
||||||
* Once you've configured your net, you train the
|
* Once you've configured your net, you train the model with model.fit.
|
||||||
* model with model.fit.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Jacksonized
|
@JsonIgnoreProperties(value = {"net"})
|
||||||
@JsonIgnoreProperties(value={"net"}, ignoreUnknown = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@EqualsAndHashCode(exclude = {"net"}, callSuper = true)
|
|
||||||
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
|
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
|
||||||
|
|
||||||
// The inner builder, that we can then extend ...
|
// The inner builder, that we can then extend ...
|
||||||
|
@Jacksonized
|
||||||
@SuperBuilder // TODO fix access
|
@SuperBuilder // TODO fix access
|
||||||
public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
|
|
||||||
|
|
||||||
private IModel net;
|
|
||||||
private static final int DEFAULT_TBPTT_LENGTH = 20;
|
private static final int DEFAULT_TBPTT_LENGTH = 20;
|
||||||
private boolean initCalled = false;
|
|
||||||
|
|
||||||
|
@Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@NonNull
|
|
||||||
@lombok.Builder.Default
|
|
||||||
@Deprecated
|
|
||||||
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
@Getter
|
|
||||||
@Setter
|
@Getter @Setter @NonNull @lombok.Builder.Default @Deprecated
|
||||||
@NonNull
|
|
||||||
@lombok.Builder.Default
|
|
||||||
@Deprecated
|
|
||||||
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED;
|
||||||
|
|
||||||
|
@Getter @Setter @lombok.Builder.Default protected int iterationCount = 0;
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@lombok.Builder.Default
|
|
||||||
protected int iterationCount = 0;
|
|
||||||
// Counter for the number of epochs completed so far. Used for per-epoch schedules
|
// Counter for the number of epochs completed so far. Used for per-epoch schedules
|
||||||
@Getter
|
@Getter @Setter @lombok.Builder.Default protected int epochCount = 0;
|
||||||
@Setter
|
@lombok.Builder.Default protected double dampingFactor = 100;
|
||||||
@lombok.Builder.Default
|
@EqualsAndHashCode.Exclude private IModel net;
|
||||||
protected int epochCount = 0;
|
private boolean initCalled = false;
|
||||||
@lombok.Builder.Default
|
|
||||||
protected double dampingFactor = 100;
|
|
||||||
// gradient keys used for ensuring order when getting and setting the gradient
|
// gradient keys used for ensuring order when getting and setting the gradient
|
||||||
@lombok.Builder.Default private LinkedHashSet<String> netWideVariables = new LinkedHashSet<>();
|
@lombok.Builder.Default private LinkedHashSet<String> netWideVariables = new LinkedHashSet<>();
|
||||||
|
|
||||||
|
@ -143,22 +107,19 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
*/
|
*/
|
||||||
@Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
|
@Getter @Setter @Builder.Default private IUpdater updater = new Sgd();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN.
|
* Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage
|
||||||
* See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory.
|
* of cuDNN. See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but
|
||||||
* <br>
|
* "NO_WORKSPACE" uses less memory. <br>
|
||||||
* Note: values set by this method will be applied to all applicable layers in the network, unless a different
|
* Note: values set by this method will be applied to all applicable layers in the network, unless
|
||||||
* value is explicitly set on a given layer. In other words: values set via this method are used as the default
|
* a different value is explicitly set on a given layer. In other words: values set via this
|
||||||
* value, and can be overridden on a per-layer basis.
|
* method are used as the default value, and can be overridden on a per-layer basis.
|
||||||
|
*
|
||||||
* @param cudnnAlgoMode cuDNN algo mode to use
|
* @param cudnnAlgoMode cuDNN algo mode to use
|
||||||
*/
|
*/
|
||||||
@Getter
|
@Getter @Setter @lombok.Builder.Default
|
||||||
@Setter
|
|
||||||
@lombok.Builder.Default
|
|
||||||
private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
|
private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a neural net configuration from json
|
* Create a neural net configuration from json
|
||||||
*
|
*
|
||||||
|
@ -166,270 +127,23 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* @return {@link NeuralNetConfiguration}
|
* @return {@link NeuralNetConfiguration}
|
||||||
*/
|
*/
|
||||||
public static NeuralNetConfiguration fromJson(String json) {
|
public static NeuralNetConfiguration fromJson(String json) {
|
||||||
//ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
JsonMapper mapper = JsonMapper.builder()
|
|
||||||
.enable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
|
|
||||||
.build();
|
|
||||||
try {
|
try {
|
||||||
return mapper.readValue(json, NeuralNetConfiguration.class);
|
return mapper.readValue(json, NeuralNetConfiguration.class);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
/*
|
|
||||||
try {
|
|
||||||
conf = mapper.readValue(json, NeuralNetConfiguration.class);
|
|
||||||
} catch (InvalidTypeIdException e) {
|
|
||||||
if (e.getMessage().contains("@class")) {
|
|
||||||
try {
|
|
||||||
//JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format
|
|
||||||
return JsonMappers.getLegacyMapper().readValue(json, NeuralNetConfiguration.class);
|
|
||||||
} catch (InvalidTypeIdException e2) {
|
|
||||||
//Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.ILayer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..."
|
|
||||||
//1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work
|
|
||||||
String msg = e2.getMessage();
|
|
||||||
if (msg != null && msg.contains("Could not resolve type id")) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Error deserializing NeuralNetConfiguration - configuration may have a custom " +
|
|
||||||
"layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom"
|
|
||||||
+
|
|
||||||
" layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
throw new RuntimeException(e2);
|
|
||||||
} catch (IOException e2) {
|
|
||||||
throw new RuntimeException(e2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
} catch (IOException e) {
|
|
||||||
//Check if this exception came from legacy deserializer...
|
|
||||||
String msg = e.getMessage();
|
|
||||||
if (msg != null && msg.contains("legacy")) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Error deserializing NeuralNetConfiguration - configuration may have a custom " +
|
|
||||||
"layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be "
|
|
||||||
+
|
|
||||||
"deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
//To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier)
|
|
||||||
// Previously: enumeration used for loss functions. Now: use classes
|
|
||||||
// IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums
|
|
||||||
int layerCount = 0;
|
|
||||||
JsonNode confs = null;
|
|
||||||
for (LayerConfiguration nnc : conf.getFlattenedLayerConfigurations()) {
|
|
||||||
LayerConfiguration l = nnc;
|
|
||||||
if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFunction() == null) {
|
|
||||||
//lossFn field null -> may be an old config format, with lossFunction field being for the enum
|
|
||||||
//if so, try walking the JSON graph to extract out the appropriate enum value
|
|
||||||
|
|
||||||
BaseOutputLayer ol = (BaseOutputLayer) l;
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"should never happen"); //return conf; //Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode outputLayerNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
JsonNode lossFunctionNode = null;
|
|
||||||
if (outputLayerNode.has("output")) {
|
|
||||||
lossFunctionNode = outputLayerNode.get("output").get("lossFunction");
|
|
||||||
} else if (outputLayerNode.has("rnnoutput")) {
|
|
||||||
lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lossFunctionNode != null) {
|
|
||||||
String lossFunctionEnumStr = lossFunctionNode.asText();
|
|
||||||
LossFunctions.LossFunction lossFunction = null;
|
|
||||||
try {
|
|
||||||
lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (lossFunction != null) {
|
|
||||||
switch (lossFunction) {
|
|
||||||
case MSE:
|
|
||||||
ol.setLossFunction(new LossMSE());
|
|
||||||
break;
|
|
||||||
case XENT:
|
|
||||||
ol.setLossFunction(new LossBinaryXENT());
|
|
||||||
break;
|
|
||||||
case NEGATIVELOGLIKELIHOOD:
|
|
||||||
ol.setLossFunction(new LossNegativeLogLikelihood());
|
|
||||||
break;
|
|
||||||
case MCXENT:
|
|
||||||
ol.setLossFunction(new LossMCXENT());
|
|
||||||
break;
|
|
||||||
|
|
||||||
//Remaining: TODO
|
|
||||||
case SQUARED_LOSS:
|
|
||||||
case RECONSTRUCTION_CROSSENTROPY:
|
|
||||||
default:
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}",
|
|
||||||
lossFunction);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})",
|
|
||||||
(confs != null ? confs.getClass() : null));
|
|
||||||
}
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON",
|
|
||||||
e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn")
|
|
||||||
//Try to load the old format if necessary, and create the appropriate IActivation instance
|
|
||||||
if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getActivationFn() == null) {
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
throw new RuntimeException(
|
|
||||||
"Should never happen"); //return conf; //Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
|
||||||
JsonNode activationFunction = layerNode.get(
|
|
||||||
"activationFunction"); //Should only have 1 element: "dense", "output", etc
|
|
||||||
|
|
||||||
if (activationFunction != null) {
|
|
||||||
Activation ia = Activation.fromString(activationFunction.asText());
|
|
||||||
((BaseLayerConfiguration) l).setActivation(ia.getActivationFunction());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"ILayer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) {
|
|
||||||
return conf;
|
|
||||||
}
|
|
||||||
|
|
||||||
layerCount++;
|
|
||||||
}
|
|
||||||
return conf;
|
|
||||||
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied
|
|
||||||
* from handling of {@link Activation} above.
|
|
||||||
*
|
|
||||||
* @return True if all is well and layer iteration shall continue. False else-wise.
|
|
||||||
*/
|
|
||||||
private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l,
|
|
||||||
ObjectMapper mapper,
|
|
||||||
JsonNode confs, int layerCount) {
|
|
||||||
if ((l instanceof BaseLayerConfiguration) ) { //&& ((BaseLayerConfiguration) l).getWeightInit() == null) {
|
|
||||||
try {
|
|
||||||
JsonNode jsonNode = mapper.readTree(json);
|
|
||||||
if (confs == null) {
|
|
||||||
confs = jsonNode.get("confs");
|
|
||||||
}
|
|
||||||
if (confs instanceof ArrayNode) {
|
|
||||||
ArrayNode layerConfs = (ArrayNode) confs;
|
|
||||||
JsonNode outputLayerNNCNode = layerConfs.get(layerCount);
|
|
||||||
if (outputLayerNNCNode == null) {
|
|
||||||
return false; //Should never happen...
|
|
||||||
}
|
|
||||||
JsonNode layerWrapperNode = outputLayerNNCNode.get("layer");
|
|
||||||
|
|
||||||
if (layerWrapperNode == null || layerWrapperNode.size() != 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
JsonNode layerNode = layerWrapperNode.elements().next();
|
|
||||||
JsonNode weightInit = layerNode.get(
|
|
||||||
"weightInit"); //Should only have 1 element: "dense", "output", etc
|
|
||||||
JsonNode distribution = layerNode.get("dist");
|
|
||||||
|
|
||||||
Distribution dist = null;
|
|
||||||
if (distribution != null) {
|
|
||||||
dist = mapper.treeToValue(distribution, Distribution.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (weightInit != null) {
|
|
||||||
final IWeightInit wi = WeightInit.valueOf(weightInit.asText())
|
|
||||||
.getWeightInitFunction(dist);
|
|
||||||
((BaseLayerConfiguration) l).setWeightInit(wi);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.warn(
|
|
||||||
"ILayer with null WeightInit detected: " + l.getName() + ", could not parse JSON",
|
|
||||||
e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapperYaml() {
|
|
||||||
return JsonMappers.getMapperYaml();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Object mapper for serialization of configurations
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ObjectMapper mapper() {
|
|
||||||
return JsonMappers.getMapper();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static NeuralNetConfiguration fromYaml(String input) {
|
public static NeuralNetConfiguration fromYaml(String input) {
|
||||||
throw new RuntimeException("Needs fixing - not supported."); // TODO
|
throw new RuntimeException("Needs fixing - not supported."); // TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return JSON representation of NN configuration
|
* @return JSON representation of NN configuration
|
||||||
*/
|
*/
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapperYaml();
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.YAML);
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
try {
|
try {
|
||||||
return mapper.writeValueAsString(this);
|
return mapper.writeValueAsString(this);
|
||||||
|
@ -443,14 +157,12 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* @return JSON representation of NN configuration
|
* @return JSON representation of NN configuration
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
JsonMapper mapper = JsonMapper.builder()
|
ObjectMapper mapper = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
.enable(SerializationFeature.INDENT_OUTPUT)
|
|
||||||
.enable(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY)
|
|
||||||
.build();
|
|
||||||
//ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
// JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields
|
||||||
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
// occasionally
|
||||||
|
// when writeValueAsString is used by multiple threads. This results in invalid JSON. See
|
||||||
|
// issue #3243
|
||||||
try {
|
try {
|
||||||
return mapper.writeValueAsString(this);
|
return mapper.writeValueAsString(this);
|
||||||
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
} catch (com.fasterxml.jackson.core.JsonProcessingException e) {
|
||||||
|
@ -469,7 +181,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
public NeuralNetConfiguration clone() {
|
public NeuralNetConfiguration clone() {
|
||||||
NeuralNetConfiguration clone;
|
NeuralNetConfiguration clone;
|
||||||
clone = (NeuralNetConfiguration) super.clone();
|
clone = (NeuralNetConfiguration) super.clone();
|
||||||
if(getStepFunction() != null) { clone.setStepFunction(getStepFunction().clone()); }
|
if (getStepFunction() != null) {
|
||||||
|
clone.setStepFunction(getStepFunction().clone());
|
||||||
|
}
|
||||||
clone.netWideVariables = new LinkedHashSet<>(netWideVariables);
|
clone.netWideVariables = new LinkedHashSet<>(netWideVariables);
|
||||||
clone.setInnerConfigurations(new ArrayList<>(innerConfigurations));
|
clone.setInnerConfigurations(new ArrayList<>(innerConfigurations));
|
||||||
|
|
||||||
|
@ -489,20 +203,15 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
clone.setDataType(this.getDataType());
|
clone.setDataType(this.getDataType());
|
||||||
|
|
||||||
return clone;
|
return clone;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** */
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public void init() {
|
public void init() {
|
||||||
if (initCalled) return;
|
if (initCalled) return;
|
||||||
initCalled = true;
|
initCalled = true;
|
||||||
|
|
||||||
/**
|
/** Run init() for each layer */
|
||||||
* Run init() for each layer
|
|
||||||
*/
|
|
||||||
for (NeuralNetConfiguration nconf : getNetConfigurations()) {
|
for (NeuralNetConfiguration nconf : getNetConfigurations()) {
|
||||||
nconf.init();
|
nconf.init();
|
||||||
}
|
}
|
||||||
|
@ -514,24 +223,31 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
// innerConfigurations.add(0, this); //put this configuration at first place
|
// innerConfigurations.add(0, this); //put this configuration at first place
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Inherit network wide configuration setting to those layer configurations
|
* Inherit network wide configuration setting to those layer configurations that do not have an
|
||||||
* that do not have an individual setting (nor a default)
|
* individual setting (nor a default)
|
||||||
*/
|
*/
|
||||||
for (LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
|
for (LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) {
|
||||||
lconf.runInheritance();
|
lconf.runInheritance();
|
||||||
}
|
}
|
||||||
|
|
||||||
getLayerConfigurations().stream().forEach( lconf -> lconf.setNetConfiguration(this)); //set this as net config for all layers (defined in here, not stacked
|
getLayerConfigurations().stream()
|
||||||
|
.forEach(
|
||||||
|
lconf ->
|
||||||
|
lconf.setNetConfiguration(
|
||||||
|
this)); // set this as net config for all layers (defined in here, not stacked
|
||||||
|
|
||||||
// Validate BackpropType setting
|
// Validate BackpropType setting
|
||||||
if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
|
if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH)
|
||||||
&& backpropType != BackpropType.TruncatedBPTT) {
|
&& backpropType != BackpropType.TruncatedBPTT) {
|
||||||
log.warn("Truncated backpropagation through time lengths have been configured with values "
|
log.warn(
|
||||||
|
"Truncated backpropagation through time lengths have been configured with values "
|
||||||
+ tbpttFwdLength
|
+ tbpttFwdLength
|
||||||
+ " and " + tbpttBackLength + " but backprop type is set to " + backpropType
|
+ " and "
|
||||||
+ ". TBPTT configuration" +
|
+ tbpttBackLength
|
||||||
" settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
|
+ " but backprop type is set to "
|
||||||
|
+ backpropType
|
||||||
|
+ ". TBPTT configuration"
|
||||||
|
+ " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (backpropType == BackpropType.TruncatedBPTT && isValidateTbpttConfig()) {
|
if (backpropType == BackpropType.TruncatedBPTT && isValidateTbpttConfig()) {
|
||||||
|
@ -541,21 +257,24 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
|
if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
|
"Invalid network configuration detected: Truncated backpropagation through time (TBPTT)"
|
||||||
+
|
+ " cannot be used with layer "
|
||||||
" cannot be used with layer " + i + " of type " + l.getClass().getName()
|
+ i
|
||||||
+ ": TBPTT is incompatible with this layer type (which is designed " +
|
+ " of type "
|
||||||
"to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
|
+ l.getClass().getName()
|
||||||
+
|
+ ": TBPTT is incompatible with this layer type (which is designed "
|
||||||
"This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
|
+ "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n"
|
||||||
|
+ "This check can be disabled using validateTbpttConfig(false) but this is not recommended.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (getInputType() == null && inputPreProcessors.get(0) == null) {
|
if (getInputType() == null && inputPreProcessors.get(0) == null) {
|
||||||
// User hasn't set the InputType. Sometimes we can infer it...
|
// User hasn't set the InputType. Sometimes we can infer it...
|
||||||
// For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in
|
// For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to
|
||||||
|
// feed in
|
||||||
// standard feedforward or RNN data
|
// standard feedforward or RNN data
|
||||||
//This isn't the most elegant implementation, but should avoid breaking backward compatibility here
|
// This isn't the most elegant implementation, but should avoid breaking backward
|
||||||
|
// compatibility here
|
||||||
// Can't infer InputType for CNN layers, however (don't know image dimensions/depth)
|
// Can't infer InputType for CNN layers, however (don't know image dimensions/depth)
|
||||||
LayerConfiguration firstLayer = getFlattenedLayerConfigurations().get(0);
|
LayerConfiguration firstLayer = getFlattenedLayerConfigurations().get(0);
|
||||||
if (firstLayer instanceof BaseRecurrentLayer) {
|
if (firstLayer instanceof BaseRecurrentLayer) {
|
||||||
|
@ -564,9 +283,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
if (nIn > 0) {
|
if (nIn > 0) {
|
||||||
setInputType(InputType.recurrent(nIn, brl.getDataFormat()));
|
setInputType(InputType.recurrent(nIn, brl.getDataFormat()));
|
||||||
}
|
}
|
||||||
} else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer
|
} else if (firstLayer instanceof DenseLayer
|
||||||
|
|| firstLayer instanceof EmbeddingLayer
|
||||||
|| firstLayer instanceof OutputLayer) {
|
|| firstLayer instanceof OutputLayer) {
|
||||||
//Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a FeedForwardLayer
|
// Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a
|
||||||
|
// FeedForwardLayer
|
||||||
FeedForwardLayer ffl = (FeedForwardLayer) firstLayer;
|
FeedForwardLayer ffl = (FeedForwardLayer) firstLayer;
|
||||||
val nIn = ffl.getNIn();
|
val nIn = ffl.getNIn();
|
||||||
if (nIn > 0) {
|
if (nIn > 0) {
|
||||||
|
@ -579,7 +300,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
// Builder.inputType field can be set in 1 of 4 ways:
|
// Builder.inputType field can be set in 1 of 4 ways:
|
||||||
// 1. User calls setInputType directly
|
// 1. User calls setInputType directly
|
||||||
// 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...))
|
// 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...))
|
||||||
// 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets the inputType field
|
// 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets
|
||||||
|
// the inputType field
|
||||||
if (inputPreProcessors == null) {
|
if (inputPreProcessors == null) {
|
||||||
inputPreProcessors = new HashMap<>();
|
inputPreProcessors = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
@ -608,26 +330,32 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
|
if (l instanceof DenseLayer && getInputType() instanceof InputType.InputTypeRecurrent) {
|
||||||
FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
|
FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l;
|
||||||
if (getInputType() instanceof InputType.InputTypeRecurrent) {
|
if (getInputType() instanceof InputType.InputTypeRecurrent) {
|
||||||
InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) getInputType();
|
InputType.InputTypeRecurrent recurrent =
|
||||||
|
(InputType.InputTypeRecurrent) getInputType();
|
||||||
feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
|
feedForwardLayer.setNIn(recurrent.getTimeSeriesLength());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
l.setNIn(currentInputType,
|
l.setNIn(
|
||||||
isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
|
currentInputType,
|
||||||
|
isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set
|
||||||
|
// by the user
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
l.setNIn(currentInputType,
|
l.setNIn(
|
||||||
isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
|
currentInputType,
|
||||||
|
isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set
|
||||||
|
// by the user
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
l.setNIn(currentInputType,
|
l.setNIn(
|
||||||
isOverrideNinUponBuild()); //Don't override the nIn setting, if it's manually set by the user
|
currentInputType,
|
||||||
|
isOverrideNinUponBuild()); // Don't override the nIn setting, if it's manually set by
|
||||||
|
// the user
|
||||||
}
|
}
|
||||||
|
|
||||||
currentInputType = l.getOutputType(i, currentInputType);
|
currentInputType = l.getOutputType(i, currentInputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(getSeed());
|
Nd4j.getRandom().setSeed(getSeed());
|
||||||
|
@ -669,19 +397,21 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
inputType = preproc.getOutputType(inputType);
|
inputType = preproc.getOutputType(inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
LayerMemoryReport report = getFlattenedLayerConfigurations().get(i).getMemoryReport(inputType);
|
LayerMemoryReport report =
|
||||||
|
getFlattenedLayerConfigurations().get(i).getMemoryReport(inputType);
|
||||||
memoryReportMap.put(layerName, report);
|
memoryReportMap.put(layerName, report);
|
||||||
|
|
||||||
inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType);
|
inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
return new NetworkMemoryReport(memoryReportMap, NeuralNetConfiguration.class,
|
return new NetworkMemoryReport(
|
||||||
"MultiLayerNetwork", inputType);
|
memoryReportMap, NeuralNetConfiguration.class, "MultiLayerNetwork", inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* For the given input shape/type for the network, return a list of activation sizes for each
|
* For the given input shape/type for the network, return a list of activation sizes for each
|
||||||
* layer in the network.<br> i.e., list.get(i) is the output activation sizes for layer i
|
* layer in the network.<br>
|
||||||
|
* i.e., list.get(i) is the output activation sizes for layer i
|
||||||
*
|
*
|
||||||
* @param inputType Input type for the network
|
* @param inputType Input type for the network
|
||||||
* @return A lits of activation types for the network, indexed by layer number
|
* @return A lits of activation types for the network, indexed by layer number
|
||||||
|
@ -715,23 +445,30 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
public void addNetWideVariable(String variable) {
|
public void addNetWideVariable(String variable) {
|
||||||
if (!netWideVariables.contains(variable)) {
|
if (!netWideVariables.contains(variable)) {
|
||||||
netWideVariables.add(variable);
|
netWideVariables.add(variable);
|
||||||
log.trace("Adding neural network wide variable '{}' to the list of variables. New length is {}.", variable, netWideVariables.size());
|
log.trace(
|
||||||
|
"Adding neural network wide variable '{}' to the list of variables. New length is {}.",
|
||||||
|
variable,
|
||||||
|
netWideVariables.size());
|
||||||
}
|
}
|
||||||
log.trace("Skipped adding neural network wide variable '{}' to the list of variables. It was already present. Length remains {}.", variable, netWideVariables.size());
|
log.trace(
|
||||||
|
"Skipped adding neural network wide variable '{}' to the list of variables. It was already present. Length remains {}.",
|
||||||
|
variable,
|
||||||
|
netWideVariables.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void clearNetWideVariable() {
|
public void clearNetWideVariable() {
|
||||||
|
|
||||||
netWideVariables.clear();
|
netWideVariables.clear();
|
||||||
log.trace("Adding neural network wide variables have been cleared. New length is {}.", netWideVariables.size());
|
log.trace(
|
||||||
|
"Adding neural network wide variables have been cleared. New length is {}.",
|
||||||
|
netWideVariables.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* From the list of layers and neural net configurations, only return the Layer Configurations that
|
* From the list of layers and neural net configurations, only return the Layer Configurations
|
||||||
* are defined in this neural network (it does not include embedded neural network configuration
|
* that are defined in this neural network (it does not include embedded neural network
|
||||||
* layers)
|
* configuration layers)
|
||||||
|
*
|
||||||
* @return list with layer configurations
|
* @return list with layer configurations
|
||||||
*/
|
*/
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
|
@ -743,7 +480,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* From the list of layers and neural net configurations, only return the neural net configurations
|
* From the list of layers and neural net configurations, only return the neural net
|
||||||
|
* configurations
|
||||||
|
*
|
||||||
* @return list with neural net configurations
|
* @return list with neural net configurations
|
||||||
*/
|
*/
|
||||||
// @Synchronized("innerConfigurationsLock")
|
// @Synchronized("innerConfigurationsLock")
|
||||||
|
@ -769,21 +508,27 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
|
public List<LayerConfiguration> getFlattenedLayerConfigurations(NeuralNetConfiguration conf) {
|
||||||
List<LayerConfiguration> ret = new ArrayList<>(); // create the final return list
|
List<LayerConfiguration> ret = new ArrayList<>(); // create the final return list
|
||||||
// When properly initialized, _this_ configuration is set first in the list, however we
|
// When properly initialized, _this_ configuration is set first in the list, however we
|
||||||
//can find cases where this is not true, thus the first configuration is another net or layer configuration
|
// can find cases where this is not true, thus the first configuration is another net or layer
|
||||||
|
// configuration
|
||||||
// and should not be skipped. In essence, skip first configuration if that is "this".
|
// and should not be skipped. In essence, skip first configuration if that is "this".
|
||||||
// TODO: skipping not needed anymore as we removed _this_ from innerConfigurations
|
// TODO: skipping not needed anymore as we removed _this_ from innerConfigurations
|
||||||
int iSkip = 0;
|
int iSkip = 0;
|
||||||
if(conf.getInnerConfigurations().size()>0 && conf.getInnerConfigurations().get(0).equals(this)) { iSkip=1;}
|
if (conf.getInnerConfigurations().size() > 0
|
||||||
conf.getInnerConfigurations().stream().skip(iSkip)
|
&& conf.getInnerConfigurations().get(0).equals(this)) {
|
||||||
.forEach(obj -> {
|
iSkip = 1;
|
||||||
|
}
|
||||||
|
conf.getInnerConfigurations().stream()
|
||||||
|
.skip(iSkip)
|
||||||
|
.forEach(
|
||||||
|
obj -> {
|
||||||
// if Layer Config, include in list and inherit parameters from this conf
|
// if Layer Config, include in list and inherit parameters from this conf
|
||||||
//else if neural net configuration, call self recursively to resolve layer configurations
|
// else if neural net configuration, call self recursively to resolve layer
|
||||||
|
// configurations
|
||||||
if (obj instanceof LayerConfiguration) {
|
if (obj instanceof LayerConfiguration) {
|
||||||
((LayerConfiguration) obj).setNetConfiguration(conf);
|
((LayerConfiguration) obj).setNetConfiguration(conf);
|
||||||
ret.add((LayerConfiguration) obj);
|
ret.add((LayerConfiguration) obj);
|
||||||
} else if (obj instanceof NeuralNetConfiguration)
|
} else if (obj instanceof NeuralNetConfiguration)
|
||||||
ret.addAll(getFlattenedLayerConfigurations(
|
ret.addAll(getFlattenedLayerConfigurations((NeuralNetConfiguration) obj));
|
||||||
(NeuralNetConfiguration) obj));
|
|
||||||
else {
|
else {
|
||||||
log.error(
|
log.error(
|
||||||
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
|
"The list of layers and neural network configurations does contain an object of {}. Element will be ignored.",
|
||||||
|
@ -794,8 +539,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sames as {@link #getFlattenedLayerConfigurations(NeuralNetConfiguration)}, but uses this configurations
|
* Sames as {@link #getFlattenedLayerConfigurations(NeuralNetConfiguration)}, but uses this
|
||||||
* list of configurations
|
* configurations list of configurations
|
||||||
|
*
|
||||||
* @return list of layer configurations
|
* @return list of layer configurations
|
||||||
*/
|
*/
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
|
@ -805,6 +551,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a new layer to the first position
|
* Add a new layer to the first position
|
||||||
|
*
|
||||||
* @param layer configuration
|
* @param layer configuration
|
||||||
*/
|
*/
|
||||||
public void setLayer(@NonNull LayerConfiguration layer) {
|
public void setLayer(@NonNull LayerConfiguration layer) {
|
||||||
|
@ -817,26 +564,28 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Deprecated, do not use. Workaround for old tests
|
* Deprecated, do not use. Workaround for old tests and getFlattenedLayerConfigurations().get(0);
|
||||||
* and getFlattenedLayerConfigurations().get(0);
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Deprecated @JsonIgnore
|
@Deprecated
|
||||||
|
@JsonIgnore
|
||||||
public LayerConfiguration getFirstLayer() {
|
public LayerConfiguration getFirstLayer() {
|
||||||
log.warn("This getFirstLayer method is an ugly workaround and will be removed.");
|
log.warn("This getFirstLayer method is an ugly workaround and will be removed.");
|
||||||
return getFlattenedLayerConfigurations().get(0);
|
return getFlattenedLayerConfigurations().get(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
protected boolean canEqual(final Object other) {
|
protected boolean canEqual(final Object other) {
|
||||||
return other instanceof NeuralNetConfiguration;
|
return other instanceof NeuralNetConfiguration;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
public abstract static class NeuralNetConfigurationBuilder<
|
||||||
public static abstract class NeuralNetConfigurationBuilder<C extends NeuralNetConfiguration,
|
C extends NeuralNetConfiguration,
|
||||||
B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>> extends
|
B extends NeuralNetConfiguration.NeuralNetConfigurationBuilder<C, B>>
|
||||||
NeuralNetBaseBuilderConfigurationBuilder<C, B> {
|
extends NeuralNetBaseBuilderConfigurationBuilder<C, B> {
|
||||||
|
|
||||||
public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
|
public ComputationGraphConfiguration.GraphBuilder graphBuilder() {
|
||||||
return new ComputationGraphConfiguration.GraphBuilder(this);
|
return new ComputationGraphConfiguration.GraphBuilder(this);
|
||||||
|
@ -849,6 +598,5 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
throw new RuntimeException(ex);
|
throw new RuntimeException(ex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,8 +25,8 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "type",
|
|
||||||
defaultImpl = LegacyDistributionHelper.class)
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, property = "@class")
|
||||||
public abstract class Distribution implements Serializable, Cloneable {
|
public abstract class Distribution implements Serializable, Cloneable {
|
||||||
|
|
||||||
private static final long serialVersionUID = 5401741214954998498L;
|
private static final long serialVersionUID = 5401741214954998498L;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.distribution;
|
package org.deeplearning4j.nn.conf.distribution;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import com.fasterxml.jackson.annotation.JsonCreator;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
|
@ -48,21 +49,7 @@ public class NormalDistribution extends Distribution {
|
||||||
this.std = std;
|
this.std = std;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getMean() {
|
|
||||||
return mean;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setMean(double mean) {
|
|
||||||
this.mean = mean;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getStd() {
|
|
||||||
return std;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setStd(double std) {
|
|
||||||
this.std = std;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.deeplearning4j.nn.conf.CacheMode;
|
import org.deeplearning4j.nn.conf.CacheMode;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
|
||||||
|
@ -151,7 +152,7 @@ public abstract class MemoryReport {
|
||||||
|
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapper().writeValueAsString(this);
|
return CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -159,7 +160,7 @@ public abstract class MemoryReport {
|
||||||
|
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapperYaml().writeValueAsString(this);
|
return CavisMapper.getMapper(CavisMapper.Type.YAML).writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -167,7 +168,7 @@ public abstract class MemoryReport {
|
||||||
|
|
||||||
public static MemoryReport fromJson(String json) {
|
public static MemoryReport fromJson(String json) {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapper().readValue(json, MemoryReport.class);
|
return CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(json, MemoryReport.class);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -175,7 +176,7 @@ public abstract class MemoryReport {
|
||||||
|
|
||||||
public static MemoryReport fromYaml(String yaml) {
|
public static MemoryReport fromYaml(String yaml) {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapperYaml().readValue(yaml, MemoryReport.class);
|
return CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(yaml, MemoryReport.class);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -251,7 +251,7 @@ public abstract class BaseNetConfigDeserializer<T> extends StdDeserializer<T> im
|
||||||
Distribution d = null;
|
Distribution d = null;
|
||||||
if(w == WeightInit.DISTRIBUTION && on.has("dist")){
|
if(w == WeightInit.DISTRIBUTION && on.has("dist")){
|
||||||
String dist = on.get("dist").toString();
|
String dist = on.get("dist").toString();
|
||||||
d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class);
|
d = CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(dist, Distribution.class);
|
||||||
}
|
}
|
||||||
IWeightInit iwi = w.getWeightInitFunction(d);
|
IWeightInit iwi = w.getWeightInitFunction(d);
|
||||||
baseLayerConfiguration.setWeightInit(iwi);
|
baseLayerConfiguration.setWeightInit(iwi);
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.deeplearning4j.nn.conf.serde;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.DeserializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.MapperFeature;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||||
|
import com.fasterxml.jackson.dataformat.yaml.YAMLMapper;
|
||||||
|
import lombok.NonNull;
|
||||||
|
|
||||||
|
public class CavisMapper {
|
||||||
|
|
||||||
|
public static ObjectMapper getMapper(@NonNull Type type) {
|
||||||
|
ObjectMapper mapper;
|
||||||
|
switch (type) {
|
||||||
|
case JSON:
|
||||||
|
mapper = JsonMapper.builder()
|
||||||
|
.enable(SerializationFeature.INDENT_OUTPUT)
|
||||||
|
.enable(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY)
|
||||||
|
.enable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES)
|
||||||
|
.enable(DeserializationFeature.FAIL_ON_INVALID_SUBTYPE)
|
||||||
|
.enable(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES)
|
||||||
|
.build();
|
||||||
|
break;
|
||||||
|
case YAML:
|
||||||
|
mapper = YAMLMapper.builder().build();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Mapper type not recognised.");
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapper;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static enum Type {
|
||||||
|
JSON,
|
||||||
|
YAML
|
||||||
|
}
|
||||||
|
}
|
|
@ -96,7 +96,7 @@ public class ComputationGraphConfigurationDeserializer
|
||||||
}
|
}
|
||||||
jsonSubString = s.substring((int) charOffsetStart - 1, charOffsetEnd.intValue());
|
jsonSubString = s.substring((int) charOffsetStart - 1, charOffsetEnd.intValue());
|
||||||
|
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
JsonNode rootNode = om.readTree(jsonSubString);
|
JsonNode rootNode = om.readTree(jsonSubString);
|
||||||
|
|
||||||
ObjectNode verticesNode = (ObjectNode) rootNode.get("vertices");
|
ObjectNode verticesNode = (ObjectNode) rootNode.get("vertices");
|
||||||
|
|
|
@ -78,7 +78,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize
|
||||||
}
|
}
|
||||||
String jsonSubString = s.substring((int) charOffsetStart - 1, (int) charOffsetEnd);
|
String jsonSubString = s.substring((int) charOffsetStart - 1, (int) charOffsetEnd);
|
||||||
|
|
||||||
ObjectMapper om = NeuralNetConfiguration.mapper();
|
ObjectMapper om = CavisMapper.getMapper(CavisMapper.Type.JSON);
|
||||||
JsonNode rootNode = om.readTree(jsonSubString);
|
JsonNode rootNode = om.readTree(jsonSubString);
|
||||||
|
|
||||||
ArrayNode confsNode = (ArrayNode)rootNode.get("confs");
|
ArrayNode confsNode = (ArrayNode)rootNode.get("confs");
|
||||||
|
|
|
@ -47,6 +47,7 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.LayerValidation;
|
import org.deeplearning4j.nn.conf.layers.LayerValidation;
|
||||||
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
import org.deeplearning4j.nn.conf.stepfunctions.StepFunction;
|
||||||
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
@ -118,7 +119,7 @@ public class FineTuneConfiguration {
|
||||||
|
|
||||||
public static FineTuneConfiguration fromJson(String json) {
|
public static FineTuneConfiguration fromJson(String json) {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapper().readValue(json, FineTuneConfiguration.class);
|
return CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(json, FineTuneConfiguration.class);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -126,7 +127,7 @@ public class FineTuneConfiguration {
|
||||||
|
|
||||||
public static FineTuneConfiguration fromYaml(String yaml) {
|
public static FineTuneConfiguration fromYaml(String yaml) {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapperYaml().readValue(yaml, FineTuneConfiguration.class);
|
return CavisMapper.getMapper(CavisMapper.Type.YAML).readValue(yaml, FineTuneConfiguration.class);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -322,7 +323,7 @@ public class FineTuneConfiguration {
|
||||||
|
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapper().writeValueAsString(this);
|
return CavisMapper.getMapper(CavisMapper.Type.JSON).writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
@ -330,7 +331,7 @@ public class FineTuneConfiguration {
|
||||||
|
|
||||||
public String toYaml() {
|
public String toYaml() {
|
||||||
try {
|
try {
|
||||||
return NeuralNetConfiguration.mapperYaml().writeValueAsString(this);
|
return CavisMapper.getMapper(CavisMapper.Type.YAML).writeValueAsString(this);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,7 +26,7 @@ import com.fasterxml.jackson.annotation.JsonTypeInfo;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, property = "@class")
|
||||||
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
|
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
|
||||||
setterVisibility = JsonAutoDetect.Visibility.NONE)
|
setterVisibility = JsonAutoDetect.Visibility.NONE)
|
||||||
public interface IWeightInit extends Serializable {
|
public interface IWeightInit extends Serializable {
|
||||||
|
|
|
@ -20,19 +20,24 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.weights;
|
package org.deeplearning4j.nn.weights;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import com.fasterxml.jackson.annotation.JsonCreator;
|
||||||
|
import lombok.*;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distributions;
|
import org.deeplearning4j.nn.conf.distribution.Distributions;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
|
import org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution;
|
||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
|
@NoArgsConstructor
|
||||||
public class WeightInitDistribution implements IWeightInit {
|
public class WeightInitDistribution implements IWeightInit {
|
||||||
|
|
||||||
private final Distribution distribution;
|
@Getter @Setter
|
||||||
|
private Distribution distribution;
|
||||||
|
|
||||||
public WeightInitDistribution(@JsonProperty("distribution") Distribution distribution) {
|
|
||||||
|
public WeightInitDistribution(@NonNull Distribution distribution) {
|
||||||
if(distribution == null) {
|
if(distribution == null) {
|
||||||
// Would fail later below otherwise
|
// Would fail later below otherwise
|
||||||
throw new IllegalArgumentException("Must set distribution!");
|
throw new IllegalArgumentException("Must set distribution!");
|
||||||
|
@ -40,6 +45,7 @@ public class WeightInitDistribution implements IWeightInit {
|
||||||
this.distribution = distribution;
|
this.distribution = distribution;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
|
||||||
//org.nd4j.linalg.api.rng.distribution.Distribution not serializable
|
//org.nd4j.linalg.api.rng.distribution.Distribution not serializable
|
||||||
|
|
|
@ -0,0 +1,176 @@
|
||||||
|
/*
|
||||||
|
*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
package net.brutex.ai.dnn.serde;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.*;
|
||||||
|
class NeuralNetConfigurationSerdeTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void toYaml() {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void toJson() {
|
||||||
|
final var conf = NeuralNetConfiguration.builder()
|
||||||
|
.weightInit(new NormalDistribution(3, 2))
|
||||||
|
.layer(DenseLayer.builder().nIn(100).nOut(30).build())
|
||||||
|
.layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.SQUARED_LOSS).build())
|
||||||
|
.build();
|
||||||
|
assertEquals(conf.toJson(), NeuralNetConfiguration.fromJson(conf.toJson()).toJson());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void toJson2() {
|
||||||
|
final var conf = NeuralNetConfiguration.builder()
|
||||||
|
.weightInit(WeightInit.IDENTITY)
|
||||||
|
.layer(DenseLayer.builder().nIn(100).nOut(30).build())
|
||||||
|
.layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.SQUARED_LOSS).build())
|
||||||
|
.build();
|
||||||
|
assertEquals(conf.toJson(), NeuralNetConfiguration.fromJson(conf.toJson()).toJson());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void fromJson() {
|
||||||
|
final String json = "{\n" +
|
||||||
|
" \"activation\" : null,\n" +
|
||||||
|
" \"seed\" : 12345, \n" +
|
||||||
|
" \"allParamConstraints\" : [ ],\n" +
|
||||||
|
" \"backpropType\" : \"Standard\",\n" +
|
||||||
|
" \"biasConstraints\" : [ ],\n" +
|
||||||
|
" \"biasInit\" : 0.0,\n" +
|
||||||
|
" \"biasUpdater\" : null,\n" +
|
||||||
|
" \"cacheMode\" : \"NONE\",\n" +
|
||||||
|
" \"constrainWeights\" : [ ],\n" +
|
||||||
|
" \"convolutionMode\" : \"Truncate\",\n" +
|
||||||
|
" \"cudnnAlgoMode\" : \"PREFER_FASTEST\",\n" +
|
||||||
|
" \"dampingFactor\" : 100.0,\n" +
|
||||||
|
" \"dataType\" : \"FLOAT\",\n" +
|
||||||
|
" \"epochCount\" : 0,\n" +
|
||||||
|
" \"gainInit\" : 1.0,\n" +
|
||||||
|
" \"gradientNormalization\" : \"None\",\n" +
|
||||||
|
" \"gradientNormalizationThreshold\" : 0.0,\n" +
|
||||||
|
" \"idropOut\" : null,\n" +
|
||||||
|
" \"inferenceWorkspaceMode\" : \"ENABLED\",\n" +
|
||||||
|
" \"initCalled\" : false,\n" +
|
||||||
|
" \"innerConfigurations\" : [ {\n" +
|
||||||
|
" \"org.deeplearning4j.nn.conf.layers.DenseLayer\" : {\n" +
|
||||||
|
" \"activation\" : [ \"org.nd4j.linalg.activations.Activation\", \"IDENTITY\" ],\n" +
|
||||||
|
" \"allParamConstraints\" : null,\n" +
|
||||||
|
" \"biasConstraints\" : null,\n" +
|
||||||
|
" \"biasInit\" : 0.0,\n" +
|
||||||
|
" \"biasUpdater\" : null,\n" +
|
||||||
|
" \"constrainWeights\" : [ ],\n" +
|
||||||
|
" \"constraints\" : null,\n" +
|
||||||
|
" \"dataType\" : null,\n" +
|
||||||
|
" \"dropOut\" : null,\n" +
|
||||||
|
" \"gainInit\" : 0.0,\n" +
|
||||||
|
" \"gradientNormalization\" : \"None\",\n" +
|
||||||
|
" \"gradientNormalizationThreshold\" : 1.0,\n" +
|
||||||
|
" \"hasBias\" : true,\n" +
|
||||||
|
" \"hasLayerNorm\" : false,\n" +
|
||||||
|
" \"name\" : null,\n" +
|
||||||
|
" \"nout\" : 30,\n" +
|
||||||
|
" \"regularization\" : [ ],\n" +
|
||||||
|
" \"regularizationBias\" : [ ],\n" +
|
||||||
|
" \"type\" : \"UNKNOWN\",\n" +
|
||||||
|
" \"variables\" : [ ],\n" +
|
||||||
|
" \"weightConstraints\" : null,\n" +
|
||||||
|
" \"weightInit\" : null,\n" +
|
||||||
|
" \"weightNoise\" : null\n" +
|
||||||
|
" }\n" +
|
||||||
|
" }, {\n" +
|
||||||
|
" \"org.deeplearning4j.nn.conf.layers.OutputLayer\" : {\n" +
|
||||||
|
" \"activation\" : [ \"org.nd4j.linalg.activations.Activation\", \"IDENTITY\" ],\n" +
|
||||||
|
" \"allParamConstraints\" : null,\n" +
|
||||||
|
" \"biasConstraints\" : null,\n" +
|
||||||
|
" \"biasInit\" : 0.0,\n" +
|
||||||
|
" \"biasUpdater\" : null,\n" +
|
||||||
|
" \"constrainWeights\" : [ ],\n" +
|
||||||
|
" \"constraints\" : null,\n" +
|
||||||
|
" \"dataType\" : null,\n" +
|
||||||
|
" \"dropOut\" : null,\n" +
|
||||||
|
" \"gainInit\" : 0.0,\n" +
|
||||||
|
" \"gradientNormalization\" : \"None\",\n" +
|
||||||
|
" \"gradientNormalizationThreshold\" : 1.0,\n" +
|
||||||
|
" \"hasBias\" : true,\n" +
|
||||||
|
" \"lossFunction\" : {\n" +
|
||||||
|
" \"@class\" : \"org.nd4j.linalg.lossfunctions.impl.LossMSE\"\n" +
|
||||||
|
" },\n" +
|
||||||
|
" \"name\" : null,\n" +
|
||||||
|
" \"nout\" : 0,\n" +
|
||||||
|
" \"regularization\" : [ ],\n" +
|
||||||
|
" \"regularizationBias\" : [ ],\n" +
|
||||||
|
" \"type\" : \"UNKNOWN\",\n" +
|
||||||
|
" \"variables\" : [ ],\n" +
|
||||||
|
" \"weightConstraints\" : null,\n" +
|
||||||
|
" \"weightInit\" : null,\n" +
|
||||||
|
" \"weightNoise\" : null\n" +
|
||||||
|
" }\n" +
|
||||||
|
" } ],\n" +
|
||||||
|
" \"inputPreProcessors\" : { },\n" +
|
||||||
|
" \"inputType\" : null,\n" +
|
||||||
|
" \"iterationCount\" : 0,\n" +
|
||||||
|
" \"maxNumLineSearchIterations\" : 5,\n" +
|
||||||
|
" \"miniBatch\" : false,\n" +
|
||||||
|
" \"minimize\" : true,\n" +
|
||||||
|
" \"name\" : \"Anonymous INeuralNetworkConfiguration\",\n" +
|
||||||
|
" \"netWideVariables\" : [ ],\n" +
|
||||||
|
" \"optimizationAlgo\" : \"STOCHASTIC_GRADIENT_DESCENT\",\n" +
|
||||||
|
" \"overrideNinUponBuild\" : true,\n" +
|
||||||
|
" \"regularization\" : [ ],\n" +
|
||||||
|
" \"regularizationBias\" : [ ],\n" +
|
||||||
|
" \"stepFunction\" : null,\n" +
|
||||||
|
" \"tbpttBackLength\" : 20,\n" +
|
||||||
|
" \"tbpttFwdLength\" : 20,\n" +
|
||||||
|
" \"trainingWorkspaceMode\" : \"ENABLED\",\n" +
|
||||||
|
" \"updater\" : {\n" +
|
||||||
|
" \"@class\" : \"org.nd4j.linalg.learning.config.Sgd\",\n" +
|
||||||
|
" \"learningRate\" : 0.001\n" +
|
||||||
|
" },\n" +
|
||||||
|
" \"validateOutputLayerConfig\" : true,\n" +
|
||||||
|
" \"validateTbpttConfig\" : true,\n" +
|
||||||
|
" \"weightInit\" : {\n" +
|
||||||
|
" \"@class\" : \"org.deeplearning4j.nn.weights.WeightInitIdentity\",\n" +
|
||||||
|
" \"scale\" : null\n" +
|
||||||
|
" },\n" +
|
||||||
|
" \"weightNoise\" : null\n" +
|
||||||
|
"}";
|
||||||
|
final var conf = NeuralNetConfiguration.builder()
|
||||||
|
.weightInit(WeightInit.IDENTITY)
|
||||||
|
.layer(DenseLayer.builder().nIn(100).nOut(30).build())
|
||||||
|
.layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.SQUARED_LOSS).build())
|
||||||
|
.seed(12345)
|
||||||
|
.build();
|
||||||
|
NeuralNetConfiguration conf2 = NeuralNetConfiguration.fromJson(json);
|
||||||
|
assertEquals(conf.toJson(),conf2.toJson());
|
||||||
|
assertEquals(conf, conf2);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -42,6 +42,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.conf.serde.CavisMapper;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
import org.deeplearning4j.ui.VertxUIServer;
|
import org.deeplearning4j.ui.VertxUIServer;
|
||||||
import org.deeplearning4j.ui.api.HttpMethod;
|
import org.deeplearning4j.ui.api.HttpMethod;
|
||||||
|
@ -909,7 +910,7 @@ public class TrainModule implements UIModule {
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
NeuralNetConfiguration layer =
|
NeuralNetConfiguration layer =
|
||||||
NeuralNetConfiguration.mapper().readValue(config, NeuralNetConfiguration.class);
|
CavisMapper.getMapper(CavisMapper.Type.JSON).readValue(config, NeuralNetConfiguration.class);
|
||||||
return new Triple<>(null, null, layer);
|
return new Triple<>(null, null, layer);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("",e);
|
log.error("",e);
|
||||||
|
|
Loading…
Reference in New Issue