Using @SuperBuilder for LayerConfigurations
Signed-off-by: brian <brian@brutex.de>master
parent
9139940101
commit
55f8486fe3
|
@ -27,8 +27,7 @@ import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationTanH;
|
import org.nd4j.linalg.activations.impl.ActivationTanH;
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
|
@ -38,13 +37,13 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer {
|
||||||
* Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or
|
* Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or
|
||||||
* longer-term dependencies.
|
* longer-term dependencies.
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default protected double forgetGateBiasInit = 1.0;
|
@lombok.Builder.Default @Getter protected double forgetGateBiasInit = 1.0;
|
||||||
/**
|
/**
|
||||||
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation
|
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation
|
||||||
* be allowed? If set to false, an exception in CuDNN will be propagated back to the user. If
|
* be allowed? If set to false, an exception in CuDNN will be propagated back to the user. If
|
||||||
* false, the built-in (non-CuDNN) implementation for LSTM/GravesLSTM will be used
|
* false, the built-in (non-CuDNN) implementation for LSTM/GravesLSTM will be used
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default protected boolean helperAllowFallback = true;
|
@lombok.Builder.Default @Getter protected boolean helperAllowFallback = true;
|
||||||
/**
|
/**
|
||||||
* Activation function for the LSTM gates. Note: This should be bounded to range 0-1: sigmoid or
|
* Activation function for the LSTM gates. Note: This should be bounded to range 0-1: sigmoid or
|
||||||
* hard sigmoid, for example
|
* hard sigmoid, for example
|
||||||
|
|
|
@ -135,7 +135,7 @@ public class ActivationLayer extends NoParamLayer {
|
||||||
|
|
||||||
public static abstract class ActivationLayerBuilder<
|
public static abstract class ActivationLayerBuilder<
|
||||||
C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>>
|
C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>>
|
||||||
extends NoParamLayerBuilder<C, B> {
|
extends NoParamLayer.NoParamLayerBuilder<C, B> {
|
||||||
public C build() {
|
public C build() {
|
||||||
C l = this.build();
|
C l = this.build();
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
|
|
|
@ -36,8 +36,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
|
@ -47,13 +46,13 @@ public class AutoEncoder extends BasePretrainNetwork {
|
||||||
* Level of corruption - 0.0 (none) to 1.0 (all values corrupted)
|
* Level of corruption - 0.0 (none) to 1.0 (all values corrupted)
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default @Getter @Setter
|
||||||
private double corruptionLevel = 3e-1f;
|
private double corruptionLevel = 3e-1f;
|
||||||
/**
|
/**
|
||||||
* Autoencoder sparity parameter
|
* Autoencoder sparity parameter
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default @Getter @Setter
|
||||||
protected double sparsity = 0f;
|
protected double sparsity = 0f;
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,9 +45,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
|
||||||
/** A neural network layer. */
|
/** A neural network layer. */
|
||||||
@Data
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@NoArgsConstructor(force = true)
|
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
public abstract class BaseLayerConfiguration extends LayerConfiguration
|
public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
implements ITraininableLayerConfiguration, Serializable, Cloneable {
|
implements ITraininableLayerConfiguration, Serializable, Cloneable {
|
||||||
|
@ -62,7 +60,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
*
|
*
|
||||||
* @param constraints Constraints to apply to all bias parameters of all layers
|
* @param constraints Constraints to apply to all bias parameters of all layers
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default protected final List<LayerConstraint> biasConstraints = new ArrayList<>();
|
@lombok.Builder.Default @Getter protected final List<LayerConstraint> biasConstraints = new ArrayList<>();
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to all layers. Default: no constraints.<br>
|
* Set constraints to be applied to all layers. Default: no constraints.<br>
|
||||||
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
||||||
|
@ -74,27 +72,33 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
*
|
*
|
||||||
* @param constraints Constraints to apply to all weight parameters of all layers
|
* @param constraints Constraints to apply to all weight parameters of all layers
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default @Getter
|
||||||
protected final List<LayerConstraint> constrainWeights = new ArrayList<>();
|
protected final List<LayerConstraint> constrainWeights = new ArrayList<>();
|
||||||
/** Weight initialization scheme to use, for initial weight values */
|
/** Weight initialization scheme to use, for initial weight values */
|
||||||
|
@Getter @Setter
|
||||||
protected IWeightInit weightInit;
|
protected IWeightInit weightInit;
|
||||||
/** Bias initialization value, for layers with biases. Defaults to 0 */
|
/** Bias initialization value, for layers with biases. Defaults to 0 */
|
||||||
|
@Getter @Setter @Builder.Default
|
||||||
protected double biasInit = 0.0;
|
protected double biasInit = 0.0;
|
||||||
/** Gain initialization value, for layers with ILayer Normalization. Defaults to 1 */
|
/** Gain initialization value, for layers with ILayer Normalization. Defaults to 1 */
|
||||||
|
@Getter @Setter @Builder.Default
|
||||||
protected double gainInit = 0.0;
|
protected double gainInit = 0.0;
|
||||||
/** Regularization for the parameters (excluding biases). */
|
/** Regularization for the parameters (excluding biases). */
|
||||||
@Builder.Default protected List<Regularization> regularization = new ArrayList<>();
|
@Builder.Default @Getter protected List<Regularization> regularization = new ArrayList<>();
|
||||||
/** Regularization for the bias parameters only */
|
/** Regularization for the bias parameters only */
|
||||||
@Builder.Default protected List<Regularization> regularizationBias = new ArrayList<>();
|
@Builder.Default @Getter
|
||||||
|
protected List<Regularization> regularizationBias = new ArrayList<>();
|
||||||
/**
|
/**
|
||||||
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
|
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
|
||||||
* org.nd4j.linalg.learning.config.Nesterovs}
|
* org.nd4j.linalg.learning.config.Nesterovs}
|
||||||
*/
|
*/
|
||||||
|
@Getter @Setter
|
||||||
protected IUpdater updater;
|
protected IUpdater updater;
|
||||||
/**
|
/**
|
||||||
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
||||||
* set by {@link #setUpdater(IUpdater)}
|
* set by {@link #setUpdater(IUpdater)}
|
||||||
*/
|
*/
|
||||||
|
@Getter @Setter
|
||||||
protected IUpdater biasUpdater;
|
protected IUpdater biasUpdater;
|
||||||
/**
|
/**
|
||||||
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
|
* Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping
|
||||||
|
@ -103,7 +107,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
* @see GradientNormalization
|
* @see GradientNormalization
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
protected @Getter GradientNormalization gradientNormalization =
|
protected @Getter @Setter GradientNormalization gradientNormalization =
|
||||||
GradientNormalization.None; // Clipping, rescale based on l2 norm, etc
|
GradientNormalization.None; // Clipping, rescale based on l2 norm, etc
|
||||||
/**
|
/**
|
||||||
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
|
* Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer,
|
||||||
|
@ -113,10 +117,10 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
* L2 threshold for first two types of clipping, or absolute value threshold for last type of
|
* L2 threshold for first two types of clipping, or absolute value threshold for last type of
|
||||||
* clipping.
|
* clipping.
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default @Getter @Setter
|
||||||
protected double gradientNormalizationThreshold =
|
protected double gradientNormalizationThreshold =
|
||||||
1.0; // Threshold for l2 and element-wise gradient clipping
|
1.0; // Threshold for l2 and element-wise gradient clipping
|
||||||
|
@Getter @Setter
|
||||||
private DataType dataType;
|
private DataType dataType;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -206,6 +210,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
|
C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
|
||||||
extends LayerConfigurationBuilder<C, B> {
|
extends LayerConfigurationBuilder<C, B> {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set weight initialization scheme to random sampling via the specified distribution.
|
* Set weight initialization scheme to random sampling via the specified distribution.
|
||||||
* Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
|
* Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}
|
||||||
|
@ -411,16 +416,5 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
|
||||||
regularizationBias$set = true;
|
regularizationBias$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B updater(IUpdater updater) {
|
|
||||||
this.updater = updater;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public B updater(Updater updater) {
|
|
||||||
this.updater = updater.getIUpdaterWithDefaultConfig();
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,8 +26,6 @@ import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.params.PretrainParamInitializer;
|
import org.deeplearning4j.nn.params.PretrainParamInitializer;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@JsonIgnoreProperties("pretrain")
|
@JsonIgnoreProperties("pretrain")
|
||||||
|
|
|
@ -30,8 +30,7 @@ import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
|
@ -42,12 +41,13 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
||||||
* weight initialization as the layer input weights is also used for the recurrent weights.
|
* weight initialization as the layer input weights is also used for the recurrent weights.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
protected IWeightInit weightInitRecurrent;
|
protected IWeightInit weightInitRecurrent;
|
||||||
/**
|
/**
|
||||||
* Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength],
|
* Set the format of data expected by the RNN. NCW = [miniBatchSize, size, timeSeriesLength],
|
||||||
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
|
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default @Getter @Setter
|
||||||
protected RNNFormat dataFormat = RNNFormat.NCW;
|
protected RNNFormat dataFormat = RNNFormat.NCW;
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no
|
* Set constraints to be applied to the RNN recurrent weight parameters of this layer. Default: no
|
||||||
|
@ -55,6 +55,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
||||||
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
||||||
* been updated.
|
* been updated.
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
protected List<LayerConstraint> recurrentConstraints;
|
protected List<LayerConstraint> recurrentConstraints;
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to the RNN input weight parameters of this layer. Default: no constraints.<br>
|
* Set constraints to be applied to the RNN input weight parameters of this layer. Default: no constraints.<br>
|
||||||
|
@ -62,6 +63,7 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
||||||
* etc). These constraints are applied at each iteration, after the parameters have been updated.
|
* etc). These constraints are applied at each iteration, after the parameters have been updated.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
protected List<LayerConstraint> inputWeightConstraints;
|
protected List<LayerConstraint> inputWeightConstraints;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -125,6 +127,4 @@ public abstract class BaseRecurrentLayer extends FeedForwardLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
||||||
public class CapsuleLayer extends SameDiffLayer {
|
public class CapsuleLayer extends SameDiffLayer {
|
||||||
|
@ -78,11 +77,6 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
@Builder.Default private int routings = 3;
|
@Builder.Default private int routings = 3;
|
||||||
public CapsuleLayer(Builder builder){
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
|
|
|
@ -41,16 +41,14 @@ import org.deeplearning4j.util.ValidationUtils;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
|
||||||
/**
|
/**
|
||||||
* ConvolutionLayer nIn in the input layer is the number of channels nOut is the number of filters
|
* ConvolutionLayer nIn in the input layer is the number of channels nOut is the number of filters
|
||||||
* to be used in the net or in other words the channels The builder specifies the filter/kernel
|
* to be used in the net or in other words the channels The builder specifies the filter/kernel
|
||||||
* size, the stride and padding The pooling layer takes the kernel size
|
* size, the stride and padding The pooling layer takes the kernel size
|
||||||
*/
|
*/
|
||||||
|
@ToString(callSuper = true)
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
||||||
public class ConvolutionLayer extends FeedForwardLayer {
|
public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
/**
|
/**
|
||||||
* Size of the convolution rows/columns
|
* Size of the convolution rows/columns
|
||||||
|
|
|
@ -38,8 +38,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/** Dense Layer Uses WeightInitXavier as default */
|
/** Dense Layer Uses WeightInitXavier as default */
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(
|
@SuperBuilder(
|
||||||
|
|
|
@ -20,10 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.*;
|
||||||
import lombok.EqualsAndHashCode;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.ToString;
|
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import net.brutex.ai.dnn.api.LayerType;
|
import net.brutex.ai.dnn.api.LayerType;
|
||||||
import org.deeplearning4j.nn.conf.DataFormat;
|
import org.deeplearning4j.nn.conf.DataFormat;
|
||||||
|
@ -33,30 +30,29 @@ import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
public abstract class FeedForwardLayer extends BaseLayerConfiguration {
|
public abstract class FeedForwardLayer extends BaseLayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
||||||
* this is the input channels, otherwise is the previous layer size.
|
* this is the input channels, otherwise is the previous layer size.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
protected long nIn;
|
protected long nIn;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
* Number of inputs for the layer (usually the size of the last layer). <br> Note that for Convolutional layers,
|
||||||
* this is the input channels, otherwise is the previous layer size.
|
* this is the input channels, otherwise is the previous layer size.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@Getter
|
||||||
protected long nOut;
|
protected long nOut;
|
||||||
protected DataFormat timeDistributedFormat;
|
protected DataFormat timeDistributedFormat;
|
||||||
{ //Initializer block
|
//
|
||||||
setType(LayerType.FC);
|
// { //Initializer block
|
||||||
}
|
// setType(LayerType.FC);
|
||||||
|
//}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
|
@ -129,4 +125,5 @@ public abstract class FeedForwardLayer extends BaseLayerConfiguration {
|
||||||
public boolean isPretrainParam(String paramName) {
|
public boolean isPretrainParam(String paramName) {
|
||||||
return false; //No pretrain params in standard FF layers
|
return false; //No pretrain params in standard FF layers
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -36,30 +40,14 @@ import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder(buildMethodName = "initBuild")
|
||||||
public class GravesLSTM extends AbstractLSTM {
|
public class GravesLSTM extends AbstractLSTM {
|
||||||
|
|
||||||
public static abstract class GravesLSTMBuilder<C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>> extends AbstractLSTMBuilder<C, B> {
|
|
||||||
public C build() {
|
|
||||||
C l = initBuild();
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
private double forgetGateBiasInit;
|
private double forgetGateBiasInit;
|
||||||
@Builder.Default
|
@Builder.Default @Getter private IActivation gateActivationFunction = new ActivationSigmoid();
|
||||||
private IActivation gateActivationFunction = new ActivationSigmoid();
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void initializeConstraints() {
|
protected void initializeConstraints() {
|
||||||
|
@ -77,8 +65,13 @@ public class GravesLSTM extends AbstractLSTM {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
public Layer instantiate(
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
NeuralNetConfiguration conf,
|
||||||
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
LayerValidation.assertNInNOutSet("GravesLSTM", getName(), layerIndex, getNIn(), getNOut());
|
LayerValidation.assertNInNOutSet("GravesLSTM", getName(), layerIndex, getNIn(), getNOut());
|
||||||
|
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
@ -108,5 +101,13 @@ public class GravesLSTM extends AbstractLSTM {
|
||||||
return LSTMHelpers.getMemoryReport(this, inputType);
|
return LSTMHelpers.getMemoryReport(this, inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public abstract static class GravesLSTMBuilder<
|
||||||
|
C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>>
|
||||||
|
extends AbstractLSTMBuilder<C, B> {
|
||||||
|
public C build() {
|
||||||
|
C l = initBuild();
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,6 @@ import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
@ -50,8 +49,6 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
/** A neural network layer. */
|
/** A neural network layer. */
|
||||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
@EqualsAndHashCode
|
@EqualsAndHashCode
|
||||||
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
|
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
|
||||||
@Slf4j
|
@Slf4j
|
||||||
|
@ -59,18 +56,16 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
public abstract class LayerConfiguration
|
public abstract class LayerConfiguration
|
||||||
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
|
implements ILayerConfiguration, Serializable, Cloneable { // ITrainableLayerConfiguration
|
||||||
|
|
||||||
protected String name;
|
@Getter @Setter protected String name;
|
||||||
protected List<LayerConstraint> allParamConstraints;
|
@Getter protected List<LayerConstraint> allParamConstraints;
|
||||||
protected List<LayerConstraint> weightConstraints;
|
@Getter protected List<LayerConstraint> weightConstraints;
|
||||||
protected List<LayerConstraint> biasConstraints;
|
@Getter protected List<LayerConstraint> biasConstraints;
|
||||||
protected List<LayerConstraint> constraints;
|
@Getter protected List<LayerConstraint> constraints;
|
||||||
protected IWeightNoise weightNoise;
|
@Getter @Setter protected IWeightNoise weightNoise;
|
||||||
@Builder.Default
|
@Builder.Default private @Getter @Setter LinkedHashSet<String> variables = new LinkedHashSet<>();
|
||||||
private @Getter @Setter LinkedHashSet<String> variables = new LinkedHashSet<>();
|
@Getter @Setter private IDropout dropOut;
|
||||||
private IDropout dropOut;
|
|
||||||
/** The type of the layer, basically defines the base class and its properties */
|
/** The type of the layer, basically defines the base class and its properties */
|
||||||
@Builder.Default
|
@Builder.Default @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
|
||||||
@Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN;
|
|
||||||
/**
|
/**
|
||||||
* A reference to the neural net configuration. This field is excluded from json serialization as
|
* A reference to the neural net configuration. This field is excluded from json serialization as
|
||||||
* well as from equals check to avoid circular referenced.
|
* well as from equals check to avoid circular referenced.
|
||||||
|
@ -87,8 +82,7 @@ public abstract class LayerConfiguration
|
||||||
* From an Activation, we can derive the IActivation (function) using {@link
|
* From an Activation, we can derive the IActivation (function) using {@link
|
||||||
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
|
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default @Getter @Setter private IActivation activation = Activation.IDENTITY;
|
||||||
@Getter @Setter private IActivation activation = Activation.IDENTITY;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the activation interface (function) from the activation. The activation must have been set
|
* Get the activation interface (function) from the activation. The activation must have been set
|
||||||
|
@ -293,9 +287,7 @@ public abstract class LayerConfiguration
|
||||||
|
|
||||||
public void setIUpdater(IUpdater iUpdater) {
|
public void setIUpdater(IUpdater iUpdater) {
|
||||||
log.warn(
|
log.warn(
|
||||||
"Setting an IUpdater on {} with name {} has no effect.",
|
"Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), name);
|
||||||
getClass().getSimpleName(),
|
|
||||||
getName());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -333,12 +325,29 @@ public abstract class LayerConfiguration
|
||||||
runInheritance(getNetConfiguration());
|
runInheritance(getNetConfiguration());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
public abstract static class LayerConfigurationBuilder<
|
||||||
|
C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
||||||
|
private String name;
|
||||||
|
private List<LayerConstraint> allParamConstraints;
|
||||||
|
private List<LayerConstraint> weightConstraints;
|
||||||
|
private List<LayerConstraint> biasConstraints;
|
||||||
|
private List<LayerConstraint> constraints;
|
||||||
|
private IWeightNoise weightNoise;
|
||||||
|
private LinkedHashSet<String> variables$value;
|
||||||
|
private boolean variables$set;
|
||||||
|
private IDropout dropOut;
|
||||||
|
private @NonNull LayerType type$value;
|
||||||
|
private boolean type$set;
|
||||||
|
private NeuralNetConfiguration netConfiguration;
|
||||||
|
private IActivation activation$value;
|
||||||
|
private boolean activation$set;
|
||||||
|
|
||||||
public B activation(Activation activation) {
|
public B activation(Activation activation) {
|
||||||
this.activation$value = activation;
|
this.activation$value = activation;
|
||||||
this.activation$set = true;
|
this.activation$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B activation(IActivation activation) {
|
public B activation(IActivation activation) {
|
||||||
this.activation$value = activation;
|
this.activation$value = activation;
|
||||||
this.activation$set = true;
|
this.activation$set = true;
|
||||||
|
@ -349,6 +358,7 @@ public abstract class LayerConfiguration
|
||||||
this.dropOut = new Dropout(d);
|
this.dropOut = new Dropout(d);
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B dropOut(IDropout d) {
|
public B dropOut(IDropout d) {
|
||||||
this.dropOut = d;
|
this.dropOut = d;
|
||||||
return self();
|
return self();
|
||||||
|
@ -361,6 +371,95 @@ public abstract class LayerConfiguration
|
||||||
public B constrainWeights(LayerConstraint constraint) {
|
public B constrainWeights(LayerConstraint constraint) {
|
||||||
return this.weightConstraints(List.of(constraint));
|
return this.weightConstraints(List.of(constraint));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public B name(String name) {
|
||||||
|
this.name = name;
|
||||||
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public B allParamConstraints(List<LayerConstraint> allParamConstraints) {
|
||||||
|
this.allParamConstraints = allParamConstraints;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B weightConstraints(List<LayerConstraint> weightConstraints) {
|
||||||
|
this.weightConstraints = weightConstraints;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B biasConstraints(List<LayerConstraint> biasConstraints) {
|
||||||
|
this.biasConstraints = biasConstraints;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B constraints(List<LayerConstraint> constraints) {
|
||||||
|
this.constraints = constraints;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B weightNoise(IWeightNoise weightNoise) {
|
||||||
|
this.weightNoise = weightNoise;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B variables(LinkedHashSet<String> variables) {
|
||||||
|
this.variables$value = variables;
|
||||||
|
this.variables$set = true;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B type(@NonNull LayerType type) {
|
||||||
|
this.type$value = type;
|
||||||
|
this.type$set = true;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
@JsonIgnore
|
||||||
|
public B netConfiguration(NeuralNetConfiguration netConfiguration) {
|
||||||
|
this.netConfiguration = netConfiguration;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract B self();
|
||||||
|
|
||||||
|
public abstract C build();
|
||||||
|
|
||||||
|
public String toString() {
|
||||||
|
return "LayerConfiguration.LayerConfigurationBuilder(name="
|
||||||
|
+ this.name
|
||||||
|
+ ", allParamConstraints="
|
||||||
|
+ this.allParamConstraints
|
||||||
|
+ ", weightConstraints="
|
||||||
|
+ this.weightConstraints
|
||||||
|
+ ", biasConstraints="
|
||||||
|
+ this.biasConstraints
|
||||||
|
+ ", constraints="
|
||||||
|
+ this.constraints
|
||||||
|
+ ", weightNoise="
|
||||||
|
+ this.weightNoise
|
||||||
|
+ ", variables$value="
|
||||||
|
+ this.variables$value
|
||||||
|
+ ", variables$set="
|
||||||
|
+ this.variables$set
|
||||||
|
+ ", dropOut="
|
||||||
|
+ this.dropOut
|
||||||
|
+ ", type$value="
|
||||||
|
+ this.type$value
|
||||||
|
+ ", type$set="
|
||||||
|
+ this.type$set
|
||||||
|
+ ", netConfiguration="
|
||||||
|
+ this.netConfiguration
|
||||||
|
+ ", activation$value="
|
||||||
|
+ this.activation$value
|
||||||
|
+ ", activation$set="
|
||||||
|
+ this.activation$set
|
||||||
|
+ ", variables$value="
|
||||||
|
+ this.variables$value
|
||||||
|
+ ", type$value="
|
||||||
|
+ this.type$value
|
||||||
|
+ ", activation$value="
|
||||||
|
+ this.activation$value
|
||||||
|
+ ")";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import net.brutex.ai.dnn.api.LayerType;
|
import net.brutex.ai.dnn.api.LayerType;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
|
@ -31,10 +30,8 @@ import org.deeplearning4j.nn.params.EmptyParamInitializer;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
@NoArgsConstructor
|
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
public abstract class NoParamLayer extends LayerConfiguration {
|
public abstract class NoParamLayer extends LayerConfiguration {
|
||||||
|
|
||||||
{
|
{
|
||||||
setType(LayerType.POOL);
|
setType(LayerType.POOL);
|
||||||
}
|
}
|
||||||
|
@ -68,4 +65,8 @@ public abstract class NoParamLayer extends LayerConfiguration {
|
||||||
public IUpdater getIUpdater() {
|
public IUpdater getIUpdater() {
|
||||||
return Updater.NONE.getIUpdaterWithDefaultConfig();
|
return Updater.NONE.getIUpdaterWithDefaultConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static abstract class NoParamLayerBuilder<C extends NoParamLayer, B extends NoParamLayerBuilder<C,B>>
|
||||||
|
extends LayerConfigurationBuilder<C,B>
|
||||||
|
{}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,10 +48,8 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
|
||||||
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
|
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
@NoArgsConstructor
|
|
||||||
public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -245,9 +243,10 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class AbstractSameDiffLayerBuilder<
|
public abstract static class AbstractSameDiffLayerBuilder<
|
||||||
C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>>
|
C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>>
|
||||||
extends LayerConfigurationBuilder<C, B> {
|
extends LayerConfigurationBuilder<C, B> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1
|
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1
|
||||||
* regularization coefficient for the bias.
|
* regularization coefficient for the bias.
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers.samediff;
|
package org.deeplearning4j.nn.conf.layers.samediff;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
|
@ -31,15 +31,15 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder
|
@SuperBuilder
|
||||||
public abstract class SameDiffLayer extends AbstractSameDiffLayer {
|
public abstract class SameDiffLayer extends AbstractSameDiffLayer {
|
||||||
|
@ -47,14 +47,9 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
|
||||||
/**
|
/**
|
||||||
* WeightInit, default is XAVIER.
|
* WeightInit, default is XAVIER.
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
|
||||||
protected WeightInit weightInit = WeightInit.XAVIER;
|
protected WeightInit weightInit = WeightInit.XAVIER;
|
||||||
@Builder.Default
|
|
||||||
protected Map<String,IWeightInit> paramWeightInit = new HashMap<>();
|
protected Map<String,IWeightInit> paramWeightInit = new HashMap<>();
|
||||||
|
|
||||||
protected SameDiffLayer() {
|
|
||||||
//No op constructor for Jackson
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Define the layer
|
* Define the layer
|
||||||
|
@ -100,6 +95,7 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static abstract class SameDiffLayerBuilder<C extends SameDiffLayer, B extends SameDiffLayerBuilder<C, B>> extends AbstractSameDiffLayerBuilder<C,B> {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
package org.deeplearning4j.nn.conf.layers.wrapper;
|
package org.deeplearning4j.nn.conf.layers.wrapper;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
@ -37,24 +37,12 @@ import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
|
|
||||||
@Data
|
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@SuperBuilder
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
|
|
||||||
/**
|
/** The configuration to of another layer to wrap */
|
||||||
* The configuration to of another layer to wrap
|
@Getter protected LayerConfiguration underlying;
|
||||||
*/
|
|
||||||
protected LayerConfiguration underlying;
|
|
||||||
|
|
||||||
|
|
||||||
protected BaseWrapperLayerConfiguration() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseWrapperLayerConfiguration(LayerConfiguration underlying) {
|
|
||||||
this.underlying = underlying;
|
|
||||||
this.setNetConfiguration(underlying.getNetConfiguration());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the net configuration for this configuration as well as for the underlying layer (if not
|
* Set the net configuration for this configuration as well as for the underlying layer (if not
|
||||||
|
@ -87,14 +75,6 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
return underlying.getDropOut();
|
return underlying.getDropOut();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @param activationFn
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void setActivation(IActivation activationFn) {
|
|
||||||
underlying.setActivation(activationFn);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param iDropout
|
* @param iDropout
|
||||||
*/
|
*/
|
||||||
|
@ -103,6 +83,14 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
underlying.setDropOut(iDropout);
|
underlying.setDropOut(iDropout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param activationFn
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void setActivation(IActivation activationFn) {
|
||||||
|
underlying.setActivation(activationFn);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param weightNoise
|
* @param weightNoise
|
||||||
*/
|
*/
|
||||||
|
@ -131,14 +119,6 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
return underlying.getUpdaterByParam(paramName);
|
return underlying.getUpdaterByParam(paramName);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @param iUpdater
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void setIUpdater(IUpdater iUpdater) {
|
|
||||||
underlying.setIUpdater(iUpdater);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@ -147,6 +127,14 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
return underlying.getIUpdater();
|
return underlying.getIUpdater();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param iUpdater
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void setIUpdater(IUpdater iUpdater) {
|
||||||
|
underlying.setIUpdater(iUpdater);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public ParamInitializer initializer() {
|
||||||
return WrapperLayerParamInitializer.getInstance();
|
return WrapperLayerParamInitializer.getInstance();
|
||||||
|
@ -190,5 +178,4 @@ public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
underlying.setName(layerName);
|
underlying.setName(layerName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue