Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-25 14:03:06 +02:00
parent 55f8486fe3
commit 391a1ad397
4 changed files with 37 additions and 167 deletions

View File

@ -210,7 +210,6 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> {
/**
* Set weight initialization scheme to random sampling via the specified distribution.
* Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
@ -394,7 +393,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
public B constrainBias(LayerConstraint... constraints) {
biasConstraints$value = Arrays.asList(constraints);
biasConstraints$set = true;
return (B) this;
return self();
}
/**

View File

@ -73,7 +73,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
*/
@Builder.Default
protected CNN2DFormat convFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons
CNN2DFormat.NCHW; // default value for legacy serialization reasons
/**
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated
@ -110,6 +110,19 @@ public class ConvolutionLayer extends FeedForwardLayer {
@Builder.Default @JsonIgnore @EqualsAndHashCode.Exclude
private boolean defaultValueOverriden = false;
public static ConvolutionLayerBuilder<?, ?> builder(int... kernelSize) {
return innerBuilder().kernelSize(kernelSize);
}
public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride) {
return innerBuilder().kernelSize(kernelSize).stride(stride);
}
public static ConvolutionLayerBuilder<?, ?> builder(
int[] kernelSize, int[] stride, int[] padding) {
return innerBuilder().kernelSize(kernelSize).stride(stride).padding(padding);
}
public boolean hasBias() {
return hasBias;
}
@ -201,16 +214,14 @@ public class ConvolutionLayer extends FeedForwardLayer {
}
if (convFormat == null || override)
this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
}
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) {
throw new IllegalStateException(
"Invalid input for Convolution layer (layer name=\""
+ getName()
+ "\"): input is null");
"Invalid input for Convolution layer (layer name=\"" + getName() + "\"): input is null");
}
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
@ -272,7 +283,6 @@ public class ConvolutionLayer extends FeedForwardLayer {
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, cachedPerEx)
.build();
}
/**
* The "PREFER_FASTEST" mode will pick the fastest algorithm for the specified parameters from the
* {@link FwdAlgo}, {@link BwdFilterAlgo}, and {@link BwdDataAlgo} lists, but they may be very
@ -320,6 +330,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
FFT_TILING,
COUNT
}
/**
* The backward data algorithm to use when {@link AlgoMode} is set to "USER_SPECIFIED".
*
@ -335,25 +346,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
COUNT
}
public static ConvolutionLayerBuilder<?, ?> builder(int... kernelSize) {
return innerBuilder().kernelSize(kernelSize);
}
public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride) {
return innerBuilder()
.kernelSize(kernelSize)
.stride(stride);
}
public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride, int[] padding) {
return innerBuilder()
.kernelSize(kernelSize)
.stride(stride)
.padding(padding);
}
public static abstract class ConvolutionLayerBuilder<
public abstract static class ConvolutionLayerBuilder<
C extends ConvolutionLayer, B extends ConvolutionLayerBuilder<C, B>>
extends FeedForwardLayerBuilder<C, B> {
@ -438,7 +431,6 @@ public class ConvolutionLayer extends FeedForwardLayer {
C l = this.initBuild();
l.setType(LayerType.CONV);
l.initializeConstraints();
return l;
}

View File

@ -20,9 +20,11 @@
package org.deeplearning4j.nn.conf.layers;
import lombok.*;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.conf.DataFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
@ -41,6 +43,13 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
*/
@Getter
protected long nIn;
public void setNIn(int in) {
this.nIn = in;
}
public void setNIn(long in) {
this.nIn = in;
}
/**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
* this is the input channels, otherwise is the previous layer size.
@ -49,6 +58,13 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
@Getter
protected long nOut;
protected DataFormat timeDistributedFormat;
protected FeedForwardLayer(FeedForwardLayerBuilder<?, ?> b) {
super(b);
this.nIn = b.nIn;
this.nOut = b.nOut;
this.timeDistributedFormat = b.timeDistributedFormat;
}
//
// { //Initializer block
// setType(LayerType.FC);

View File

@ -325,141 +325,4 @@ public abstract class LayerConfiguration
runInheritance(getNetConfiguration());
}
public abstract static class LayerConfigurationBuilder<
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
private String name;
private List<LayerConstraint> allParamConstraints;
private List<LayerConstraint> weightConstraints;
private List<LayerConstraint> biasConstraints;
private List<LayerConstraint> constraints;
private IWeightNoise weightNoise;
private LinkedHashSet<String> variables$value;
private boolean variables$set;
private IDropout dropOut;
private @NonNull LayerType type$value;
private boolean type$set;
private NeuralNetConfiguration netConfiguration;
private IActivation activation$value;
private boolean activation$set;
public B activation(Activation activation) {
this.activation$value = activation;
this.activation$set = true;
return self();
}
public B activation(IActivation activation) {
this.activation$value = activation;
this.activation$set = true;
return self();
}
public B dropOut(double d) {
this.dropOut = new Dropout(d);
return self();
}
public B dropOut(IDropout d) {
this.dropOut = d;
return self();
}
public B constrainBias(LayerConstraint constraint) {
return this.biasConstraints(List.of(constraint));
}
public B constrainWeights(LayerConstraint constraint) {
return this.weightConstraints(List.of(constraint));
}
public B name(String name) {
this.name = name;
return self();
}
public B allParamConstraints(List<LayerConstraint> allParamConstraints) {
this.allParamConstraints = allParamConstraints;
return self();
}
public B weightConstraints(List<LayerConstraint> weightConstraints) {
this.weightConstraints = weightConstraints;
return self();
}
public B biasConstraints(List<LayerConstraint> biasConstraints) {
this.biasConstraints = biasConstraints;
return self();
}
public B constraints(List<LayerConstraint> constraints) {
this.constraints = constraints;
return self();
}
public B weightNoise(IWeightNoise weightNoise) {
this.weightNoise = weightNoise;
return self();
}
public B variables(LinkedHashSet<String> variables) {
this.variables$value = variables;
this.variables$set = true;
return self();
}
public B type(@NonNull LayerType type) {
this.type$value = type;
this.type$set = true;
return self();
}
@JsonIgnore
public B netConfiguration(NeuralNetConfiguration netConfiguration) {
this.netConfiguration = netConfiguration;
return self();
}
protected abstract B self();
public abstract C build();
public String toString() {
return "LayerConfiguration.LayerConfigurationBuilder(name="
+ this.name
+ ", allParamConstraints="
+ this.allParamConstraints
+ ", weightConstraints="
+ this.weightConstraints
+ ", biasConstraints="
+ this.biasConstraints
+ ", constraints="
+ this.constraints
+ ", weightNoise="
+ this.weightNoise
+ ", variables$value="
+ this.variables$value
+ ", variables$set="
+ this.variables$set
+ ", dropOut="
+ this.dropOut
+ ", type$value="
+ this.type$value
+ ", type$set="
+ this.type$set
+ ", netConfiguration="
+ this.netConfiguration
+ ", activation$value="
+ this.activation$value
+ ", activation$set="
+ this.activation$set
+ ", variables$value="
+ this.variables$value
+ ", type$value="
+ this.type$value
+ ", activation$value="
+ this.activation$value
+ ")";
}
}
}