Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-25 13:25:23 +02:00
parent 9139940101
commit 55f8486fe3
16 changed files with 257 additions and 196 deletions

View File

@ -27,8 +27,7 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder
@ -38,13 +37,13 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer {
* Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or
* longer-term dependencies.
*/
@lombok.Builder.Default protected double forgetGateBiasInit = 1.0;
@lombok.Builder.Default @Getter protected double forgetGateBiasInit = 1.0;
/**
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation
* be allowed? If set to false, an exception in CuDNN will be propagated back to the user. If
* false, the built-in (non-CuDNN) implementation for LSTM/GravesLSTM will be used
*/
@lombok.Builder.Default protected boolean helperAllowFallback = true;
@lombok.Builder.Default @Getter protected boolean helperAllowFallback = true;
/**
* Activation function for the LSTM gates. Note: This should be bounded to range 0-1: sigmoid or
* hard sigmoid, for example

View File

@ -135,7 +135,7 @@ public class ActivationLayer extends NoParamLayer {
public static abstract class ActivationLayerBuilder<
C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>>
extends NoParamLayerBuilder<C, B> {
extends NoParamLayer.NoParamLayerBuilder<C, B> {
public C build() {
C l = this.build();
l.initializeConstraints();

View File

@ -36,8 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder
@ -47,13 +46,13 @@ public class AutoEncoder extends BasePretrainNetwork {
* Level of corruption - 0.0 (none) to 1.0 (all values corrupted)
*
*/
@lombok.Builder.Default
@lombok.Builder.Default @Getter @Setter
private double corruptionLevel = 3e-1f;
/**
* Autoencoder sparity parameter
*
*/
@lombok.Builder.Default
@lombok.Builder.Default @Getter @Setter
protected double sparsity = 0f;

View File

@ -45,9 +45,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
/** A neural network layer. */
@Data
@EqualsAndHashCode(callSuper = true)
@NoArgsConstructor(force = true)
@SuperBuilder
public abstract class BaseLayerConfiguration extends LayerConfiguration
implements ITraininableLayerConfiguration, Serializable, Cloneable {
@ -62,7 +60,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
*
* @param constraints Constraints to apply to all bias parameters of all layers
*/
@lombok.Builder.Default protected final List<LayerConstraint> biasConstraints = new ArrayList<>();
@lombok.Builder.Default @Getter protected final List<LayerConstraint> biasConstraints = new ArrayList<>();
/**
* Set constraints to be applied to all layers. Default: no constraints.<br>
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
@ -74,27 +72,33 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
*
* @param constraints Constraints to apply to all weight parameters of all layers
*/
@lombok.Builder.Default
@lombok.Builder.Default @Getter
protected final List<LayerConstraint> constrainWeights = new ArrayList<>();
/** Weight initialization scheme to use, for initial weight values */
@Getter @Setter
protected IWeightInit weightInit;
/** Bias initialization value, for layers with biases. Defaults to 0 */
@Getter @Setter @Builder.Default
protected double biasInit = 0.0;
/** Gain initialization value, for layers with ILayer Normalization. Defaults to 1 */
@Getter @Setter @Builder.Default
protected double gainInit = 0.0;
/** Regularization for the parameters (excluding biases). */
@Builder.Default protected List<Regularization> regularization = new ArrayList<>();
@Builder.Default @Getter protected List<Regularization> regularization = new ArrayList<>();
/** Regularization for the bias parameters only */
@Builder.Default protected List<Regularization> regularizationBias = new ArrayList<>();
@Builder.Default @Getter
protected List<Regularization> regularizationBias = new ArrayList<>();
/**
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
* org.nd4j.linalg.learning.config.Nesterovs}
*/
@Getter @Setter
protected IUpdater updater;
/**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
* set by {@link #setUpdater(IUpdater)}
*/
@Getter @Setter
protected IUpdater biasUpdater;
/**
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
@ -103,7 +107,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
* @see GradientNormalization
*/
@Builder.Default
protected @Getter GradientNormalization gradientNormalization =
protected @Getter @Setter GradientNormalization gradientNormalization =
GradientNormalization.None; // Clipping, rescale based on l2 norm, etc
/**
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
@ -113,10 +117,10 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
* L2 threshold for first two types of clipping, or absolute value threshold for last type of
* clipping.
*/
@Builder.Default
@Builder.Default @Getter @Setter
protected double gradientNormalizationThreshold =
1.0; // Threshold for l2 and element-wise gradient clipping
@Getter @Setter
private DataType dataType;
/**
@ -206,6 +210,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> {
/**
* Set weight initialization scheme to random sampling via the specified distribution.
* Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
@ -411,16 +416,5 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
regularizationBias$set = true;
return self();
}
public B updater(IUpdater updater) {
this.updater = updater;
return self();
}
public B updater(Updater updater) {
this.updater = updater.getIUpdaterWithDefaultConfig();
return self();
}
}
}

View File

@ -26,8 +26,6 @@ import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties("pretrain")

View File

@ -30,8 +30,7 @@ import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.weights.IWeightInit;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder
@ -42,12 +41,13 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
* weight initialization as the layer input weights is also used for the recurrent weights.
*
*/
@Getter
protected IWeightInit weightInitRecurrent;
/**
* Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength],
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
*/
@Builder.Default
@Builder.Default @Getter @Setter
protected RNNFormat dataFormat = RNNFormat.NCW;
/**
* Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no
@ -55,6 +55,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
* been updated.
*/
@Getter
protected List<LayerConstraint> recurrentConstraints;
/**
* Set constraints to be applied to the RNN input weight parameters of this layer. Default: no constraints.<br>
@ -62,6 +63,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
* etc). These constraints are applied at each iteration, after the parameters have been updated.
*
*/
@Getter
protected List<LayerConstraint> inputWeightConstraints;
@Override
@ -125,6 +127,4 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
return self();
}
}
}

View File

@ -39,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class CapsuleLayer extends SameDiffLayer {
@ -78,11 +77,6 @@ public class CapsuleLayer extends SameDiffLayer {
* @return
*/
@Builder.Default private int routings = 3;
public CapsuleLayer(Builder builder){
}
@Override
public void setNIn(InputType inputType, boolean override) {

View File

@ -41,16 +41,14 @@ import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
/**
* ConvolutionLayer nIn in the input layer is the number of channels nOut is the number of filters
* to be used in the net or in other words the channels The builder specifies the filter/kernel
* size, the stride and padding The pooling layer takes the kernel size
*/
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class ConvolutionLayer extends FeedForwardLayer {
/**
* Size of the convolution rows/columns

View File

@ -38,8 +38,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
/** Dense Layer Uses WeightInitXavier as default */
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(

View File

@ -20,10 +20,7 @@
package org.deeplearning4j.nn.conf.layers;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.*;
import lombok.experimental.SuperBuilder;
import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.conf.DataFormat;
@ -33,30 +30,29 @@ import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder
public abstract class FeedForwardLayer extends BaseLayerConfiguration {
/**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
* this is the input channels, otherwise is the previous layer size.
*
*/
@Getter
protected long nIn;
/**
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
* this is the input channels, otherwise is the previous layer size.
*
*/
@Getter
protected long nOut;
protected DataFormat timeDistributedFormat;
{ //Initializer block
setType(LayerType.FC);
}
//
// { //Initializer block
// setType(LayerType.FC);
//}
@Override
public InputType getOutputType(int layerIndex, InputType inputType) {
@ -129,4 +125,5 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
public boolean isPretrainParam(String paramName) {
return false; //No pretrain params in standard FF layers
}
}

View File

@ -20,6 +20,10 @@
package org.deeplearning4j.nn.conf.layers;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer;
@ -36,33 +40,17 @@ import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
@Deprecated
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild")
public class GravesLSTM extends AbstractLSTM {
public static abstract class GravesLSTMBuilder<C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>> extends AbstractLSTMBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
}
private double forgetGateBiasInit;
@Builder.Default
private IActivation gateActivationFunction = new ActivationSigmoid();
@Builder.Default @Getter private IActivation gateActivationFunction = new ActivationSigmoid();
@Override
protected void initializeConstraints( ) {
protected void initializeConstraints() {
super.initializeConstraints();
if (getRecurrentConstraints() != null) {
if (constraints == null) {
@ -77,8 +65,13 @@ public class GravesLSTM extends AbstractLSTM {
}
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
public Layer instantiate(
NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerValidation.assertNInNOutSet("GravesLSTM", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
@ -104,9 +97,17 @@ public class GravesLSTM extends AbstractLSTM {
@Override
public LayerMemoryReport getMemoryReport(InputType inputType) {
//TODO - CuDNN etc
// TODO - CuDNN etc
return LSTMHelpers.getMemoryReport(this, inputType);
}
public abstract static class GravesLSTMBuilder<
C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>>
extends AbstractLSTMBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
}
}

View File

@ -42,7 +42,6 @@ import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
@ -50,8 +49,6 @@ import org.nd4j.linalg.learning.regularization.Regularization;
/** A neural network layer. */
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@Data
@NoArgsConstructor
@EqualsAndHashCode
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
@Slf4j
@ -59,18 +56,16 @@ import org.nd4j.linalg.learning.regularization.Regularization;
public abstract class LayerConfiguration
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
protected String name;
protected List<LayerConstraint> allParamConstraints;
protected List<LayerConstraint> weightConstraints;
protected List<LayerConstraint> biasConstraints;
protected List<LayerConstraint> constraints;
protected IWeightNoise weightNoise;
@Builder.Default
private @Getter @Setter LinkedHashSet<String> variables = new LinkedHashSet<>();
private IDropout dropOut;
@Getter @Setter protected String name;
@Getter protected List<LayerConstraint> allParamConstraints;
@Getter protected List<LayerConstraint> weightConstraints;
@Getter protected List<LayerConstraint> biasConstraints;
@Getter protected List<LayerConstraint> constraints;
@Getter @Setter protected IWeightNoise weightNoise;
@Builder.Default private @Getter @Setter LinkedHashSet<String> variables = new LinkedHashSet<>();
@Getter @Setter private IDropout dropOut;
/** The type of the layer, basically defines the base class and its properties */
@Builder.Default
@Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
/**
* A reference to the neural net configuration. This field is excluded from json serialization as
* well as from equals check to avoid circular referenced.
@ -87,8 +82,7 @@ public abstract class LayerConfiguration
* From an Activation, we can derive the IActivation (function) using {@link
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
*/
@Builder.Default
@Getter @Setter private IActivation activation = Activation.IDENTITY;
@Builder.Default @Getter @Setter private IActivation activation = Activation.IDENTITY;
/**
* Get the activation interface (function) from the activation. The activation must have been set
@ -293,9 +287,7 @@ public abstract class LayerConfiguration
public void setIUpdater(IUpdater iUpdater) {
log.warn(
"Setting an IUpdater on {} with name {} has no effect.",
getClass().getSimpleName(),
getName());
"Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), name);
}
/**
@ -333,12 +325,29 @@ public abstract class LayerConfiguration
runInheritance(getNetConfiguration());
}
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
public abstract static class LayerConfigurationBuilder<
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
private String name;
private List<LayerConstraint> allParamConstraints;
private List<LayerConstraint> weightConstraints;
private List<LayerConstraint> biasConstraints;
private List<LayerConstraint> constraints;
private IWeightNoise weightNoise;
private LinkedHashSet<String> variables$value;
private boolean variables$set;
private IDropout dropOut;
private @NonNull LayerType type$value;
private boolean type$set;
private NeuralNetConfiguration netConfiguration;
private IActivation activation$value;
private boolean activation$set;
public B activation(Activation activation) {
this.activation$value = activation;
this.activation$set = true;
return self();
}
public B activation(IActivation activation) {
this.activation$value = activation;
this.activation$set = true;
@ -349,6 +358,7 @@ public abstract class LayerConfiguration
this.dropOut = new Dropout(d);
return self();
}
public B dropOut(IDropout d) {
this.dropOut = d;
return self();
@ -361,6 +371,95 @@ public abstract class LayerConfiguration
public B constrainWeights(LayerConstraint constraint) {
return this.weightConstraints(List.of(constraint));
}
public B name(String name) {
this.name = name;
return self();
}
public B allParamConstraints(List<LayerConstraint> allParamConstraints) {
this.allParamConstraints = allParamConstraints;
return self();
}
public B weightConstraints(List<LayerConstraint> weightConstraints) {
this.weightConstraints = weightConstraints;
return self();
}
public B biasConstraints(List<LayerConstraint> biasConstraints) {
this.biasConstraints = biasConstraints;
return self();
}
public B constraints(List<LayerConstraint> constraints) {
this.constraints = constraints;
return self();
}
public B weightNoise(IWeightNoise weightNoise) {
this.weightNoise = weightNoise;
return self();
}
public B variables(LinkedHashSet<String> variables) {
this.variables$value = variables;
this.variables$set = true;
return self();
}
public B type(@NonNull LayerType type) {
this.type$value = type;
this.type$set = true;
return self();
}
@JsonIgnore
public B netConfiguration(NeuralNetConfiguration netConfiguration) {
this.netConfiguration = netConfiguration;
return self();
}
protected abstract B self();
public abstract C build();
public String toString() {
return "LayerConfiguration.LayerConfigurationBuilder(name="
+ this.name
+ ", allParamConstraints="
+ this.allParamConstraints
+ ", weightConstraints="
+ this.weightConstraints
+ ", biasConstraints="
+ this.biasConstraints
+ ", constraints="
+ this.constraints
+ ", weightNoise="
+ this.weightNoise
+ ", variables$value="
+ this.variables$value
+ ", variables$set="
+ this.variables$set
+ ", dropOut="
+ this.dropOut
+ ", type$value="
+ this.type$value
+ ", type$set="
+ this.type$set
+ ", netConfiguration="
+ this.netConfiguration
+ ", activation$value="
+ this.activation$value
+ ", activation$set="
+ this.activation$set
+ ", variables$value="
+ this.variables$value
+ ", type$value="
+ this.type$value
+ ", activation$value="
+ this.activation$value
+ ")";
}
}
}

View File

@ -21,7 +21,6 @@
package org.deeplearning4j.nn.conf.layers;
import java.util.List;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.api.ParamInitializer;
@ -31,10 +30,8 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
@NoArgsConstructor
@SuperBuilder
public abstract class NoParamLayer extends LayerConfiguration {
{
setType(LayerType.POOL);
}
@ -68,4 +65,8 @@ public abstract class NoParamLayer extends LayerConfiguration {
public IUpdater getIUpdater() {
return Updater.NONE.getIUpdaterWithDefaultConfig();
}
public static abstract class NoParamLayerBuilder<C extends NoParamLayer, B extends NoParamLayerBuilder<C,B>>
extends LayerConfigurationBuilder<C,B>
{}
}

View File

@ -48,10 +48,8 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay;
@Slf4j
@Data
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
@SuperBuilder
@NoArgsConstructor
public abstract class AbstractSameDiffLayer extends LayerConfiguration {
/**
@ -245,9 +243,10 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
}
}
public static abstract class AbstractSameDiffLayerBuilder<
public abstract static class AbstractSameDiffLayerBuilder<
C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> {
/**
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1
* regularization coefficient for the bias.

View File

@ -20,7 +20,7 @@
package org.deeplearning4j.nn.conf.layers.samediff;
import lombok.*;
import lombok.EqualsAndHashCode;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
@ -31,15 +31,15 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = true)
@SuperBuilder
public abstract class SameDiffLayer extends AbstractSameDiffLayer {
@ -47,14 +47,9 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
/**
* WeightInit, default is XAVIER.
*/
@Builder.Default
protected WeightInit weightInit = WeightInit.XAVIER;
@Builder.Default
protected Map<String,IWeightInit> paramWeightInit = new HashMap<>();
protected SameDiffLayer() {
//No op constructor for Jackson
}
/**
* Define the layer
@ -100,6 +95,7 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
return ret;
}
public static abstract class SameDiffLayerBuilder<C extends SameDiffLayer, B extends SameDiffLayerBuilder<C, B>> extends AbstractSameDiffLayerBuilder<C,B> {
}
}

View File

@ -21,8 +21,8 @@
package org.deeplearning4j.nn.conf.layers.wrapper;
import java.util.List;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
@ -37,24 +37,12 @@ import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
@Data
@EqualsAndHashCode(callSuper = false)
@SuperBuilder
@SuperBuilder(builderMethodName = "innerBuilder")
public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
/**
* The configuration to of another layer to wrap
*/
protected LayerConfiguration underlying;
protected BaseWrapperLayerConfiguration() {
}
public BaseWrapperLayerConfiguration(LayerConfiguration underlying) {
this.underlying = underlying;
this.setNetConfiguration(underlying.getNetConfiguration());
}
/** The configuration to of another layer to wrap */
@Getter protected LayerConfiguration underlying;
/**
* Set the net configuration for this configuration as well as for the underlying layer (if not
@ -67,7 +55,7 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
super.setNetConfiguration(netConfiguration);
if (underlying.getNetConfiguration() == null) {
underlying.setNetConfiguration(
netConfiguration); //also set netconf for underlying if not set
netConfiguration); // also set netconf for underlying if not set
}
}
@ -87,14 +75,6 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
return underlying.getDropOut();
}
/**
* @param activationFn
*/
@Override
public void setActivation(IActivation activationFn) {
underlying.setActivation(activationFn);
}
/**
* @param iDropout
*/
@ -103,6 +83,14 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
underlying.setDropOut(iDropout);
}
/**
* @param activationFn
*/
@Override
public void setActivation(IActivation activationFn) {
underlying.setActivation(activationFn);
}
/**
* @param weightNoise
*/
@ -131,14 +119,6 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
return underlying.getUpdaterByParam(paramName);
}
/**
* @param iUpdater
*/
@Override
public void setIUpdater(IUpdater iUpdater) {
underlying.setIUpdater(iUpdater);
}
/**
* @return
*/
@ -147,6 +127,14 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
return underlying.getIUpdater();
}
/**
* @param iUpdater
*/
@Override
public void setIUpdater(IUpdater iUpdater) {
underlying.setIUpdater(iUpdater);
}
@Override
public ParamInitializer initializer() {
return WrapperLayerParamInitializer.getInstance();
@ -186,9 +174,8 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
public void setName(String layerName) {
super.setName(layerName);
if (underlying != null) {
//May be null at some points during JSON deserialization
// May be null at some points during JSON deserialization
underlying.setName(layerName);
}
}
}