Using @SuperBuilder for LayerConfigurations
Signed-off-by: brian <brian@brutex.de>
This commit is contained in:
		
							parent
							
								
									55f8486fe3
								
							
						
					
					
						commit
						391a1ad397
					
				@ -210,7 +210,6 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
 | 
			
		||||
          C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
 | 
			
		||||
      extends LayerConfigurationBuilder<C, B> {
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Set weight initialization scheme to random sampling via the specified distribution.
 | 
			
		||||
     * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
 | 
			
		||||
@ -394,7 +393,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
 | 
			
		||||
    public B constrainBias(LayerConstraint... constraints) {
 | 
			
		||||
      biasConstraints$value = Arrays.asList(constraints);
 | 
			
		||||
      biasConstraints$set = true;
 | 
			
		||||
      return (B) this;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
 | 
			
		||||
@ -110,6 +110,19 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
  @Builder.Default @JsonIgnore @EqualsAndHashCode.Exclude
 | 
			
		||||
  private boolean defaultValueOverriden = false;
 | 
			
		||||
 | 
			
		||||
  public static ConvolutionLayerBuilder<?, ?> builder(int... kernelSize) {
 | 
			
		||||
    return innerBuilder().kernelSize(kernelSize);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride) {
 | 
			
		||||
    return innerBuilder().kernelSize(kernelSize).stride(stride);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public static ConvolutionLayerBuilder<?, ?> builder(
 | 
			
		||||
      int[] kernelSize, int[] stride, int[] padding) {
 | 
			
		||||
    return innerBuilder().kernelSize(kernelSize).stride(stride).padding(padding);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public boolean hasBias() {
 | 
			
		||||
    return hasBias;
 | 
			
		||||
  }
 | 
			
		||||
@ -208,9 +221,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
  public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
 | 
			
		||||
    if (inputType == null) {
 | 
			
		||||
      throw new IllegalStateException(
 | 
			
		||||
          "Invalid input for Convolution layer (layer name=\""
 | 
			
		||||
              + getName()
 | 
			
		||||
              + "\"): input is null");
 | 
			
		||||
          "Invalid input for Convolution layer (layer name=\"" + getName() + "\"): input is null");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
 | 
			
		||||
@ -272,7 +283,6 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
        .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, cachedPerEx)
 | 
			
		||||
        .build();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * The "PREFER_FASTEST" mode will pick the fastest algorithm for the specified parameters from the
 | 
			
		||||
   * {@link FwdAlgo}, {@link BwdFilterAlgo}, and {@link BwdDataAlgo} lists, but they may be very
 | 
			
		||||
@ -320,6 +330,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
    FFT_TILING,
 | 
			
		||||
    COUNT
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * The backward data algorithm to use when {@link AlgoMode} is set to "USER_SPECIFIED".
 | 
			
		||||
   *
 | 
			
		||||
@ -335,25 +346,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
    COUNT
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public static ConvolutionLayerBuilder<?, ?> builder(int... kernelSize) {
 | 
			
		||||
    return innerBuilder().kernelSize(kernelSize);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride) {
 | 
			
		||||
    return innerBuilder()
 | 
			
		||||
            .kernelSize(kernelSize)
 | 
			
		||||
            .stride(stride);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public static ConvolutionLayerBuilder<?, ?> builder(int[] kernelSize, int[] stride, int[] padding) {
 | 
			
		||||
    return innerBuilder()
 | 
			
		||||
            .kernelSize(kernelSize)
 | 
			
		||||
            .stride(stride)
 | 
			
		||||
            .padding(padding);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  public static abstract class ConvolutionLayerBuilder<
 | 
			
		||||
  public abstract static class ConvolutionLayerBuilder<
 | 
			
		||||
          C extends ConvolutionLayer, B extends ConvolutionLayerBuilder<C, B>>
 | 
			
		||||
      extends FeedForwardLayerBuilder<C, B> {
 | 
			
		||||
 | 
			
		||||
@ -438,7 +431,6 @@ public class ConvolutionLayer extends FeedForwardLayer {
 | 
			
		||||
 | 
			
		||||
      C l = this.initBuild();
 | 
			
		||||
      l.setType(LayerType.CONV);
 | 
			
		||||
 | 
			
		||||
      l.initializeConstraints();
 | 
			
		||||
      return l;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -20,9 +20,11 @@
 | 
			
		||||
 | 
			
		||||
package org.deeplearning4j.nn.conf.layers;
 | 
			
		||||
 | 
			
		||||
import lombok.*;
 | 
			
		||||
import lombok.EqualsAndHashCode;
 | 
			
		||||
import lombok.Getter;
 | 
			
		||||
import lombok.Setter;
 | 
			
		||||
import lombok.ToString;
 | 
			
		||||
import lombok.experimental.SuperBuilder;
 | 
			
		||||
import net.brutex.ai.dnn.api.LayerType;
 | 
			
		||||
import org.deeplearning4j.nn.conf.DataFormat;
 | 
			
		||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
 | 
			
		||||
import org.deeplearning4j.nn.conf.inputs.InputType;
 | 
			
		||||
@ -41,6 +43,13 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
 | 
			
		||||
     */
 | 
			
		||||
    @Getter
 | 
			
		||||
    protected long nIn;
 | 
			
		||||
 | 
			
		||||
    public void setNIn(int in) {
 | 
			
		||||
        this.nIn = in;
 | 
			
		||||
    }
 | 
			
		||||
    public void setNIn(long in) {
 | 
			
		||||
        this.nIn = in;
 | 
			
		||||
    }
 | 
			
		||||
    /**
 | 
			
		||||
     * Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
 | 
			
		||||
     * this is the input channels, otherwise is the previous layer size.
 | 
			
		||||
@ -49,6 +58,13 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
 | 
			
		||||
    @Getter
 | 
			
		||||
    protected long nOut;
 | 
			
		||||
    protected DataFormat timeDistributedFormat;
 | 
			
		||||
 | 
			
		||||
    protected FeedForwardLayer(FeedForwardLayerBuilder<?, ?> b) {
 | 
			
		||||
        super(b);
 | 
			
		||||
        this.nIn = b.nIn;
 | 
			
		||||
        this.nOut = b.nOut;
 | 
			
		||||
        this.timeDistributedFormat = b.timeDistributedFormat;
 | 
			
		||||
    }
 | 
			
		||||
//
 | 
			
		||||
  //  { //Initializer block
 | 
			
		||||
      // setType(LayerType.FC);
 | 
			
		||||
 | 
			
		||||
@ -325,141 +325,4 @@ public abstract class LayerConfiguration
 | 
			
		||||
    runInheritance(getNetConfiguration());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  public abstract static class LayerConfigurationBuilder<
 | 
			
		||||
      C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
 | 
			
		||||
    private String name;
 | 
			
		||||
    private List<LayerConstraint> allParamConstraints;
 | 
			
		||||
    private List<LayerConstraint> weightConstraints;
 | 
			
		||||
    private List<LayerConstraint> biasConstraints;
 | 
			
		||||
    private List<LayerConstraint> constraints;
 | 
			
		||||
    private IWeightNoise weightNoise;
 | 
			
		||||
    private LinkedHashSet<String> variables$value;
 | 
			
		||||
    private boolean variables$set;
 | 
			
		||||
    private IDropout dropOut;
 | 
			
		||||
    private @NonNull LayerType type$value;
 | 
			
		||||
    private boolean type$set;
 | 
			
		||||
    private NeuralNetConfiguration netConfiguration;
 | 
			
		||||
    private IActivation activation$value;
 | 
			
		||||
    private boolean activation$set;
 | 
			
		||||
 | 
			
		||||
    public B activation(Activation activation) {
 | 
			
		||||
      this.activation$value = activation;
 | 
			
		||||
      this.activation$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B activation(IActivation activation) {
 | 
			
		||||
      this.activation$value = activation;
 | 
			
		||||
      this.activation$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B dropOut(double d) {
 | 
			
		||||
      this.dropOut = new Dropout(d);
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B dropOut(IDropout d) {
 | 
			
		||||
      this.dropOut = d;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B constrainBias(LayerConstraint constraint) {
 | 
			
		||||
      return this.biasConstraints(List.of(constraint));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B constrainWeights(LayerConstraint constraint) {
 | 
			
		||||
      return this.weightConstraints(List.of(constraint));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B name(String name) {
 | 
			
		||||
      this.name = name;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B allParamConstraints(List<LayerConstraint> allParamConstraints) {
 | 
			
		||||
      this.allParamConstraints = allParamConstraints;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B weightConstraints(List<LayerConstraint> weightConstraints) {
 | 
			
		||||
      this.weightConstraints = weightConstraints;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B biasConstraints(List<LayerConstraint> biasConstraints) {
 | 
			
		||||
      this.biasConstraints = biasConstraints;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B constraints(List<LayerConstraint> constraints) {
 | 
			
		||||
      this.constraints = constraints;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B weightNoise(IWeightNoise weightNoise) {
 | 
			
		||||
      this.weightNoise = weightNoise;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B variables(LinkedHashSet<String> variables) {
 | 
			
		||||
      this.variables$value = variables;
 | 
			
		||||
      this.variables$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    public B type(@NonNull LayerType type) {
 | 
			
		||||
      this.type$value = type;
 | 
			
		||||
      this.type$set = true;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @JsonIgnore
 | 
			
		||||
    public B netConfiguration(NeuralNetConfiguration netConfiguration) {
 | 
			
		||||
      this.netConfiguration = netConfiguration;
 | 
			
		||||
      return self();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    protected abstract B self();
 | 
			
		||||
 | 
			
		||||
    public abstract C build();
 | 
			
		||||
 | 
			
		||||
    public String toString() {
 | 
			
		||||
      return "LayerConfiguration.LayerConfigurationBuilder(name="
 | 
			
		||||
          + this.name
 | 
			
		||||
          + ", allParamConstraints="
 | 
			
		||||
          + this.allParamConstraints
 | 
			
		||||
          + ", weightConstraints="
 | 
			
		||||
          + this.weightConstraints
 | 
			
		||||
          + ", biasConstraints="
 | 
			
		||||
          + this.biasConstraints
 | 
			
		||||
          + ", constraints="
 | 
			
		||||
          + this.constraints
 | 
			
		||||
          + ", weightNoise="
 | 
			
		||||
          + this.weightNoise
 | 
			
		||||
          + ", variables$value="
 | 
			
		||||
          + this.variables$value
 | 
			
		||||
          + ", variables$set="
 | 
			
		||||
          + this.variables$set
 | 
			
		||||
          + ", dropOut="
 | 
			
		||||
          + this.dropOut
 | 
			
		||||
          + ", type$value="
 | 
			
		||||
          + this.type$value
 | 
			
		||||
          + ", type$set="
 | 
			
		||||
          + this.type$set
 | 
			
		||||
          + ", netConfiguration="
 | 
			
		||||
          + this.netConfiguration
 | 
			
		||||
          + ", activation$value="
 | 
			
		||||
          + this.activation$value
 | 
			
		||||
          + ", activation$set="
 | 
			
		||||
          + this.activation$set
 | 
			
		||||
          + ", variables$value="
 | 
			
		||||
          + this.variables$value
 | 
			
		||||
          + ", type$value="
 | 
			
		||||
          + this.type$value
 | 
			
		||||
          + ", activation$value="
 | 
			
		||||
          + this.activation$value
 | 
			
		||||
          + ")";
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user