Fixing tests

Signed-off-by: brian <brian@brutex.de>
enhance-build-infrastructure
Brian Rosenberger 2023-05-08 19:12:46 +02:00
parent 2c8c6d9624
commit c758cf918f
7 changed files with 26 additions and 13 deletions

View File

@ -27,6 +27,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ArrayNode;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -89,6 +90,7 @@ import java.util.*;
@Slf4j @Slf4j
// The inner builder, that we can then extend ... // The inner builder, that we can then extend ...
@SuperBuilder // TODO fix access @SuperBuilder // TODO fix access
@Jacksonized
@EqualsAndHashCode @EqualsAndHashCode
public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetworkConfiguration { public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetworkConfiguration {
@ -895,7 +897,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
* *
* @param distribution Distribution to use for weight initialization * @param distribution Distribution to use for weight initialization
*/ */
@JsonIgnore
public B weightInit(Distribution distribution) { public B weightInit(Distribution distribution) {
this.weightInit$value = new WeightInitDistribution(distribution); this.weightInit$value = new WeightInitDistribution(distribution);
this.weightInit$set = true; this.weightInit$set = true;

View File

@ -24,8 +24,12 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ArrayNode;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
@ -165,7 +169,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
* @return {@link NeuralNetConfiguration} * @return {@link NeuralNetConfiguration}
*/ */
public static NeuralNetConfiguration fromJson(String json) { public static NeuralNetConfiguration fromJson(String json) {
ObjectMapper mapper = NeuralNetConfiguration.mapper(); //ObjectMapper mapper = NeuralNetConfiguration.mapper();
JsonMapper mapper = JsonMapper.builder().build();
try { try {
return mapper.readValue(json, NeuralNetConfiguration.class); return mapper.readValue(json, NeuralNetConfiguration.class);
} catch (JsonProcessingException e) { } catch (JsonProcessingException e) {
@ -439,7 +444,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
* @return JSON representation of NN configuration * @return JSON representation of NN configuration
*/ */
public String toJson() { public String toJson() {
ObjectMapper mapper = NeuralNetConfiguration.mapper(); JsonMapper mapper = JsonMapper.builder()
.enable(SerializationFeature.INDENT_OUTPUT)
.enable(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY)
.build();
//ObjectMapper mapper = NeuralNetConfiguration.mapper();
synchronized (mapper) { synchronized (mapper) {
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243

View File

@ -90,6 +90,7 @@ public abstract class InputType implements Serializable {
* *
* @return int[] * @return int[]
*/ */
@JsonIgnore
public long[] getShape() { public long[] getShape() {
return getShape(false); return getShape(false);
} }
@ -431,7 +432,7 @@ public abstract class InputType implements Serializable {
return height * width * depth * channels; return height * width * depth * channels;
} }
@Override @Override @JsonIgnore
public long[] getShape(boolean includeBatchDim) { public long[] getShape(boolean includeBatchDim) {
if(dataFormat == Convolution3D.DataFormat.NDHWC){ if(dataFormat == Convolution3D.DataFormat.NDHWC){
if(includeBatchDim) return new long[]{-1, depth, height, width, channels}; if(includeBatchDim) return new long[]{-1, depth, height, width, channels};

View File

@ -40,6 +40,7 @@ import org.nd4j.linalg.factory.Nd4j;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
@NoArgsConstructor
public class CapsuleLayer extends SameDiffLayer { public class CapsuleLayer extends SameDiffLayer {
private static final String WEIGHT_PARAM = "weight"; private static final String WEIGHT_PARAM = "weight";

View File

@ -92,6 +92,7 @@ public abstract class LayerConfiguration
* *
* @return activation function * @return activation function
*/ */
@JsonIgnore
public IActivation getActivationFn() { public IActivation getActivationFn() {
if (activation == null) if (activation == null)
throw new RuntimeException( throw new RuntimeException(

View File

@ -30,12 +30,12 @@ import com.fasterxml.jackson.annotation.JsonProperty;
@Data @Data
public class LossFunctionWrapper implements ReconstructionDistribution { public class LossFunctionWrapper implements ReconstructionDistribution {
private final IActivation activationFn; private final IActivation activation;
private final ILossFunction lossFunction; private final ILossFunction lossFunction;
public LossFunctionWrapper(@JsonProperty("activationFn") IActivation activationFn, public LossFunctionWrapper(@JsonProperty("activation") IActivation activation,
@JsonProperty("lossFunction") ILossFunction lossFunction) { @JsonProperty("lossFunction") ILossFunction lossFunction) {
this.activationFn = activationFn; this.activation = activation;
this.lossFunction = lossFunction; this.lossFunction = lossFunction;
} }
@ -59,17 +59,17 @@ public class LossFunctionWrapper implements ReconstructionDistribution {
//NOTE: The returned value here is NOT negative log probability, but it (the loss function value) //NOTE: The returned value here is NOT negative log probability, but it (the loss function value)
// is equivalent, in terms of being something we want to minimize... // is equivalent, in terms of being something we want to minimize...
return lossFunction.computeScore(x, preOutDistributionParams, activationFn, null, average); return lossFunction.computeScore(x, preOutDistributionParams, activation, null, average);
} }
@Override @Override
public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) { public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
return lossFunction.computeScoreArray(x, preOutDistributionParams, activationFn, null); return lossFunction.computeScoreArray(x, preOutDistributionParams, activation, null);
} }
@Override @Override
public INDArray gradient(INDArray x, INDArray preOutDistributionParams) { public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
return lossFunction.computeGradient(x, preOutDistributionParams, activationFn, null); return lossFunction.computeGradient(x, preOutDistributionParams, activation, null);
} }
@Override @Override
@ -82,11 +82,11 @@ public class LossFunctionWrapper implements ReconstructionDistribution {
public INDArray generateAtMean(INDArray preOutDistributionParams) { public INDArray generateAtMean(INDArray preOutDistributionParams) {
//Loss functions: not probabilistic -> not random //Loss functions: not probabilistic -> not random
INDArray out = preOutDistributionParams.dup(); INDArray out = preOutDistributionParams.dup();
return activationFn.getActivation(out, true); return activation.getActivation(out, true);
} }
@Override @Override
public String toString() { public String toString() {
return "LossFunctionWrapper(afn=" + activationFn + "," + lossFunction + ")"; return "LossFunctionWrapper(afn=" + activation + "," + lossFunction + ")";
} }
} }

View File

@ -118,7 +118,7 @@ import org.nd4j.linalg.workspace.WorkspaceUtils;
*/ */
@Slf4j @Slf4j
// @JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator.class, property = "@id") // @JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator.class, property = "@id")
@JsonIgnoreProperties({"helper", "net", "initCalled", "iupdater", "activationFn"}) @JsonIgnoreProperties({"helper", "net", "initCalled", "iupdater"})
public class MultiLayerNetwork extends ArtificialNeuralNetwork public class MultiLayerNetwork extends ArtificialNeuralNetwork
implements Serializable, Classifier, Layer, ITrainableLayer { implements Serializable, Classifier, Layer, ITrainableLayer {