Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-25 13:25:23 +02:00
parent 9139940101
commit 55f8486fe3
16 changed files with 257 additions and 196 deletions

View File

@ -27,8 +27,7 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH; import org.nd4j.linalg.activations.impl.ActivationTanH;
@Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
@ -38,13 +37,13 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer {
* Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or * Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or
* longer-term dependencies. * longer-term dependencies.
*/ */
@lombok.Builder.Default protected double forgetGateBiasInit = 1.0; @lombok.Builder.Default @Getter protected double forgetGateBiasInit = 1.0;
/** /**
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation
* be allowed? If set to false, an exception in CuDNN will be propagated back to the user. If * be allowed? If set to false, an exception in CuDNN will be propagated back to the user. If
* false, the built-in (non-CuDNN) implementation for LSTM/GravesLSTM will be used * false, the built-in (non-CuDNN) implementation for LSTM/GravesLSTM will be used
*/ */
@lombok.Builder.Default protected boolean helperAllowFallback = true; @lombok.Builder.Default @Getter protected boolean helperAllowFallback = true;
/** /**
* Activation function for the LSTM gates. Note: This should be bounded to range 0-1: sigmoid or * Activation function for the LSTM gates. Note: This should be bounded to range 0-1: sigmoid or
* hard sigmoid, for example * hard sigmoid, for example

View File

@ -135,7 +135,7 @@ public class ActivationLayer extends NoParamLayer {
public static abstract class ActivationLayerBuilder< public static abstract class ActivationLayerBuilder<
C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>> C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>>
extends NoParamLayerBuilder<C, B> { extends NoParamLayer.NoParamLayerBuilder<C, B> {
public C build() { public C build() {
C l = this.build(); C l = this.build();
l.initializeConstraints(); l.initializeConstraints();

View File

@ -36,8 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
@ -47,13 +46,13 @@ public class AutoEncoder extends BasePretrainNetwork {
* Level of corruption - 0.0 (none) to 1.0 (all values corrupted) * Level of corruption - 0.0 (none) to 1.0 (all values corrupted)
* *
*/ */
@lombok.Builder.Default @lombok.Builder.Default @Getter @Setter
private double corruptionLevel = 3e-1f; private double corruptionLevel = 3e-1f;
/** /**
* Autoencoder sparity parameter * Autoencoder sparity parameter
* *
*/ */
@lombok.Builder.Default @lombok.Builder.Default @Getter @Setter
protected double sparsity = 0f; protected double sparsity = 0f;

View File

@ -45,9 +45,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay; import org.nd4j.linalg.learning.regularization.WeightDecay;
/** A neural network layer. */ /** A neural network layer. */
@Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@NoArgsConstructor(force = true)
@SuperBuilder @SuperBuilder
public abstract class BaseLayerConfiguration extends LayerConfiguration public abstract class BaseLayerConfiguration extends LayerConfiguration
implements ITraininableLayerConfiguration, Serializable, Cloneable { implements ITraininableLayerConfiguration, Serializable, Cloneable {
@ -62,7 +60,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
* *
* @param constraints Constraints to apply to all bias parameters of all layers * @param constraints Constraints to apply to all bias parameters of all layers
*/ */
@lombok.Builder.Default protected final List<LayerConstraint> biasConstraints = new ArrayList<>(); @lombok.Builder.Default @Getter protected final List<LayerConstraint> biasConstraints = new ArrayList<>();
/** /**
* Set constraints to be applied to all layers. Default: no constraints.<br> * Set constraints to be applied to all layers. Default: no constraints.<br>
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
@ -74,27 +72,33 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
* *
* @param constraints Constraints to apply to all weight parameters of all layers * @param constraints Constraints to apply to all weight parameters of all layers
*/ */
@lombok.Builder.Default @lombok.Builder.Default @Getter
protected final List<LayerConstraint> constrainWeights = new ArrayList<>(); protected final List<LayerConstraint> constrainWeights = new ArrayList<>();
/** Weight initialization scheme to use, for initial weight values */ /** Weight initialization scheme to use, for initial weight values */
@Getter @Setter
protected IWeightInit weightInit; protected IWeightInit weightInit;
/** Bias initialization value, for layers with biases. Defaults to 0 */ /** Bias initialization value, for layers with biases. Defaults to 0 */
@Getter @Setter @Builder.Default
protected double biasInit = 0.0; protected double biasInit = 0.0;
/** Gain initialization value, for layers with ILayer Normalization. Defaults to 1 */ /** Gain initialization value, for layers with ILayer Normalization. Defaults to 1 */
@Getter @Setter @Builder.Default
protected double gainInit = 0.0; protected double gainInit = 0.0;
/** Regularization for the parameters (excluding biases). */ /** Regularization for the parameters (excluding biases). */
@Builder.Default protected List<Regularization> regularization = new ArrayList<>(); @Builder.Default @Getter protected List<Regularization> regularization = new ArrayList<>();
/** Regularization for the bias parameters only */ /** Regularization for the bias parameters only */
@Builder.Default protected List<Regularization> regularizationBias = new ArrayList<>(); @Builder.Default @Getter
protected List<Regularization> regularizationBias = new ArrayList<>();
/** /**
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link * Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
* org.nd4j.linalg.learning.config.Nesterovs} * org.nd4j.linalg.learning.config.Nesterovs}
*/ */
@Getter @Setter
protected IUpdater updater; protected IUpdater updater;
/** /**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as * Gradient updater configuration, for the biases only. If not set, biases will use the updater as
* set by {@link #setUpdater(IUpdater)} * set by {@link #setUpdater(IUpdater)}
*/ */
@Getter @Setter
protected IUpdater biasUpdater; protected IUpdater biasUpdater;
/** /**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
@ -103,7 +107,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
* @see GradientNormalization * @see GradientNormalization
*/ */
@Builder.Default @Builder.Default
protected @Getter GradientNormalization gradientNormalization = protected @Getter @Setter GradientNormalization gradientNormalization =
GradientNormalization.None; // Clipping, rescale based on l2 norm, etc GradientNormalization.None; // Clipping, rescale based on l2 norm, etc
/** /**
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
@ -113,10 +117,10 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
* L2 threshold for first two types of clipping, or absolute value threshold for last type of * L2 threshold for first two types of clipping, or absolute value threshold for last type of
* clipping. * clipping.
*/ */
@Builder.Default @Builder.Default @Getter @Setter
protected double gradientNormalizationThreshold = protected double gradientNormalizationThreshold =
1.0; // Threshold for l2 and element-wise gradient clipping 1.0; // Threshold for l2 and element-wise gradient clipping
@Getter @Setter
private DataType dataType; private DataType dataType;
/** /**
@ -206,6 +210,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>> C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> { extends LayerConfigurationBuilder<C, B> {
/** /**
* Set weight initialization scheme to random sampling via the specified distribution. * Set weight initialization scheme to random sampling via the specified distribution.
* Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
@ -411,16 +416,5 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
regularizationBias$set = true; regularizationBias$set = true;
return self(); return self();
} }
public B updater(IUpdater updater) {
this.updater = updater;
return self();
}
public B updater(Updater updater) {
this.updater = updater.getIUpdaterWithDefaultConfig();
return self();
}
} }
} }

View File

@ -26,8 +26,6 @@ import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.params.PretrainParamInitializer; import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties("pretrain") @JsonIgnoreProperties("pretrain")

View File

@ -30,8 +30,7 @@ import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
@Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
@ -42,12 +41,13 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
* weight initialization as the layer input weights is also used for the recurrent weights. * weight initialization as the layer input weights is also used for the recurrent weights.
* *
*/ */
@Getter
protected IWeightInit weightInitRecurrent; protected IWeightInit weightInitRecurrent;
/** /**
* Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength], * Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength],
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. * NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
*/ */
@Builder.Default @Builder.Default @Getter @Setter
protected RNNFormat dataFormat = RNNFormat.NCW; protected RNNFormat dataFormat = RNNFormat.NCW;
/** /**
* Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no * Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no
@ -55,6 +55,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have * max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
* been updated. * been updated.
*/ */
@Getter
protected List<LayerConstraint> recurrentConstraints; protected List<LayerConstraint> recurrentConstraints;
/** /**
* Set constraints to be applied to the RNN input weight parameters of this layer. Default: no constraints.<br> * Set constraints to be applied to the RNN input weight parameters of this layer. Default: no constraints.<br>
@ -62,6 +63,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
* etc). These constraints are applied at each iteration, after the parameters have been updated. * etc). These constraints are applied at each iteration, after the parameters have been updated.
* *
*/ */
@Getter
protected List<LayerConstraint> inputWeightConstraints; protected List<LayerConstraint> inputWeightConstraints;
@Override @Override
@ -125,6 +127,4 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
return self(); return self();
} }
} }
} }

View File

@ -39,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class CapsuleLayer extends SameDiffLayer { public class CapsuleLayer extends SameDiffLayer {
@ -78,11 +77,6 @@ public class CapsuleLayer extends SameDiffLayer {
* @return * @return
*/ */
@Builder.Default private int routings = 3; @Builder.Default private int routings = 3;
public CapsuleLayer(Builder builder){
}
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {

View File

@ -41,16 +41,14 @@ import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
/** /**
* ConvolutionLayer nIn in the input layer is the number of channels nOut is the number of filters * ConvolutionLayer nIn in the input layer is the number of channels nOut is the number of filters
* to be used in the net or in other words the channels The builder specifies the filter/kernel * to be used in the net or in other words the channels The builder specifies the filter/kernel
* size, the stride and padding The pooling layer takes the kernel size * size, the stride and padding The pooling layer takes the kernel size
*/ */
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class ConvolutionLayer extends FeedForwardLayer { public class ConvolutionLayer extends FeedForwardLayer {
/** /**
* Size of the convolution rows/columns * Size of the convolution rows/columns

View File

@ -38,8 +38,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** Dense Layer Uses WeightInitXavier as default */ /** Dense Layer Uses WeightInitXavier as default */
@Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder( @SuperBuilder(

View File

@ -20,10 +20,7 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import lombok.Data; import lombok.*;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import net.brutex.ai.dnn.api.LayerType; import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.conf.DataFormat; import org.deeplearning4j.nn.conf.DataFormat;
@ -33,30 +30,29 @@ import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
@Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
public abstract class FeedForwardLayer extends BaseLayerConfiguration { public abstract class FeedForwardLayer extends BaseLayerConfiguration {
/** /**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers, * Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
* this is the input channels, otherwise is the previous layer size. * this is the input channels, otherwise is the previous layer size.
* *
*/ */
@Getter
protected long nIn; protected long nIn;
/** /**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers, * Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
* this is the input channels, otherwise is the previous layer size. * this is the input channels, otherwise is the previous layer size.
* *
*/ */
@Getter
protected long nOut; protected long nOut;
protected DataFormat timeDistributedFormat; protected DataFormat timeDistributedFormat;
{ //Initializer block //
setType(LayerType.FC); // { //Initializer block
} // setType(LayerType.FC);
//}
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
@ -129,4 +125,5 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
public boolean isPretrainParam(String paramName) { public boolean isPretrainParam(String paramName) {
return false; //No pretrain params in standard FF layers return false; //No pretrain params in standard FF layers
} }
} }

View File

@ -20,6 +20,10 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -36,33 +40,17 @@ import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
@Deprecated @Deprecated
@Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder(buildMethodName = "initBuild")
public class GravesLSTM extends AbstractLSTM { public class GravesLSTM extends AbstractLSTM {
public static abstract class GravesLSTMBuilder<C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>> extends AbstractLSTMBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
}
private double forgetGateBiasInit; private double forgetGateBiasInit;
@Builder.Default @Builder.Default @Getter private IActivation gateActivationFunction = new ActivationSigmoid();
private IActivation gateActivationFunction = new ActivationSigmoid();
@Override @Override
protected void initializeConstraints( ) { protected void initializeConstraints() {
super.initializeConstraints(); super.initializeConstraints();
if (getRecurrentConstraints() != null) { if (getRecurrentConstraints() != null) {
if (constraints == null) { if (constraints == null) {
@ -77,8 +65,13 @@ public class GravesLSTM extends AbstractLSTM {
} }
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerValidation.assertNInNOutSet("GravesLSTM", getName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("GravesLSTM", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
@ -104,9 +97,17 @@ public class GravesLSTM extends AbstractLSTM {
@Override @Override
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
//TODO - CuDNN etc // TODO - CuDNN etc
return LSTMHelpers.getMemoryReport(this, inputType); return LSTMHelpers.getMemoryReport(this, inputType);
} }
public abstract static class GravesLSTMBuilder<
C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>>
extends AbstractLSTMBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
}
} }

View File

@ -42,7 +42,6 @@ import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
@ -50,8 +49,6 @@ import org.nd4j.linalg.learning.regularization.Regularization;
/** A neural network layer. */ /** A neural network layer. */
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Data
@NoArgsConstructor
@EqualsAndHashCode @EqualsAndHashCode
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id") // @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
@Slf4j @Slf4j
@ -59,18 +56,16 @@ import org.nd4j.linalg.learning.regularization.Regularization;
public abstract class LayerConfiguration public abstract class LayerConfiguration
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
protected String name; @Getter @Setter protected String name;
protected List<LayerConstraint> allParamConstraints; @Getter protected List<LayerConstraint> allParamConstraints;
protected List<LayerConstraint> weightConstraints; @Getter protected List<LayerConstraint> weightConstraints;
protected List<LayerConstraint> biasConstraints; @Getter protected List<LayerConstraint> biasConstraints;
protected List<LayerConstraint> constraints; @Getter protected List<LayerConstraint> constraints;
protected IWeightNoise weightNoise; @Getter @Setter protected IWeightNoise weightNoise;
@Builder.Default @Builder.Default private @Getter @Setter LinkedHashSet<String> variables = new LinkedHashSet<>();
private @Getter @Setter LinkedHashSet<String> variables = new LinkedHashSet<>(); @Getter @Setter private IDropout dropOut;
private IDropout dropOut;
/** The type of the layer, basically defines the base class and its properties */ /** The type of the layer, basically defines the base class and its properties */
@Builder.Default @Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
@Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
/** /**
* A reference to the neural net configuration. This field is excluded from json serialization as * A reference to the neural net configuration. This field is excluded from json serialization as
* well as from equals check to avoid circular referenced. * well as from equals check to avoid circular referenced.
@ -87,8 +82,7 @@ public abstract class LayerConfiguration
* From an Activation, we can derive the IActivation (function) using {@link * From an Activation, we can derive the IActivation (function) using {@link
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation. * Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
*/ */
@Builder.Default @Builder.Default @Getter @Setter private IActivation activation = Activation.IDENTITY;
@Getter @Setter private IActivation activation = Activation.IDENTITY;
/** /**
* Get the activation interface (function) from the activation. The activation must have been set * Get the activation interface (function) from the activation. The activation must have been set
@ -293,9 +287,7 @@ public abstract class LayerConfiguration
public void setIUpdater(IUpdater iUpdater) { public void setIUpdater(IUpdater iUpdater) {
log.warn( log.warn(
"Setting an IUpdater on {} with name {} has no effect.", "Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), name);
getClass().getSimpleName(),
getName());
} }
/** /**
@ -333,12 +325,29 @@ public abstract class LayerConfiguration
runInheritance(getNetConfiguration()); runInheritance(getNetConfiguration());
} }
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> { public abstract static class LayerConfigurationBuilder<
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
private String name;
private List<LayerConstraint> allParamConstraints;
private List<LayerConstraint> weightConstraints;
private List<LayerConstraint> biasConstraints;
private List<LayerConstraint> constraints;
private IWeightNoise weightNoise;
private LinkedHashSet<String> variables$value;
private boolean variables$set;
private IDropout dropOut;
private @NonNull LayerType type$value;
private boolean type$set;
private NeuralNetConfiguration netConfiguration;
private IActivation activation$value;
private boolean activation$set;
public B activation(Activation activation) { public B activation(Activation activation) {
this.activation$value = activation; this.activation$value = activation;
this.activation$set = true; this.activation$set = true;
return self(); return self();
} }
public B activation(IActivation activation) { public B activation(IActivation activation) {
this.activation$value = activation; this.activation$value = activation;
this.activation$set = true; this.activation$set = true;
@ -349,6 +358,7 @@ public abstract class LayerConfiguration
this.dropOut = new Dropout(d); this.dropOut = new Dropout(d);
return self(); return self();
} }
public B dropOut(IDropout d) { public B dropOut(IDropout d) {
this.dropOut = d; this.dropOut = d;
return self(); return self();
@ -361,6 +371,95 @@ public abstract class LayerConfiguration
public B constrainWeights(LayerConstraint constraint) { public B constrainWeights(LayerConstraint constraint) {
return this.weightConstraints(List.of(constraint)); return this.weightConstraints(List.of(constraint));
} }
public B name(String name) {
this.name = name;
return self();
} }
public B allParamConstraints(List<LayerConstraint> allParamConstraints) {
this.allParamConstraints = allParamConstraints;
return self();
}
public B weightConstraints(List<LayerConstraint> weightConstraints) {
this.weightConstraints = weightConstraints;
return self();
}
public B biasConstraints(List<LayerConstraint> biasConstraints) {
this.biasConstraints = biasConstraints;
return self();
}
public B constraints(List<LayerConstraint> constraints) {
this.constraints = constraints;
return self();
}
public B weightNoise(IWeightNoise weightNoise) {
this.weightNoise = weightNoise;
return self();
}
public B variables(LinkedHashSet<String> variables) {
this.variables$value = variables;
this.variables$set = true;
return self();
}
public B type(@NonNull LayerType type) {
this.type$value = type;
this.type$set = true;
return self();
}
@JsonIgnore
public B netConfiguration(NeuralNetConfiguration netConfiguration) {
this.netConfiguration = netConfiguration;
return self();
}
protected abstract B self();
public abstract C build();
public String toString() {
return "LayerConfiguration.LayerConfigurationBuilder(name="
+ this.name
+ ", allParamConstraints="
+ this.allParamConstraints
+ ", weightConstraints="
+ this.weightConstraints
+ ", biasConstraints="
+ this.biasConstraints
+ ", constraints="
+ this.constraints
+ ", weightNoise="
+ this.weightNoise
+ ", variables$value="
+ this.variables$value
+ ", variables$set="
+ this.variables$set
+ ", dropOut="
+ this.dropOut
+ ", type$value="
+ this.type$value
+ ", type$set="
+ this.type$set
+ ", netConfiguration="
+ this.netConfiguration
+ ", activation$value="
+ this.activation$value
+ ", activation$set="
+ this.activation$set
+ ", variables$value="
+ this.variables$value
+ ", type$value="
+ this.type$value
+ ", activation$value="
+ this.activation$value
+ ")";
}
}
} }

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.List; import java.util.List;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import net.brutex.ai.dnn.api.LayerType; import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
@ -31,10 +30,8 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
@NoArgsConstructor
@SuperBuilder @SuperBuilder
public abstract class NoParamLayer extends LayerConfiguration { public abstract class NoParamLayer extends LayerConfiguration {
{ {
setType(LayerType.POOL); setType(LayerType.POOL);
} }
@ -68,4 +65,8 @@ public abstract class NoParamLayer extends LayerConfiguration {
public IUpdater getIUpdater() { public IUpdater getIUpdater() {
return Updater.NONE.getIUpdaterWithDefaultConfig(); return Updater.NONE.getIUpdaterWithDefaultConfig();
} }
public static abstract class NoParamLayerBuilder<C extends NoParamLayer, B extends NoParamLayerBuilder<C,B>>
extends LayerConfigurationBuilder<C,B>
{}
} }

View File

@ -48,10 +48,8 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay; import org.nd4j.linalg.learning.regularization.WeightDecay;
@Slf4j @Slf4j
@Data
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true) @EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
@SuperBuilder @SuperBuilder
@NoArgsConstructor
public abstract class AbstractSameDiffLayer extends LayerConfiguration { public abstract class AbstractSameDiffLayer extends LayerConfiguration {
/** /**
@ -245,9 +243,10 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
} }
} }
public static abstract class AbstractSameDiffLayerBuilder< public abstract static class AbstractSameDiffLayerBuilder<
C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>> C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> { extends LayerConfigurationBuilder<C, B> {
/** /**
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1 * L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1
* regularization coefficient for the bias. * regularization coefficient for the bias.

View File

@ -20,7 +20,7 @@
package org.deeplearning4j.nn.conf.layers.samediff; package org.deeplearning4j.nn.conf.layers.samediff;
import lombok.*; import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.MaskState;
@ -31,15 +31,15 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import java.util.Collection; import java.util.Collection;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
public abstract class SameDiffLayer extends AbstractSameDiffLayer { public abstract class SameDiffLayer extends AbstractSameDiffLayer {
@ -47,14 +47,9 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
/** /**
* WeightInit, default is XAVIER. * WeightInit, default is XAVIER.
*/ */
@Builder.Default
protected WeightInit weightInit = WeightInit.XAVIER; protected WeightInit weightInit = WeightInit.XAVIER;
@Builder.Default
protected Map<String,IWeightInit> paramWeightInit = new HashMap<>(); protected Map<String,IWeightInit> paramWeightInit = new HashMap<>();
protected SameDiffLayer() {
//No op constructor for Jackson
}
/** /**
* Define the layer * Define the layer
@ -100,6 +95,7 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
return ret; return ret;
} }
public static abstract class SameDiffLayerBuilder<C extends SameDiffLayer, B extends SameDiffLayerBuilder<C, B>> extends AbstractSameDiffLayerBuilder<C,B> {
}
} }

View File

@ -21,8 +21,8 @@
package org.deeplearning4j.nn.conf.layers.wrapper; package org.deeplearning4j.nn.conf.layers.wrapper;
import java.util.List; import java.util.List;
import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
@ -37,24 +37,12 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
@Data
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
@SuperBuilder @SuperBuilder(builderMethodName = "innerBuilder")
public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration { public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
/** /** The configuration to of another layer to wrap */
* The configuration to of another layer to wrap @Getter protected LayerConfiguration underlying;
*/
protected LayerConfiguration underlying;
protected BaseWrapperLayerConfiguration() {
}
public BaseWrapperLayerConfiguration(LayerConfiguration underlying) {
this.underlying = underlying;
this.setNetConfiguration(underlying.getNetConfiguration());
}
/** /**
* Set the net configuration for this configuration as well as for the underlying layer (if not * Set the net configuration for this configuration as well as for the underlying layer (if not
@ -67,7 +55,7 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
super.setNetConfiguration(netConfiguration); super.setNetConfiguration(netConfiguration);
if (underlying.getNetConfiguration() == null) { if (underlying.getNetConfiguration() == null) {
underlying.setNetConfiguration( underlying.setNetConfiguration(
netConfiguration); //also set netconf for underlying if not set netConfiguration); // also set netconf for underlying if not set
} }
} }
@ -87,14 +75,6 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
return underlying.getDropOut(); return underlying.getDropOut();
} }
/**
* @param activationFn
*/
@Override
public void setActivation(IActivation activationFn) {
underlying.setActivation(activationFn);
}
/** /**
* @param iDropout * @param iDropout
*/ */
@ -103,6 +83,14 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
underlying.setDropOut(iDropout); underlying.setDropOut(iDropout);
} }
/**
* @param activationFn
*/
@Override
public void setActivation(IActivation activationFn) {
underlying.setActivation(activationFn);
}
/** /**
* @param weightNoise * @param weightNoise
*/ */
@ -131,14 +119,6 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
return underlying.getUpdaterByParam(paramName); return underlying.getUpdaterByParam(paramName);
} }
/**
* @param iUpdater
*/
@Override
public void setIUpdater(IUpdater iUpdater) {
underlying.setIUpdater(iUpdater);
}
/** /**
* @return * @return
*/ */
@ -147,6 +127,14 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
return underlying.getIUpdater(); return underlying.getIUpdater();
} }
/**
* @param iUpdater
*/
@Override
public void setIUpdater(IUpdater iUpdater) {
underlying.setIUpdater(iUpdater);
}
@Override @Override
public ParamInitializer initializer() { public ParamInitializer initializer() {
return WrapperLayerParamInitializer.getInstance(); return WrapperLayerParamInitializer.getInstance();
@ -186,9 +174,8 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
public void setName(String layerName) { public void setName(String layerName) {
super.setName(layerName); super.setName(layerName);
if (underlying != null) { if (underlying != null) {
//May be null at some points during JSON deserialization // May be null at some points during JSON deserialization
underlying.setName(layerName); underlying.setName(layerName);
} }
} }
} }