parent
2c8c6d9624
commit
c758cf918f
|
@ -27,6 +27,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
|
import lombok.extern.jackson.Jacksonized;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
|
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
@ -89,6 +90,7 @@ import java.util.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
// The inner builder, that we can then extend ...
|
// The inner builder, that we can then extend ...
|
||||||
@SuperBuilder // TODO fix access
|
@SuperBuilder // TODO fix access
|
||||||
|
@Jacksonized
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetworkConfiguration {
|
public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetworkConfiguration {
|
||||||
|
|
||||||
|
@ -895,7 +897,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
*
|
*
|
||||||
* @param distribution Distribution to use for weight initialization
|
* @param distribution Distribution to use for weight initialization
|
||||||
*/
|
*/
|
||||||
@JsonIgnore
|
|
||||||
public B weightInit(Distribution distribution) {
|
public B weightInit(Distribution distribution) {
|
||||||
this.weightInit$value = new WeightInitDistribution(distribution);
|
this.weightInit$value = new WeightInitDistribution(distribution);
|
||||||
this.weightInit$set = true;
|
this.weightInit$set = true;
|
||||||
|
|
|
@ -24,8 +24,12 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
|
||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.MapperFeature;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.fasterxml.jackson.databind.SerializationFeature;
|
||||||
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
|
||||||
|
import com.fasterxml.jackson.databind.json.JsonMapper;
|
||||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
|
@ -165,7 +169,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* @return {@link NeuralNetConfiguration}
|
* @return {@link NeuralNetConfiguration}
|
||||||
*/
|
*/
|
||||||
public static NeuralNetConfiguration fromJson(String json) {
|
public static NeuralNetConfiguration fromJson(String json) {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
//ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
||||||
|
JsonMapper mapper = JsonMapper.builder().build();
|
||||||
try {
|
try {
|
||||||
return mapper.readValue(json, NeuralNetConfiguration.class);
|
return mapper.readValue(json, NeuralNetConfiguration.class);
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
|
@ -439,7 +444,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
|
||||||
* @return JSON representation of NN configuration
|
* @return JSON representation of NN configuration
|
||||||
*/
|
*/
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
JsonMapper mapper = JsonMapper.builder()
|
||||||
|
.enable(SerializationFeature.INDENT_OUTPUT)
|
||||||
|
.enable(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY)
|
||||||
|
.build();
|
||||||
|
//ObjectMapper mapper = NeuralNetConfiguration.mapper();
|
||||||
synchronized (mapper) {
|
synchronized (mapper) {
|
||||||
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
|
||||||
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243
|
||||||
|
|
|
@ -90,6 +90,7 @@ public abstract class InputType implements Serializable {
|
||||||
*
|
*
|
||||||
* @return int[]
|
* @return int[]
|
||||||
*/
|
*/
|
||||||
|
@JsonIgnore
|
||||||
public long[] getShape() {
|
public long[] getShape() {
|
||||||
return getShape(false);
|
return getShape(false);
|
||||||
}
|
}
|
||||||
|
@ -431,7 +432,7 @@ public abstract class InputType implements Serializable {
|
||||||
return height * width * depth * channels;
|
return height * width * depth * channels;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override @JsonIgnore
|
||||||
public long[] getShape(boolean includeBatchDim) {
|
public long[] getShape(boolean includeBatchDim) {
|
||||||
if(dataFormat == Convolution3D.DataFormat.NDHWC){
|
if(dataFormat == Convolution3D.DataFormat.NDHWC){
|
||||||
if(includeBatchDim) return new long[]{-1, depth, height, width, channels};
|
if(includeBatchDim) return new long[]{-1, depth, height, width, channels};
|
||||||
|
|
|
@ -40,6 +40,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
|
@NoArgsConstructor
|
||||||
public class CapsuleLayer extends SameDiffLayer {
|
public class CapsuleLayer extends SameDiffLayer {
|
||||||
|
|
||||||
private static final String WEIGHT_PARAM = "weight";
|
private static final String WEIGHT_PARAM = "weight";
|
||||||
|
|
|
@ -92,6 +92,7 @@ public abstract class LayerConfiguration
|
||||||
*
|
*
|
||||||
* @return activation function
|
* @return activation function
|
||||||
*/
|
*/
|
||||||
|
@JsonIgnore
|
||||||
public IActivation getActivationFn() {
|
public IActivation getActivationFn() {
|
||||||
if (activation == null)
|
if (activation == null)
|
||||||
throw new RuntimeException(
|
throw new RuntimeException(
|
||||||
|
|
|
@ -30,12 +30,12 @@ import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
@Data
|
@Data
|
||||||
public class LossFunctionWrapper implements ReconstructionDistribution {
|
public class LossFunctionWrapper implements ReconstructionDistribution {
|
||||||
|
|
||||||
private final IActivation activationFn;
|
private final IActivation activation;
|
||||||
private final ILossFunction lossFunction;
|
private final ILossFunction lossFunction;
|
||||||
|
|
||||||
public LossFunctionWrapper(@JsonProperty("activationFn") IActivation activationFn,
|
public LossFunctionWrapper(@JsonProperty("activation") IActivation activation,
|
||||||
@JsonProperty("lossFunction") ILossFunction lossFunction) {
|
@JsonProperty("lossFunction") ILossFunction lossFunction) {
|
||||||
this.activationFn = activationFn;
|
this.activation = activation;
|
||||||
this.lossFunction = lossFunction;
|
this.lossFunction = lossFunction;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,17 +59,17 @@ public class LossFunctionWrapper implements ReconstructionDistribution {
|
||||||
//NOTE: The returned value here is NOT negative log probability, but it (the loss function value)
|
//NOTE: The returned value here is NOT negative log probability, but it (the loss function value)
|
||||||
// is equivalent, in terms of being something we want to minimize...
|
// is equivalent, in terms of being something we want to minimize...
|
||||||
|
|
||||||
return lossFunction.computeScore(x, preOutDistributionParams, activationFn, null, average);
|
return lossFunction.computeScore(x, preOutDistributionParams, activation, null, average);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
|
public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
|
||||||
return lossFunction.computeScoreArray(x, preOutDistributionParams, activationFn, null);
|
return lossFunction.computeScoreArray(x, preOutDistributionParams, activation, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
|
public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
|
||||||
return lossFunction.computeGradient(x, preOutDistributionParams, activationFn, null);
|
return lossFunction.computeGradient(x, preOutDistributionParams, activation, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -82,11 +82,11 @@ public class LossFunctionWrapper implements ReconstructionDistribution {
|
||||||
public INDArray generateAtMean(INDArray preOutDistributionParams) {
|
public INDArray generateAtMean(INDArray preOutDistributionParams) {
|
||||||
//Loss functions: not probabilistic -> not random
|
//Loss functions: not probabilistic -> not random
|
||||||
INDArray out = preOutDistributionParams.dup();
|
INDArray out = preOutDistributionParams.dup();
|
||||||
return activationFn.getActivation(out, true);
|
return activation.getActivation(out, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "LossFunctionWrapper(afn=" + activationFn + "," + lossFunction + ")";
|
return "LossFunctionWrapper(afn=" + activation + "," + lossFunction + ")";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,7 +118,7 @@ import org.nd4j.linalg.workspace.WorkspaceUtils;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
// @JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator.class, property = "@id")
|
// @JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator.class, property = "@id")
|
||||||
@JsonIgnoreProperties({"helper", "net", "initCalled", "iupdater", "activationFn"})
|
@JsonIgnoreProperties({"helper", "net", "initCalled", "iupdater"})
|
||||||
public class MultiLayerNetwork extends ArtificialNeuralNetwork
|
public class MultiLayerNetwork extends ArtificialNeuralNetwork
|
||||||
implements Serializable, Classifier, Layer, ITrainableLayer {
|
implements Serializable, Classifier, Layer, ITrainableLayer {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue