Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-25 14:03:06 +02:00
parent 55f8486fe3
commit 391a1ad397
4 changed files with 37 additions and 167 deletions

View File

@ -210,7 +210,6 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>> C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> { extends LayerConfigurationBuilder<C, B> {
/** /**
* Set weight initialization scheme to random sampling via the specified distribution. * Set weight initialization scheme to random sampling via the specified distribution.
* Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
@ -394,7 +393,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
public B constrainBias(LayerConstraint... constraints) { public B constrainBias(LayerConstraint... constraints) {
biasConstraints$value = Arrays.asList(constraints); biasConstraints$value = Arrays.asList(constraints);
biasConstraints$set = true; biasConstraints$set = true;
return (B) this; return self();
} }
/** /**

View File

@ -73,7 +73,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
*/ */
@Builder.Default @Builder.Default
protected CNN2DFormat convFormat = protected CNN2DFormat convFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons CNN2DFormat.NCHW; // default value for legacy serialization reasons
/** /**
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated * Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated
@ -110,6 +110,19 @@ public class ConvolutionLayer extends FeedForwardLayer {
@Builder.Default @JsonIgnore @EqualsAndHashCode.Exclude @Builder.Default @JsonIgnore @EqualsAndHashCode.Exclude
private boolean defaultValueOverriden = false; private boolean defaultValueOverriden = false;
public static ConvolutionLayerBuilder<?, ?> builder(int... kernelSize) {
return innerBuilder().kernelSize(kernelSize);
}
public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride) {
return innerBuilder().kernelSize(kernelSize).stride(stride);
}
public static ConvolutionLayerBuilder<?, ?> builder(
int[] kernelSize, int[] stride, int[] padding) {
return innerBuilder().kernelSize(kernelSize).stride(stride).padding(padding);
}
public boolean hasBias() { public boolean hasBias() {
return hasBias; return hasBias;
} }
@ -201,16 +214,14 @@ public class ConvolutionLayer extends FeedForwardLayer {
} }
if (convFormat == null || override) if (convFormat == null || override)
this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
} }
@Override @Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) { if (inputType == null) {
throw new IllegalStateException( throw new IllegalStateException(
"Invalid input for Convolution layer (layer name=\"" "Invalid input for Convolution layer (layer name=\"" + getName() + "\"): input is null");
+ getName()
+ "\"): input is null");
} }
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName()); return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
@ -272,7 +283,6 @@ public class ConvolutionLayer extends FeedForwardLayer {
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, cachedPerEx) .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, cachedPerEx)
.build(); .build();
} }
/** /**
* The "PREFER_FASTEST" mode will pick the fastest algorithm for the specified parameters from the * The "PREFER_FASTEST" mode will pick the fastest algorithm for the specified parameters from the
* {@link FwdAlgo}, {@link BwdFilterAlgo}, and {@link BwdDataAlgo} lists, but they may be very * {@link FwdAlgo}, {@link BwdFilterAlgo}, and {@link BwdDataAlgo} lists, but they may be very
@ -320,6 +330,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
FFT_TILING, FFT_TILING,
COUNT COUNT
} }
/** /**
* The backward data algorithm to use when {@link AlgoMode} is set to "USER_SPECIFIED". * The backward data algorithm to use when {@link AlgoMode} is set to "USER_SPECIFIED".
* *
@ -335,25 +346,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
COUNT COUNT
} }
public static ConvolutionLayerBuilder<?, ?> builder(int... kernelSize) { public abstract static class ConvolutionLayerBuilder<
return innerBuilder().kernelSize(kernelSize);
}
public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride) {
return innerBuilder()
.kernelSize(kernelSize)
.stride(stride);
}
public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride, int[] padding) {
return innerBuilder()
.kernelSize(kernelSize)
.stride(stride)
.padding(padding);
}
public static abstract class ConvolutionLayerBuilder<
C extends ConvolutionLayer, B extends ConvolutionLayerBuilder<C, B>> C extends ConvolutionLayer, B extends ConvolutionLayerBuilder<C, B>>
extends FeedForwardLayerBuilder<C, B> { extends FeedForwardLayerBuilder<C, B> {
@ -438,7 +431,6 @@ public class ConvolutionLayer extends FeedForwardLayer {
C l = this.initBuild(); C l = this.initBuild();
l.setType(LayerType.CONV); l.setType(LayerType.CONV);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }

View File

@ -20,9 +20,11 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import lombok.*; import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.conf.DataFormat; import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -41,6 +43,13 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
*/ */
@Getter @Getter
protected long nIn; protected long nIn;
public void setNIn(int in) {
this.nIn = in;
}
public void setNIn(long in) {
this.nIn = in;
}
/** /**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers, * Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
* this is the input channels, otherwise is the previous layer size. * this is the input channels, otherwise is the previous layer size.
@ -49,6 +58,13 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
@Getter @Getter
protected long nOut; protected long nOut;
protected DataFormat timeDistributedFormat; protected DataFormat timeDistributedFormat;
protected FeedForwardLayer(FeedForwardLayerBuilder<?, ?> b) {
super(b);
this.nIn = b.nIn;
this.nOut = b.nOut;
this.timeDistributedFormat = b.timeDistributedFormat;
}
// //
// { //Initializer block // { //Initializer block
// setType(LayerType.FC); // setType(LayerType.FC);

View File

@ -325,141 +325,4 @@ public abstract class LayerConfiguration
runInheritance(getNetConfiguration()); runInheritance(getNetConfiguration());
} }
public abstract static class LayerConfigurationBuilder<
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
private String name;
private List<LayerConstraint> allParamConstraints;
private List<LayerConstraint> weightConstraints;
private List<LayerConstraint> biasConstraints;
private List<LayerConstraint> constraints;
private IWeightNoise weightNoise;
private LinkedHashSet<String> variables$value;
private boolean variables$set;
private IDropout dropOut;
private @NonNull LayerType type$value;
private boolean type$set;
private NeuralNetConfiguration netConfiguration;
private IActivation activation$value;
private boolean activation$set;
public B activation(Activation activation) {
this.activation$value = activation;
this.activation$set = true;
return self();
}
public B activation(IActivation activation) {
this.activation$value = activation;
this.activation$set = true;
return self();
}
public B dropOut(double d) {
this.dropOut = new Dropout(d);
return self();
}
public B dropOut(IDropout d) {
this.dropOut = d;
return self();
}
public B constrainBias(LayerConstraint constraint) {
return this.biasConstraints(List.of(constraint));
}
public B constrainWeights(LayerConstraint constraint) {
return this.weightConstraints(List.of(constraint));
}
public B name(String name) {
this.name = name;
return self();
}
public B allParamConstraints(List<LayerConstraint> allParamConstraints) {
this.allParamConstraints = allParamConstraints;
return self();
}
public B weightConstraints(List<LayerConstraint> weightConstraints) {
this.weightConstraints = weightConstraints;
return self();
}
public B biasConstraints(List<LayerConstraint> biasConstraints) {
this.biasConstraints = biasConstraints;
return self();
}
public B constraints(List<LayerConstraint> constraints) {
this.constraints = constraints;
return self();
}
public B weightNoise(IWeightNoise weightNoise) {
this.weightNoise = weightNoise;
return self();
}
public B variables(LinkedHashSet<String> variables) {
this.variables$value = variables;
this.variables$set = true;
return self();
}
public B type(@NonNull LayerType type) {
this.type$value = type;
this.type$set = true;
return self();
}
@JsonIgnore
public B netConfiguration(NeuralNetConfiguration netConfiguration) {
this.netConfiguration = netConfiguration;
return self();
}
protected abstract B self();
public abstract C build();
public String toString() {
return "LayerConfiguration.LayerConfigurationBuilder(name="
+ this.name
+ ", allParamConstraints="
+ this.allParamConstraints
+ ", weightConstraints="
+ this.weightConstraints
+ ", biasConstraints="
+ this.biasConstraints
+ ", constraints="
+ this.constraints
+ ", weightNoise="
+ this.weightNoise
+ ", variables$value="
+ this.variables$value
+ ", variables$set="
+ this.variables$set
+ ", dropOut="
+ this.dropOut
+ ", type$value="
+ this.type$value
+ ", type$set="
+ this.type$set
+ ", netConfiguration="
+ this.netConfiguration
+ ", activation$value="
+ this.activation$value
+ ", activation$set="
+ this.activation$set
+ ", variables$value="
+ this.variables$value
+ ", type$value="
+ this.type$value
+ ", activation$value="
+ this.activation$value
+ ")";
}
}
} }