Fixing tests

Signed-off-by: brian <brian@brutex.de>
enhance-build-infrastructure
Brian Rosenberger 2023-05-08 19:12:46 +02:00
parent 2c8c6d9624
commit c758cf918f
7 changed files with 26 additions and 13 deletions

View File

@ -27,6 +27,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import lombok.*;
import lombok.experimental.SuperBuilder;
import lombok.extern.jackson.Jacksonized;
import lombok.extern.slf4j.Slf4j;
import net.brutex.ai.dnn.api.INeuralNetworkConfiguration;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -89,6 +90,7 @@ import java.util.*;
@Slf4j
// The inner builder, that we can then extend ...
@SuperBuilder // TODO fix access
@Jacksonized
@EqualsAndHashCode
public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetworkConfiguration {
@ -895,7 +897,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
*
* @param distribution Distribution to use for weight initialization
*/
@JsonIgnore
public B weightInit(Distribution distribution) {
this.weightInit$value = new WeightInitDistribution(distribution);
this.weightInit$set = true;

View File

@ -24,8 +24,12 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.MapperFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.exc.InvalidTypeIdException;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import lombok.*;
import lombok.experimental.SuperBuilder;
@ -165,7 +169,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
* @return {@link NeuralNetConfiguration}
*/
public static NeuralNetConfiguration fromJson(String json) {
ObjectMapper mapper = NeuralNetConfiguration.mapper();
//ObjectMapper mapper = NeuralNetConfiguration.mapper();
JsonMapper mapper = JsonMapper.builder().build();
try {
return mapper.readValue(json, NeuralNetConfiguration.class);
} catch (JsonProcessingException e) {
@ -439,7 +444,11 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration {
* @return JSON representation of NN configuration
*/
public String toJson() {
ObjectMapper mapper = NeuralNetConfiguration.mapper();
JsonMapper mapper = JsonMapper.builder()
.enable(SerializationFeature.INDENT_OUTPUT)
.enable(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY)
.build();
//ObjectMapper mapper = NeuralNetConfiguration.mapper();
synchronized (mapper) {
//JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally
//when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243

View File

@ -90,6 +90,7 @@ public abstract class InputType implements Serializable {
*
* @return int[]
*/
@JsonIgnore
public long[] getShape() {
return getShape(false);
}
@ -431,7 +432,7 @@ public abstract class InputType implements Serializable {
return height * width * depth * channels;
}
@Override
@Override @JsonIgnore
public long[] getShape(boolean includeBatchDim) {
if(dataFormat == Convolution3D.DataFormat.NDHWC){
if(includeBatchDim) return new long[]{-1, depth, height, width, channels};

View File

@ -40,6 +40,7 @@ import org.nd4j.linalg.factory.Nd4j;
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder")
@NoArgsConstructor
public class CapsuleLayer extends SameDiffLayer {
private static final String WEIGHT_PARAM = "weight";

View File

@ -92,6 +92,7 @@ public abstract class LayerConfiguration
*
* @return activation function
*/
@JsonIgnore
public IActivation getActivationFn() {
if (activation == null)
throw new RuntimeException(

View File

@ -30,12 +30,12 @@ import com.fasterxml.jackson.annotation.JsonProperty;
@Data
public class LossFunctionWrapper implements ReconstructionDistribution {
private final IActivation activationFn;
private final IActivation activation;
private final ILossFunction lossFunction;
public LossFunctionWrapper(@JsonProperty("activationFn") IActivation activationFn,
public LossFunctionWrapper(@JsonProperty("activation") IActivation activation,
@JsonProperty("lossFunction") ILossFunction lossFunction) {
this.activationFn = activationFn;
this.activation = activation;
this.lossFunction = lossFunction;
}
@ -59,17 +59,17 @@ public class LossFunctionWrapper implements ReconstructionDistribution {
//NOTE: The returned value here is NOT negative log probability, but it (the loss function value)
// is equivalent, in terms of being something we want to minimize...
return lossFunction.computeScore(x, preOutDistributionParams, activationFn, null, average);
return lossFunction.computeScore(x, preOutDistributionParams, activation, null, average);
}
@Override
public INDArray exampleNegLogProbability(INDArray x, INDArray preOutDistributionParams) {
return lossFunction.computeScoreArray(x, preOutDistributionParams, activationFn, null);
return lossFunction.computeScoreArray(x, preOutDistributionParams, activation, null);
}
@Override
public INDArray gradient(INDArray x, INDArray preOutDistributionParams) {
return lossFunction.computeGradient(x, preOutDistributionParams, activationFn, null);
return lossFunction.computeGradient(x, preOutDistributionParams, activation, null);
}
@Override
@ -82,11 +82,11 @@ public class LossFunctionWrapper implements ReconstructionDistribution {
public INDArray generateAtMean(INDArray preOutDistributionParams) {
//Loss functions: not probabilistic -> not random
INDArray out = preOutDistributionParams.dup();
return activationFn.getActivation(out, true);
return activation.getActivation(out, true);
}
@Override
public String toString() {
return "LossFunctionWrapper(afn=" + activationFn + "," + lossFunction + ")";
return "LossFunctionWrapper(afn=" + activation + "," + lossFunction + ")";
}
}

View File

@ -118,7 +118,7 @@ import org.nd4j.linalg.workspace.WorkspaceUtils;
*/
@Slf4j
// @JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator.class, property = "@id")
@JsonIgnoreProperties({"helper", "net", "initCalled", "iupdater", "activationFn"})
@JsonIgnoreProperties({"helper", "net", "initCalled", "iupdater"})
public class MultiLayerNetwork extends ArtificialNeuralNetwork
implements Serializable, Classifier, Layer, ITrainableLayer {