Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-27 15:48:34 +02:00
parent 396dbec24e
commit 7628bbdd53
37 changed files with 938 additions and 808 deletions

View File

@ -43,7 +43,7 @@ import org.nd4j.linalg.learning.config.IUpdater;
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class ActivationLayer extends NoParamLayer { public class ActivationLayer extends NoParamLayer {
@ -133,8 +133,12 @@ public class ActivationLayer extends NoParamLayer {
public static abstract class ActivationLayerBuilder< public static abstract class ActivationLayerBuilder<
C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>> C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>>
extends NoParamLayer.NoParamLayerBuilder<C, B> { extends NoParamLayer.NoParamLayerBuilder<C, B> {
public C build() {
C l = this.initBuild(); }
private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> {
public ActivationLayer build() {
ActivationLayer l = this.initBuild();
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }

View File

@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder() @SuperBuilder
public abstract class BaseUpsamplingLayer extends NoParamLayer { public abstract class BaseUpsamplingLayer extends NoParamLayer {
/** /**

View File

@ -25,10 +25,10 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import net.brutex.ai.dnn.api.LayerType; import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -48,101 +48,117 @@ import org.nd4j.linalg.learning.regularization.Regularization;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class BatchNormalization extends FeedForwardLayer { public class BatchNormalization extends FeedForwardLayer {
/** /**
* At test time: we can use a global estimate of the mean and variance, calculated using a moving average of the * At test time: we can use a global estimate of the mean and variance, calculated using a moving
* batch means/variances. This moving average is implemented as:<br> globalMeanEstimate = decay * * average of the batch means/variances. This moving average is implemented as:<br>
* globalMeanEstimate + (1-decay) * batchMean<br> globalVarianceEstimate = decay * globalVarianceEstimate + * globalMeanEstimate = decay * globalMeanEstimate + (1-decay) * batchMean<br>
* (1-decay) * batchVariance<br> * globalVarianceEstimate = decay * globalVarianceEstimate + (1-decay) * batchVariance<br>
* *
* @param decay Decay value to use for global stats calculation * @param decay Decay value to use for global stats calculation
*/ */
@lombok.Builder.Default @lombok.Builder.Default protected double decay = 0.9;
protected double decay = 0.9;
//Note: need to set defaults here in addition to builder, in case user uses no-op constructor... // Note: need to set defaults here in addition to builder, in case user uses no-op constructor...
/** /**
* Epsilon value for batch normalization; small floating point value added to variance (algorithm 1 in <a * Epsilon value for batch normalization; small floating point value added to variance (algorithm
* href="https://arxiv.org/pdf/1502.03167v3.pdf">https://arxiv.org/pdf/1502.03167v3.pdf</a>) to reduce/avoid * 1 in <a
* underflow issues.<br> Default: 1e-5 * href="https://arxiv.org/pdf/1502.03167v3.pdf">https://arxiv.org/pdf/1502.03167v3.pdf</a>) to
* reduce/avoid underflow issues.<br>
* Default: 1e-5
* *
* @param eps Epsilon values to use * @param eps Epsilon values to use
*/ */
@lombok.Builder.Default protected double eps = 1e-5; @lombok.Builder.Default protected double eps = 1e-5;
/** /**
* If doing minibatch training or not. Default: true. Under most circumstances, this should be set to true. If * If doing minibatch training or not. Default: true. Under most circumstances, this should be set
* doing full batch training (i.e., all examples in a single DataSet object - very small data sets) then this * to true. If doing full batch training (i.e., all examples in a single DataSet object - very
* should be set to false. Affects how global mean/variance estimates are calculated. * small data sets) then this should be set to false. Affects how global mean/variance estimates
* are calculated.
* *
* @param minibatch Minibatch parameter * @param minibatch Minibatch parameter
*/ */
@lombok.Builder.Default protected boolean isMinibatch = true; @lombok.Builder.Default protected boolean isMinibatch = true;
/** /**
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}.
* 1.0 * Value is not used otherwise.<br>
* Default: 1.0
* *
* @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode * @param gamma Gamma parameter for all activations, used only with locked gamma/beta
* configuration mode
*/ */
@lombok.Builder.Default protected double gamma = 1.0; @lombok.Builder.Default protected double gamma = 1.0;
/** /**
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}.
* 0.0 * Value is not used otherwise.<br>
* Default: 0.0
* *
* @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode * @param beta Beta parameter for all activations, used only with locked gamma/beta configuration
* mode
*/ */
@lombok.Builder.Default protected double beta = 0.0; @lombok.Builder.Default protected double beta = 0.0;
/** /**
* Set constraints to be applied to the beta parameter of this batch normalisation layer. Default: no * Set constraints to be applied to the beta parameter of this batch normalisation layer. Default:
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters, * no constraints.<br>
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
* been updated. * regularization, etc). These constraints are applied at each iteration, after the parameters
* * have been updated.
*/ */
protected List<LayerConstraint> betaConstraints; protected List<LayerConstraint> betaConstraints;
/** /**
* Set constraints to be applied to the gamma parameter of this batch normalisation layer. Default: no * Set constraints to be applied to the gamma parameter of this batch normalisation layer.
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters, * Default: no constraints.<br>
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
* been updated. * regularization, etc). These constraints are applied at each iteration, after the parameters
* * have been updated.
*/ */
protected List<LayerConstraint> gammaConstraints; protected List<LayerConstraint> gammaConstraints;
/** /**
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper
* If set to false, an exception in the helper will be propagated back to the user. If true, the built-in * implementation be allowed? If set to false, an exception in the helper will be propagated back
* (non-MKL/CuDNN) implementation for BatchNormalizationLayer will be used * to the user. If true, the built-in (non-MKL/CuDNN) implementation for BatchNormalizationLayer
* will be used
* *
* @param allowFallback Whether fallback to non-CuDNN implementation should be used * @param allowFallback Whether fallback to non-CuDNN implementation should be used
*/ */
@lombok.Builder.Default protected boolean cudnnAllowFallback = true; @lombok.Builder.Default protected boolean cudnnAllowFallback = true;
/** /**
* How should the moving average of variance be stored? Two different parameterizations are supported. * How should the moving average of variance be stored? Two different parameterizations are
* useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is stored directly as * supported. useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is
* variable<br> useLogStd(true): (Default) variance is stored as log10(stdev)<br> The motivation here is for * stored directly as variable<br>
* numerical stability (FP16 etc) and also distributed training: storing the variance directly can cause * useLogStd(true): (Default) variance is stored as log10(stdev)<br>
* numerical issues. For example, a standard deviation of 1e-3 (something that could be encountered in practice) * The motivation here is for numerical stability (FP16 etc) and also distributed training:
* gives a variance of 1e-6, which can be problematic for 16-bit floating point * storing the variance directly can cause numerical issues. For example, a standard deviation of
* * 1e-3 (something that could be encountered in practice) gives a variance of 1e-6, which can be
* How should the moving average of variance be stored? Two different parameterizations are supported. * problematic for 16-bit floating point
* useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is stored directly as *
* variable<br> useLogStd(true): (Default) variance is stored as log10(stdev)<br> The motivation here is for * <p>How should the moving average of variance be stored? Two different parameterizations are
* numerical stability (FP16 etc) and also distributed training: storing the variance directly can cause * supported. useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is
* numerical issues. For example, a standard deviation of 1e-3 (something that could be encountered in practice) * stored directly as variable<br>
* gives a variance of 1e-6, which can be problematic for 16-bit floating point * useLogStd(true): (Default) variance is stored as log10(stdev)<br>
* The motivation here is for numerical stability (FP16 etc) and also distributed training:
* storing the variance directly can cause numerical issues. For example, a standard deviation of
* 1e-3 (something that could be encountered in practice) gives a variance of 1e-6, which can be
* problematic for 16-bit floating point
*/ */
@lombok.Builder.Default protected boolean useLogStd = false; //Default for deserialized models (1.0.0-beta3) and earlier: store variance as variance. Post 1.0.0-beta3: use log stdev instead @lombok.Builder.Default
protected boolean useLogStd =
false; // Default for deserialized models (1.0.0-beta3) and earlier: store variance as
// variance. Post 1.0.0-beta3: use log stdev instead
/** /**
* Set the input and output array data format. Defaults to NCHW format - i.e., channels first. * Set the input and output array data format. Defaults to NCHW format - i.e., channels first. See
* See {@link CNN2DFormat} for more details * {@link CNN2DFormat} for more details
*
* @param format Format to use * @param format Format to use
*/ */
@lombok.Builder.Default protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier @lombok.Builder.Default
protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // Default for deserialized models, 1.0.0-beta6 and earlier
private boolean lockGammaBeta; private boolean lockGammaBeta;
@ -151,14 +167,11 @@ public class BatchNormalization extends FeedForwardLayer {
} }
public static BatchNormalizationBuilder<?, ?> builder(double gamma, double beta) { public static BatchNormalizationBuilder<?, ?> builder(double gamma, double beta) {
return innerBuilder() return innerBuilder().gamma(gamma).beta(beta);
.gamma(gamma)
.beta(beta);
} }
public static BatchNormalizationBuilder<?, ?> builder(boolean lockGammaBeta) { public static BatchNormalizationBuilder<?, ?> builder(boolean lockGammaBeta) {
return innerBuilder() return innerBuilder().lockGammaBeta(lockGammaBeta);
.lockGammaBeta(lockGammaBeta);
} }
@Override @Override
@ -168,8 +181,13 @@ public class BatchNormalization extends FeedForwardLayer {
} }
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
this.setNetConfiguration(conf); this.setNetConfiguration(conf);
LayerValidation.assertNOutSet("BatchNormalization", getName(), layerIndex, getNOut()); LayerValidation.assertNOutSet("BatchNormalization", getName(), layerIndex, getNOut());
runInheritance(); runInheritance();
@ -196,20 +214,24 @@ public class BatchNormalization extends FeedForwardLayer {
if (inputType == null) { if (inputType == null) {
throw new IllegalStateException( throw new IllegalStateException(
"Invalid input type: Batch norm layer expected input of type CNN, got null for layer \"" "Invalid input type: Batch norm layer expected input of type CNN, got null for layer \""
+ getName() + "\""); + getName()
+ "\"");
} }
//Can handle CNN, flat CNN, CNN3D or FF input formats only // Can handle CNN, flat CNN, CNN3D or FF input formats only
switch (inputType.getType()) { switch (inputType.getType()) {
case FF: case FF:
case CNN: case CNN:
case CNNFlat: case CNNFlat:
case CNN3D: case CNN3D:
return inputType; //OK return inputType; // OK
default: default:
throw new IllegalStateException( throw new IllegalStateException(
"Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " "Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got "
+ inputType + " for layer index " + layerIndex + ", layer name = " + inputType
+ " for layer index "
+ layerIndex
+ ", layer name = "
+ getName()); + getName());
} }
} }
@ -233,7 +255,10 @@ public class BatchNormalization extends FeedForwardLayer {
default: default:
throw new IllegalStateException( throw new IllegalStateException(
"Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " "Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got "
+ inputType + " for layer " + getName() + "\""); + inputType
+ " for layer "
+ getName()
+ "\"");
} }
nOut = nIn; nOut = nIn;
} }
@ -252,8 +277,9 @@ public class BatchNormalization extends FeedForwardLayer {
} }
@Override @Override
public List<Regularization> getRegularizationByParam(String paramName){ public List<Regularization> getRegularizationByParam(String paramName) {
//Don't regularize batch norm params: similar to biases in the sense that there are not many of them... // Don't regularize batch norm params: similar to biases in the sense that there are not many of
// them...
return null; return null;
} }
@ -276,7 +302,7 @@ public class BatchNormalization extends FeedForwardLayer {
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
//TODO CuDNN helper etc // TODO CuDNN helper etc
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
int updaterStateSize = 0; int updaterStateSize = 0;
@ -285,50 +311,63 @@ public class BatchNormalization extends FeedForwardLayer {
updaterStateSize += getUpdaterByParam(s).stateSize(nOut); updaterStateSize += getUpdaterByParam(s).stateSize(nOut);
} }
//During forward pass: working memory size approx. equal to 2x input size (copy ops, etc) // During forward pass: working memory size approx. equal to 2x input size (copy ops, etc)
val inferenceWorkingSize = 2 * inputType.arrayElementsPerExample(); val inferenceWorkingSize = 2 * inputType.arrayElementsPerExample();
//During training: we calculate mean and variance... result is equal to nOut, and INDEPENDENT of minibatch size // During training: we calculate mean and variance... result is equal to nOut, and INDEPENDENT
// of minibatch size
val trainWorkFixed = 2 * nOut; val trainWorkFixed = 2 * nOut;
//During backprop: multiple working arrays... output size, 2 * output size (indep. of example size), // During backprop: multiple working arrays... output size, 2 * output size (indep. of example
val trainWorkingSizePerExample = inferenceWorkingSize //Inference during backprop // size),
+ (outputType.arrayElementsPerExample() + 2 * nOut); //Backprop gradient calculation val trainWorkingSizePerExample =
inferenceWorkingSize // Inference during backprop
+ (outputType.arrayElementsPerExample() + 2 * nOut); // Backprop gradient calculation
return new LayerMemoryReport.Builder(name, BatchNormalization.class, inputType, outputType) return new LayerMemoryReport.Builder(name, BatchNormalization.class, inputType, outputType)
.standardMemory(numParams, updaterStateSize) .standardMemory(numParams, updaterStateSize)
.workingMemory(0, 0, trainWorkFixed, trainWorkingSizePerExample) //No additional memory (beyond activations) for inference .workingMemory(
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching 0,
0,
trainWorkFixed,
trainWorkingSizePerExample) // No additional memory (beyond activations) for inference
.cacheMemory(
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
.build(); .build();
} }
@Override @Override
public boolean isPretrainParam(String paramName) { public boolean isPretrainParam(String paramName) {
return false; //No pretrain params in BN return false; // No pretrain params in BN
} }
public static abstract class BatchNormalizationBuilder<C extends BatchNormalization, B extends BatchNormalizationBuilder<C, B>> extends FeedForwardLayerBuilder<C, B> { private static final class BatchNormalizationBuilderImpl
public C build() { extends BatchNormalizationBuilder<BatchNormalization, BatchNormalizationBuilderImpl> {
C l = this.initBuild(); public BatchNormalization build() {
BatchNormalization l = new BatchNormalization(this);
l.setType(LayerType.BN); l.setType(LayerType.BN);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
}
public abstract static class BatchNormalizationBuilder<
C extends BatchNormalization, B extends BatchNormalizationBuilder<C, B>>
extends FeedForwardLayerBuilder<C, B> {
public B helperAllowFallback(boolean b) { public B helperAllowFallback(boolean b) {
this.cudnnAllowFallback$value = b; this.cudnnAllowFallback$value = b;
this.cudnnAllowFallback$set = true; this.cudnnAllowFallback$set = true;
return self(); return self();
} }
public B constrainBeta(LayerConstraint ... constraints) { public B constrainBeta(LayerConstraint... constraints) {
this.betaConstraints = List.of(constraints); this.betaConstraints = List.of(constraints);
return self(); return self();
} }
public B constrainGamma(LayerConstraint ... constraints) {
public B constrainGamma(LayerConstraint... constraints) {
this.gammaConstraints = List.of(constraints); this.gammaConstraints = List.of(constraints);
return self(); return self();
} }
} }
} }

View File

@ -38,9 +38,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class CapsuleLayer extends SameDiffLayer { public class CapsuleLayer extends SameDiffLayer {
private static final String WEIGHT_PARAM = "weight"; private static final String WEIGHT_PARAM = "weight";
@ -78,6 +77,18 @@ public class CapsuleLayer extends SameDiffLayer {
*/ */
@Builder.Default @Getter @Setter private int routings = 3; @Builder.Default @Getter @Setter private int routings = 3;
public static CapsuleLayerBuilder<?,?> builder() {
return innerBuilder()
;
}
public static CapsuleLayerBuilder<?,?> builder(int capsules, int capsulesDim, int routings) {
return innerBuilder()
.capsules(capsules)
.capsuleDimensions(capsulesDim)
.routings(routings);
}
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
if(inputType == null || inputType.getType() != Type.RNN) { if(inputType == null || inputType.getType() != Type.RNN) {
@ -185,16 +196,6 @@ public class CapsuleLayer extends SameDiffLayer {
return InputType.recurrent(capsules, capsuleDimensions); return InputType.recurrent(capsules, capsuleDimensions);
} }
public static CapsuleLayerBuilder<?,?> builder() {
return innerBuilder()
;
}
public static CapsuleLayerBuilder<?,?> builder(int capsules, int capsulesDim, int routings) {
return innerBuilder()
.capsules(capsules)
.capsuleDimensions(capsulesDim)
.routings(routings);
}
public static abstract class CapsuleLayerBuilder< public static abstract class CapsuleLayerBuilder<
C extends CapsuleLayer, B extends CapsuleLayerBuilder<C, B>> C extends CapsuleLayer, B extends CapsuleLayerBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> { extends SameDiffLayerBuilder<C, B> {
@ -215,35 +216,37 @@ public class CapsuleLayer extends SameDiffLayer {
} }
public C build() {
C l = this.initBuild(); }
if (capsules <= 0 || capsuleDimensions <= 0 || routings$value <= 0) {
private static final class CapsuleLayerBuilderImpl extends CapsuleLayerBuilder<CapsuleLayer, CapsuleLayerBuilderImpl> {
public CapsuleLayer build() {
CapsuleLayer l = new CapsuleLayer(this);
if (l.getCapsules() <= 0 || l.getCapsuleDimensions() <= 0 || l.getRoutings() <= 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Invalid configuration for Capsule ILayer (layer name = \"" "Invalid configuration for Capsule ILayer (layer name = \""
+ l.getName() + l.getName()
+ "\"):" + "\"):"
+ " capsules, capsuleDimensions, and routings must be > 0. Got: " + " capsules, capsuleDimensions, and routings must be > 0. Got: "
+ capsules + l.getCapsules()
+ ", " + ", "
+ capsuleDimensions + l.getCapsuleDimensions()
+ ", " + ", "
+ routings$value); + l.getRoutings());
} }
if (inputCapsules$value < 0 || inputCapsuleDimensions$value < 0) { if (l.getInputCapsules() < 0 || l.getInputCapsuleDimensions() < 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Invalid configuration for Capsule ILayer (layer name = \"" "Invalid configuration for Capsule ILayer (layer name = \""
+ l.getName() + l.getName()
+ "\"):" + "\"):"
+ " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: " + " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: "
+ inputCapsules$value + l.getInputCapsules()
+ ", " + ", "
+ inputCapsuleDimensions$value); + l.getInputCapsuleDimensions() );
} }
return l; return l;
} }
} }
} }

View File

@ -20,6 +20,8 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Collection;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -30,36 +32,21 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.CenterLossParamInitializer; import org.deeplearning4j.nn.params.CenterLossParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import java.util.Collection;
import java.util.Map;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class CenterLossOutputLayer extends BaseOutputLayer { public class CenterLossOutputLayer extends BaseOutputLayer {
@Builder.Default protected double alpha= 0.805; @Builder.Default protected double alpha= 0.805;
@Builder.Default protected double lambda = 2e-4; @Builder.Default protected double lambda = 2e-4;
@Builder.Default protected boolean gradientCheck = false; @Builder.Default protected boolean gradientCheck = false;
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
BaseOutputLayerBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
}
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
@ -91,7 +78,6 @@ public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOu
return getUpdater(); return getUpdater();
} }
public boolean getGradientCheck() { public boolean getGradientCheck() {
return gradientCheck; return gradientCheck;
} }
@ -135,6 +121,24 @@ public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOu
.build(); .build();
} }
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
BaseOutputLayerBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
}
private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer,
CenterLossOutputLayerBuilderImpl> {
public CenterLossOutputLayer build() {
CenterLossOutputLayer l = new CenterLossOutputLayer(this);
l.initializeConstraints();
return l;
}
}
} }

View File

@ -29,6 +29,6 @@ import lombok.experimental.SuperBuilder;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class Convolution1D extends Convolution1DLayer { public class Convolution1D extends Convolution1DLayer {
} }

View File

@ -47,9 +47,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class Convolution1DLayer extends ConvolutionLayer { public class Convolution1DLayer extends ConvolutionLayer {
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
/** /**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br> * See {@link CNN2DFormat} for more details.<br>
@ -60,6 +59,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
@Builder.Default @Builder.Default
protected CNN2DFormat dataFormat = protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons CNN2DFormat.NCHW; // default value for legacy serialization reasons
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
/** /**
* Size of the convolution * Size of the convolution
* *
@ -183,17 +183,20 @@ public class Convolution1DLayer extends ConvolutionLayer {
return true; return true;
} }
public static abstract class Convolution1DLayerBuilder< private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> {
C extends ConvolutionLayer, B extends Convolution1DLayerBuilder<C, B>> public ConvolutionLayer build() {
extends ConvolutionLayerBuilder<C, B> { ConvolutionLayer l = initBuild();
public C build() { ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
C l = initBuild();
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), padding$value);
ConvolutionUtils.validateCnnKernelStridePadding( ConvolutionUtils.validateCnnKernelStridePadding(
kernelSize$value, stride$value, padding$value); l.getKernelSize(), l.getStride(), l.getPadding());
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
}
public static abstract class Convolution1DLayerBuilder<
C extends ConvolutionLayer, B extends Convolution1DLayerBuilder<C, B>>
extends ConvolutionLayerBuilder<C, B> {
public B kernelSize(int @NonNull ... kernelSize) { public B kernelSize(int @NonNull ... kernelSize) {
this.kernelSize$value[0] = ValidationUtils.validate1NonNegative(kernelSize, "kernelSize")[0]; this.kernelSize$value[0] = ValidationUtils.validate1NonNegative(kernelSize, "kernelSize")[0];

View File

@ -30,6 +30,6 @@ import lombok.experimental.SuperBuilder;
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class Convolution2D extends ConvolutionLayer { public class Convolution2D extends ConvolutionLayer {
} }

View File

@ -40,7 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild") @SuperBuilder(builderMethodName = "innerBuilder")
public class Convolution3D extends ConvolutionLayer { public class Convolution3D extends ConvolutionLayer {
/** /**
@ -235,17 +235,20 @@ public class Convolution3D extends ConvolutionLayer {
NDHWC NDHWC
} }
private static final class Convolution3DBuilderImpl extends Convolution3DBuilder<Convolution3D, Convolution3DBuilderImpl> {
public Convolution3D build() {
Convolution3D l = new Convolution3D(this);
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
Convolution3DUtils.validateCnn3DKernelStridePadding(l.getKernelSize(), l.getStride(), l.getPadding());
return l;
}
}
// public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) { // public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
// sup/er(kernelSize, stride, padding, dilation, 3); // sup/er(kernelSize, stride, padding, dilation, 3);
public static abstract class Convolution3DBuilder< public static abstract class Convolution3DBuilder<
C extends Convolution3D, B extends Convolution3DBuilder<C, B>> C extends Convolution3D, B extends Convolution3DBuilder<C, B>>
extends ConvolutionLayer.ConvolutionLayerBuilder<C, B> { extends ConvolutionLayer.ConvolutionLayerBuilder<C, B> {
public C build() {
ConvolutionUtils.validateConvolutionModePadding(convolutionMode$value, padding);
Convolution3DUtils.validateCnn3DKernelStridePadding(kernelSize, stride, padding);
C l = initBuild();
return l;
}
@Override // TODO we can use the parent builder and do not need to redefine the variables. @Override // TODO we can use the parent builder and do not need to redefine the variables.
// Validation can be done in override function! // Validation can be done in override function!

View File

@ -48,7 +48,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*/ */
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild") @SuperBuilder(builderMethodName = "innerBuilder")
public class ConvolutionLayer extends FeedForwardLayer { public class ConvolutionLayer extends FeedForwardLayer {
/** /**
* Size of the convolution rows/columns * Size of the convolution rows/columns
@ -397,48 +397,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
return self(); return self();
} }
public C build() {
ConvolutionUtils.validateConvolutionModePadding(convolutionMode$value, padding$value);
ConvolutionUtils.validateCnnKernelStridePadding(
kernelSize$value, stride$value, padding$value);
if (kernelSize$value.length != convolutionDim$value) {
throw new IllegalArgumentException(
"Kernel argument should be a "
+ convolutionDim$value
+ "d array, got "
+ Arrays.toString(kernelSize$value));
}
if (stride$value.length != convolutionDim$value) {
throw new IllegalArgumentException(
"Strides argument should be a "
+ convolutionDim$value
+ "d array, got "
+ Arrays.toString(stride$value));
}
if (padding$value.length != convolutionDim$value) {
throw new IllegalArgumentException(
"Padding argument should be a "
+ convolutionDim$value
+ "d array, got "
+ Arrays.toString(padding$value));
}
if (dilation$value.length != convolutionDim$value) {
throw new IllegalArgumentException(
"Dilation argument should be a "
+ convolutionDim$value
+ "d array, got "
+ Arrays.toString(dilation$value));
}
C l = initBuild();
l.setType(LayerType.CONV);
l.initializeConstraints();
return l;
}
/** /**
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper
@ -454,4 +413,47 @@ public class ConvolutionLayer extends FeedForwardLayer {
return self(); return self();
} }
} }
private static final class ConvolutionLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, ConvolutionLayerBuilderImpl> {
public ConvolutionLayer build() {
ConvolutionLayer l = new ConvolutionLayer(this);
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
ConvolutionUtils.validateCnnKernelStridePadding(
l.getKernelSize(), l.getStride(), l.getPadding());
if (l.getKernelSize().length != l.getConvolutionDim()) {
throw new IllegalArgumentException(
"Kernel argument should be a "
+ l.getConvolutionDim()
+ "d array, got "
+ Arrays.toString(l.getKernelSize()));
}
if (l.getStride().length != l.getConvolutionDim()) {
throw new IllegalArgumentException(
"Strides argument should be a "
+ l.getConvolutionDim()
+ "d array, got "
+ Arrays.toString(l.getStride()));
}
if (l.getPadding().length != l.getConvolutionDim()) {
throw new IllegalArgumentException(
"Padding argument should be a "
+ l.getConvolutionDim()
+ "d array, got "
+ Arrays.toString(l.getPadding()));
}
if (l.getDilation().length != l.getConvolutionDim()) {
throw new IllegalArgumentException(
"Dilation argument should be a "
+ l.getConvolutionDim()
+ "d array, got "
+ Arrays.toString(l.getDilation()));
}
l.setType(LayerType.CONV);
l.initializeConstraints();
return l;
}
}
} }

View File

@ -46,7 +46,7 @@ import java.util.Map;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuild") @SuperBuilder(builderMethodName = "innerBuilder")
public class Deconvolution2D extends ConvolutionLayer { public class Deconvolution2D extends ConvolutionLayer {
@ -57,12 +57,15 @@ private CNN2DFormat format = CNN2DFormat.NCHW;
return false; return false;
} }
public static abstract class Deconvolution2DBuilder<C extends Deconvolution2D, B extends Deconvolution2DBuilder<C, B>> extends ConvolutionLayerBuilder<C, B> { private static final class Deconvolution2DBuilderImpl extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
public C build() { public Deconvolution2D build() {
C l = initBuild(); Deconvolution2D l = new Deconvolution2D(this);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
}
public static abstract class Deconvolution2DBuilder<C extends Deconvolution2D, B extends Deconvolution2DBuilder<C, B>> extends ConvolutionLayerBuilder<C, B> {
@Override @Override

View File

@ -44,7 +44,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class Deconvolution3D extends ConvolutionLayer { public class Deconvolution3D extends ConvolutionLayer {
/** /**
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more
@ -56,6 +56,15 @@ public class Deconvolution3D extends ConvolutionLayer {
private Convolution3D.DataFormat dataFormat = private Convolution3D.DataFormat dataFormat =
Convolution3D.DataFormat.NCDHW; // in libnd4j: 1 - NCDHW, 0 - NDHWC Convolution3D.DataFormat.NCDHW; // in libnd4j: 1 - NCDHW, 0 - NDHWC
public static Deconvolution3DBuilder<?, ?> builder() {
return innerBuilder()
.kernelSize(new int[] {2, 2, 2})
.stride(new int[] {1, 1, 1})
.padding(new int[] {0, 0, 0})
.dilation(new int[] {1, 1, 1})
.convolutionDim(3);
}
protected boolean allowCausal() { protected boolean allowCausal() {
// Causal convolution - allowed for 1D only // Causal convolution - allowed for 1D only
return false; return false;
@ -69,13 +78,13 @@ public class Deconvolution3D extends ConvolutionLayer {
public Deconvolution3D clone() { public Deconvolution3D clone() {
Deconvolution3D clone = (Deconvolution3D) super.clone(); Deconvolution3D clone = (Deconvolution3D) super.clone();
if (clone.getKernelSize() != null) { if (clone.getKernelSize() != null) {
clone.setKernelSize( clone.getKernelSize().clone()); clone.setKernelSize(clone.getKernelSize().clone());
} }
if (clone.getStride() != null) { if (clone.getStride() != null) {
clone.setStride( clone.getStride().clone()); clone.setStride(clone.getStride().clone());
} }
if (clone.getPadding() != null) { if (clone.getPadding() != null) {
clone.setPadding( clone.getPadding().clone()); clone.setPadding(clone.getPadding().clone());
} }
return clone; return clone;
} }
@ -134,6 +143,11 @@ public class Deconvolution3D extends ConvolutionLayer {
} }
} }
// private int[] kernelSize;
// private int[] stride;
// private int[] padding;
// private int[] dilation;
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) { if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
@ -158,29 +172,16 @@ public class Deconvolution3D extends ConvolutionLayer {
Deconvolution3DLayer.class); Deconvolution3DLayer.class);
} }
//private int[] kernelSize; public abstract static class Deconvolution3DBuilder<
//private int[] stride;
//private int[] padding;
//private int[] dilation;
public static abstract class Deconvolution3DBuilder<
C extends Deconvolution3D, B extends Deconvolution3DBuilder<C, B>> C extends Deconvolution3D, B extends Deconvolution3DBuilder<C, B>>
extends ConvolutionLayerBuilder<C, B> { extends ConvolutionLayerBuilder<C, B> {}
public C build() {
C l = initBuild(); private static final class Deconvolution3DBuilderImpl
extends Deconvolution3DBuilder<Deconvolution3D, Deconvolution3DBuilderImpl> {
public Deconvolution3D build() {
Deconvolution3D l = new Deconvolution3D(this);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
} }
public static Deconvolution3DBuilder<?,?> builder() {
return innerBuilder()
.kernelSize(new int[] {2, 2, 2})
.stride(new int[] {1, 1, 1})
.padding(new int[] {0, 0, 0})
.dilation(new int[] {1, 1, 1})
.convolutionDim(3);
}
} }

View File

@ -40,7 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class DepthwiseConvolution2D extends ConvolutionLayer { public class DepthwiseConvolution2D extends ConvolutionLayer {
/** /**
* Set channels multiplier for depth-wise convolution * Set channels multiplier for depth-wise convolution
@ -145,21 +145,25 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
} }
public abstract static class DepthwiseConvolution2DBuilder< private static final class DepthwiseConvolution2DBuilderImpl extends DepthwiseConvolution2DBuilder<DepthwiseConvolution2D, DepthwiseConvolution2DBuilderImpl> {
C extends DepthwiseConvolution2D, B extends DepthwiseConvolution2DBuilder<C, B>> public DepthwiseConvolution2D build() {
extends ConvolutionLayerBuilder<C, B> { DepthwiseConvolution2D l = new DepthwiseConvolution2D(this);
public C build() {
Preconditions.checkState( Preconditions.checkState(
depthMultiplier$value > 0, l.getDepthMultiplier() > 0,
"Depth multiplier must be > 0, got %s", "Depth multiplier must be > 0, got %s",
depthMultiplier$value); l.getDepthMultiplier());
C l = this.initBuild();
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding()); ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
ConvolutionUtils.validateCnnKernelStridePadding( ConvolutionUtils.validateCnnKernelStridePadding(
l.getKernelSize(), l.getStride(), l.getPadding()); l.getKernelSize(), l.getStride(), l.getPadding());
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
}
public abstract static class DepthwiseConvolution2DBuilder<
C extends DepthwiseConvolution2D, B extends DepthwiseConvolution2DBuilder<C, B>>
extends ConvolutionLayerBuilder<C, B> {
@Override @Override
public B kernelSize(int... kernelSize) { public B kernelSize(int... kernelSize) {

View File

@ -20,6 +20,8 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Collection;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.Accessors; import lombok.experimental.Accessors;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
@ -30,95 +32,47 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer; import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class EmbeddingLayer extends FeedForwardLayer { public class EmbeddingLayer extends FeedForwardLayer {
/** /**
* If true: include bias parameters in the layer. False (default): no bias. * If true: include bias parameters in the layer. False (default): no bias.
*
* @param hasBias If true: include bias parameters in this layer * @param hasBias If true: include bias parameters in this layer
*/ */
@Accessors @Builder.Default @Accessors @Builder.Default private boolean hasBias = false;
private boolean hasBias = false;
/** /**
*Default to Identity activation - i.e., don't inherit. * Default to Identity activation - i.e., don't inherit. For example, if user sets ReLU as global
* For example, if user sets ReLU as global default, they very likely don't intend to use it for Embedding layer also * default, they very likely don't intend to use it for Embedding layer also
*
*/ */
public static EmbeddingLayerBuilder<?, ?> builder() { public static EmbeddingLayerBuilder<?, ?> builder() {
return innerBuilder() return innerBuilder().activation(Activation.IDENTITY);
.activation(Activation.IDENTITY);
} }
public static abstract class EmbeddingLayerBuilder<C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C,B>>
extends FeedForwardLayerBuilder<C,B>{
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
/**
* Weight initialization scheme to use, for initial weight values
*
* @param weightInit
* @see WeightInit
*/
@Override @Override
public B weightInit(WeightInit weightInit) { public Layer instantiate(
if(weightInit.getWeightInitFunction() instanceof WeightInitEmbedding){ NeuralNetConfiguration conf,
long[] shape = ((WeightInitEmbedding) weightInit.getWeightInitFunction()).shape(); Collection<TrainingListener> trainingListeners,
nIn(shape[0]); int layerIndex,
nOut(shape[1]); INDArray layerParamsView,
} boolean initializeParams,
super.weightInit(weightInit); DataType networkDataType) {
return self();
}
/**
* Initialize the embedding layer using values from the specified array. Note that the array should have shape
* [vocabSize, vectorSize]. After copying values from the array to initialize the network parameters, the input
* array will be discarded (so that, if necessary, it can be garbage collected)
*
* @param vectors Vectors to initialize the embedding layer with
*/
public B weightInit(INDArray vectors){
weightInit(new ArrayEmbeddingInitializer(vectors));
return self();
}
/**
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec instance
*
* @param embeddingInitializer Source of the embedding layer weights
*/
public B weightInit(EmbeddingInitializer embeddingInitializer) {
var weightIn = new WeightInitEmbedding(embeddingInitializer);
super.weightInit(weightIn);
return self();
}
}
@Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret = org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret =
new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(
lconf, networkDataType);
runInheritance(); runInheritance();
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
@ -137,20 +91,77 @@ public class EmbeddingLayer extends FeedForwardLayer {
@Override @Override
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
//Basically a dense layer, but no dropout is possible here, and no epsilons // Basically a dense layer, but no dropout is possible here, and no epsilons
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
val actElementsPerEx = outputType.arrayElementsPerExample(); val actElementsPerEx = outputType.arrayElementsPerExample();
val numParams = initializer().numParams(this); val numParams = initializer().numParams(this);
val updaterStateSize = (int) getIUpdater().stateSize(numParams); val updaterStateSize = (int) getIUpdater().stateSize(numParams);
//Embedding layer does not use caching. // Embedding layer does not use caching.
//Inference: no working memory - just activations (pullRows) // Inference: no working memory - just activations (pullRows)
//Training: preout op, the only in-place ops on epsilon (from layer above) + assign ops // Training: preout op, the only in-place ops on epsilon (from layer above) + assign ops
return new LayerMemoryReport.Builder(name, EmbeddingLayer.class, inputType, outputType) return new LayerMemoryReport.Builder(name, EmbeddingLayer.class, inputType, outputType)
.standardMemory(numParams, updaterStateSize).workingMemory(0, 0, 0, actElementsPerEx) .standardMemory(numParams, updaterStateSize)
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching .workingMemory(0, 0, 0, actElementsPerEx)
.cacheMemory(
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
.build(); .build();
} }
private static final class EmbeddingLayerBuilderImpl
extends EmbeddingLayerBuilder<EmbeddingLayer, EmbeddingLayerBuilderImpl> {
public EmbeddingLayer build() {
EmbeddingLayer l = new EmbeddingLayer(this);
l.initializeConstraints();
return l;
}
}
public abstract static class EmbeddingLayerBuilder<
C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C, B>>
extends FeedForwardLayerBuilder<C, B> {
/**
* Weight initialization scheme to use, for initial weight values
*
* @param weightInit
* @see WeightInit
*/
@Override
public B weightInit(WeightInit weightInit) {
if (weightInit.getWeightInitFunction() instanceof WeightInitEmbedding) {
long[] shape = ((WeightInitEmbedding) weightInit.getWeightInitFunction()).shape();
nIn(shape[0]);
nOut(shape[1]);
}
super.weightInit(weightInit);
return self();
}
/**
* Initialize the embedding layer using values from the specified array. Note that the array
* should have shape [vocabSize, vectorSize]. After copying values from the array to initialize
* the network parameters, the input array will be discarded (so that, if necessary, it can be
* garbage collected)
*
* @param vectors Vectors to initialize the embedding layer with
*/
public B weightInit(INDArray vectors) {
weightInit(new ArrayEmbeddingInitializer(vectors));
return self();
}
/**
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec
* instance
*
* @param embeddingInitializer Source of the embedding layer weights
*/
public B weightInit(EmbeddingInitializer embeddingInitializer) {
var weightIn = new WeightInitEmbedding(embeddingInitializer);
super.weightInit(weightIn);
return self();
}
}
} }

View File

@ -46,7 +46,7 @@ import java.util.Map;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class EmbeddingSequenceLayer extends FeedForwardLayer { public class EmbeddingSequenceLayer extends FeedForwardLayer {
/** /**
* Set input sequence length for this embedding layer. * Set input sequence length for this embedding layer.
@ -70,13 +70,16 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
@Builder.Default private boolean inferInputLength = false; // use input length as provided by input data @Builder.Default private boolean inferInputLength = false; // use input length as provided by input data
@Builder.Default private RNNFormat outputDataFormat = RNNFormat.NCW; //Default value for older deserialized models @Builder.Default private RNNFormat outputDataFormat = RNNFormat.NCW; //Default value for older deserialized models
public static abstract class EmbeddingSequenceLayerBuilder<C extends EmbeddingSequenceLayer, B extends EmbeddingSequenceLayerBuilder<C, B>> private static final class EmbeddingSequenceLayerBuilderImpl extends EmbeddingSequenceLayerBuilder<EmbeddingSequenceLayer, EmbeddingSequenceLayerBuilderImpl> {
extends FeedForwardLayerBuilder<C, B> { public EmbeddingSequenceLayer build() {
public C build() { EmbeddingSequenceLayer l = new EmbeddingSequenceLayer(this);
C l = initBuild();
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
}
public static abstract class EmbeddingSequenceLayerBuilder<C extends EmbeddingSequenceLayer, B extends EmbeddingSequenceLayerBuilder<C, B>>
extends FeedForwardLayerBuilder<C, B> {
public B weightInit(IWeightInit weightInit){ public B weightInit(IWeightInit weightInit){
if(weightInit instanceof WeightInitEmbedding){ if(weightInit instanceof WeightInitEmbedding){

View File

@ -20,6 +20,7 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.*;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -36,23 +37,21 @@ import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.*;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Deprecated @Deprecated
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class GravesBidirectionalLSTM extends BaseRecurrentLayer { public class GravesBidirectionalLSTM extends BaseRecurrentLayer {
public static abstract class GravesBidirectionalLSTMBuilder<C extends GravesBidirectionalLSTM, B extends /**
GravesBidirectionalLSTMBuilder<C, B>> extends BaseRecurrentLayerBuilder<C, B> { * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed?
public C build() { * If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in
C l = this.initBuild(); * (non-CuDNN) implementation for GravesBidirectionalLSTM will be used
l.initializeConstraints(); *
return l; */
} @Builder.Default
} protected boolean helperAllowFallback = true;
/** /**
* Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or longer-term * Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or longer-term
* dependencies. * dependencies.
@ -66,15 +65,6 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer {
*/ */
@Builder.Default @Builder.Default
private IActivation gateActivationFunction = new ActivationSigmoid(); private IActivation gateActivationFunction = new ActivationSigmoid();
/**
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed?
* If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in
* (non-CuDNN) implementation for GravesBidirectionalLSTM will be used
*
*/
@Builder.Default
protected boolean helperAllowFallback = true;
@Override @Override
protected void initializeConstraints() { protected void initializeConstraints() {
@ -121,5 +111,18 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer {
return LSTMHelpers.getMemoryReport(this, inputType); return LSTMHelpers.getMemoryReport(this, inputType);
} }
private static final class GravesBidirectionalLSTMBuilderImpl extends GravesBidirectionalLSTMBuilder<GravesBidirectionalLSTM, GravesBidirectionalLSTMBuilderImpl> {
public GravesBidirectionalLSTM build() {
GravesBidirectionalLSTM l = new GravesBidirectionalLSTM(this);
l.initializeConstraints();
return l;
}
}
public static abstract class GravesBidirectionalLSTMBuilder<C extends GravesBidirectionalLSTM, B extends
GravesBidirectionalLSTMBuilder<C, B>> extends BaseRecurrentLayerBuilder<C, B> {
}
} }

View File

@ -43,7 +43,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Deprecated @Deprecated
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class GravesLSTM extends AbstractLSTM { public class GravesLSTM extends AbstractLSTM {
private double forgetGateBiasInit; private double forgetGateBiasInit;
@ -103,9 +103,12 @@ public class GravesLSTM extends AbstractLSTM {
public abstract static class GravesLSTMBuilder< public abstract static class GravesLSTMBuilder<
C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>> C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>>
extends AbstractLSTMBuilder<C, B> { extends AbstractLSTMBuilder<C, B> {}
public C build() {
C l = initBuild(); private static final class GravesLSTMBuilderImpl
extends GravesLSTMBuilder<GravesLSTM, GravesLSTMBuilderImpl> {
public GravesLSTM build() {
GravesLSTM l = new GravesLSTM(this);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }

View File

@ -20,6 +20,10 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -31,31 +35,17 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers; import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers;
import org.deeplearning4j.nn.params.LSTMParamInitializer; import org.deeplearning4j.nn.params.LSTMParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class LSTM extends AbstractLSTM { public class LSTM extends AbstractLSTM {
private double forgetGateBiasInit; private double forgetGateBiasInit;
public static abstract class LSTMBuilder<C extends LSTM, B extends LSTMBuilder<C, B>> extends AbstractLSTMBuilder<C, B> {
@Override public C build() {
C l = this.initBuild();
l.initializeConstraints();
return l;
}
}
@Override @Override
protected void initializeConstraints() { protected void initializeConstraints() {
super.initializeConstraints(); super.initializeConstraints();
@ -72,13 +62,19 @@ public class LSTM extends AbstractLSTM {
} }
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerValidation.assertNInNOutSet("LSTM", getName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("LSTM", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance(); runInheritance();
org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(lconf, networkDataType); org.deeplearning4j.nn.layers.recurrent.LSTM ret =
new org.deeplearning4j.nn.layers.recurrent.LSTM(lconf, networkDataType);
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
ret.setIndex(layerIndex); ret.setIndex(layerIndex);
ret.setParamsViewArray(layerParamsView); ret.setParamsViewArray(layerParamsView);
@ -95,7 +91,19 @@ public class LSTM extends AbstractLSTM {
@Override @Override
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
//TODO - CuDNN etc // TODO - CuDNN etc
return LSTMHelpers.getMemoryReport(this, inputType); return LSTMHelpers.getMemoryReport(this, inputType);
} }
public abstract static class LSTMBuilder<C extends LSTM, B extends LSTMBuilder<C, B>>
extends AbstractLSTMBuilder<C, B> {}
private static final class LSTMBuilderImpl extends LSTMBuilder<LSTM, LSTMBuilderImpl> {
@Override
public LSTM build() {
LSTM l = new LSTM(this);
l.initializeConstraints();
return l;
}
}
} }

View File

@ -354,15 +354,5 @@ public abstract class LayerConfiguration
biasConstraints = Arrays.asList(constraints); biasConstraints = Arrays.asList(constraints);
return self(); return self();
} }
/**
* we are doing this to avoid BUG https://github.com/projectlombok/lombok/issues/3419 as some
* child classes may specify their own buildMethodName in @SuperBuilder, but we use only
* "initBuild" here consequently
* @return
*/
public C initBuild() {
return build();
}
} }
} }

View File

@ -41,7 +41,7 @@ import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class LearnedSelfAttentionLayer extends SameDiffLayer { public class LearnedSelfAttentionLayer extends SameDiffLayer {
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq"; private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk"; private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
@ -173,19 +173,24 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
public static abstract class LearnedSelfAttentionLayerBuilder< public static abstract class LearnedSelfAttentionLayerBuilder<
C extends LearnedSelfAttentionLayer, B extends LearnedSelfAttentionLayerBuilder<C, B>> C extends LearnedSelfAttentionLayer, B extends LearnedSelfAttentionLayerBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> { extends SameDiffLayerBuilder<C, B> {
public C build() {
Preconditions.checkArgument(
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(
this.nOut % nHeads == 0 || headSize > 0,
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries.");
return initBuild(); }
private static final class LearnedSelfAttentionLayerBuilderImpl extends LearnedSelfAttentionLayerBuilder<LearnedSelfAttentionLayer, LearnedSelfAttentionLayerBuilderImpl> {
public LearnedSelfAttentionLayer build() {
LearnedSelfAttentionLayer l = new LearnedSelfAttentionLayer(this);
Preconditions.checkArgument(
l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(
l.isProjectInput() || l.getNIn() == l.getNOut(), "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(
!l.isProjectInput() || l.getNOut() != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(
l.getNOut() % l.getNHeads() == 0 || l.getHeadSize() > 0,
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
Preconditions.checkArgument(l.getNQueries() > 0, "You must set numQueries.");
return l;
} }
} }
} }

View File

@ -48,19 +48,9 @@ import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties({"paramShapes"}) @JsonIgnoreProperties({"paramShapes"})
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class LocallyConnected1D extends SameDiffLayer { public class LocallyConnected1D extends SameDiffLayer {
public static abstract class LocallyConnected1DBuilder<C extends LocallyConnected1D, B extends LocallyConnected1DBuilder<C, B>> extends
SameDiffLayerBuilder<C, B> {
public C build() {
Convolution1DUtils.validateConvolutionModePadding(convolutionMode$value, padding$value);
Convolution1DUtils.validateCnn1DKernelStridePadding(kernelSize$value, stride$value, padding$value);
C l = initBuild();
return l;
}
}
private static final List<String> WEIGHT_KEYS = private static final List<String> WEIGHT_KEYS =
Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY); Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY);
private static final List<String> BIAS_KEYS = private static final List<String> BIAS_KEYS =
@ -89,10 +79,8 @@ public class LocallyConnected1D extends SameDiffLayer {
private int paddingR; // Right/bottom padding private int paddingR; // Right/bottom padding
/** Convolution mode for the layer. See {@link ConvolutionMode} for details */ /** Convolution mode for the layer. See {@link ConvolutionMode} for details */
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Same; @Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Same;
/** Dilation for the layer */ /** Dilation for the layer */
@Builder.Default private int dilation = 1; @Builder.Default private int dilation = 1;
/** If true (default is false) the layer will have a bias */ /** If true (default is false) the layer will have a bias */
@Builder.Default private boolean hasBias = true; @Builder.Default private boolean hasBias = true;
@ -272,4 +260,20 @@ public class LocallyConnected1D extends SameDiffLayer {
convolutionMode = global_conf.getConvolutionMode(); convolutionMode = global_conf.getConvolutionMode();
} }
} }
private static final class LocallyConnected1DBuilderImpl
extends LocallyConnected1DBuilder<LocallyConnected1D, LocallyConnected1DBuilderImpl> {
public LocallyConnected1D build() {
LocallyConnected1D l = new LocallyConnected1D(this);
Convolution1DUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
Convolution1DUtils.validateCnn1DKernelStridePadding(
l.getKernelSize(), l.getStride(), l.getPadding());
return l;
}
}
public abstract static class LocallyConnected1DBuilder<
C extends LocallyConnected1D, B extends LocallyConnected1DBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> {}
} }

View File

@ -41,7 +41,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.enums.PadMode; import org.nd4j.enums.PadMode;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -49,7 +48,7 @@ import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties({"paramShapes"}) @JsonIgnoreProperties({"paramShapes"})
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class LocallyConnected2D extends SameDiffLayer { public class LocallyConnected2D extends SameDiffLayer {
private static final List<String> WEIGHT_KEYS = private static final List<String> WEIGHT_KEYS =
@ -318,39 +317,43 @@ public class LocallyConnected2D extends SameDiffLayer {
} }
} }
public static abstract class LocallyConnected2DBuilder< private static final class LocallyConnected2DBuilderImpl
C extends LocallyConnected2D, B extends LocallyConnected2DBuilder<C, B>> extends LocallyConnected2DBuilder<LocallyConnected2D, LocallyConnected2DBuilderImpl> {
extends SameDiffLayerBuilder<C, B> { public LocallyConnected2D build() {
public C build() { LocallyConnected2D l = new LocallyConnected2D(this);
featureDim(kernel$value[0] * kernel$value[1] * (int) nIn); l.setFeatureDim(l.getKernel()[0] * l.getKernel()[1] * (int) l.getNIn());
C l = initBuild();
return l; return l;
} }
}
public B kernelSize(int ... kernel) { public abstract static class LocallyConnected2DBuilder<
C extends LocallyConnected2D, B extends LocallyConnected2DBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> {
public B kernelSize(int... kernel) {
this.kernel$value = ValidationUtils.validate2NonNegative(kernel, false, "kernel"); this.kernel$value = ValidationUtils.validate2NonNegative(kernel, false, "kernel");
this.kernel$set = true; this.kernel$set = true;
return self(); return self();
} }
public B inputSize(int ... size) { public B inputSize(int... size) {
this.inputSize = size; this.inputSize = size;
return self(); return self();
} }
public B stride(int ... stride) { public B stride(int... stride) {
this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride"); this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride");
this.stride$set = true; this.stride$set = true;
return self(); return self();
} }
public B padding(int ... padding) { public B padding(int... padding) {
this.padding$value = ValidationUtils.validate2NonNegative(padding, false, "padding"); this.padding$value = ValidationUtils.validate2NonNegative(padding, false, "padding");
this.padding$set = true; this.padding$set = true;
return self(); return self();
} }
public B dilation(int ... dilation) { public B dilation(int... dilation) {
this.dilation$value = ValidationUtils.validate2NonNegative(dilation, false, "dilation"); this.dilation$value = ValidationUtils.validate2NonNegative(dilation, false, "dilation");
this.dilation$set = true; this.dilation$set = true;
return self(); return self();

View File

@ -22,7 +22,6 @@ package org.deeplearning4j.nn.conf.layers;
import java.util.Collection; import java.util.Collection;
import java.util.Map; import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -33,13 +32,12 @@ import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class OutputLayer extends BaseOutputLayer { public class OutputLayer extends BaseOutputLayer {
{ // Set default activation function to softmax (to match default loss function MCXENT) { // Set default activation function to softmax (to match default loss function MCXENT)
@ -82,15 +80,16 @@ public class OutputLayer extends BaseOutputLayer {
return DefaultParamInitializer.getInstance(); return DefaultParamInitializer.getInstance();
} }
public static abstract class OutputLayerBuilder< public abstract static class OutputLayerBuilder<
C extends OutputLayer, B extends OutputLayerBuilder<C, B>> C extends OutputLayer, B extends OutputLayerBuilder<C, B>>
extends BaseOutputLayerBuilder<C, B> { extends BaseOutputLayerBuilder<C, B> {}
public C build() {
C l = this.initBuild(); private static final class OutputLayerBuilderImpl
extends OutputLayerBuilder<OutputLayer, OutputLayerBuilderImpl> {
public OutputLayer build() {
OutputLayer l = new OutputLayer(this);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
} }
} }

View File

@ -40,7 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class PReLULayer extends BaseLayerConfiguration { public class PReLULayer extends BaseLayerConfiguration {
/** /**
* Explicitly set input shape of incoming activations so that parameters can be initialized * Explicitly set input shape of incoming activations so that parameters can be initialized
@ -129,14 +129,17 @@ public class PReLULayer extends BaseLayerConfiguration {
.build(); .build();
} }
public static abstract class PReLULayerBuilder< private static final class PReLULayerBuilderImpl extends PReLULayerBuilder<PReLULayer, PReLULayerBuilderImpl> {
C extends PReLULayer, B extends PReLULayerBuilder<C, B>> public PReLULayer build() {
extends BaseLayerConfigurationBuilder<C, B> { PReLULayer l = new PReLULayer(this);
public C build() {
C l = initBuild();
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
}
public static abstract class PReLULayerBuilder<
C extends PReLULayer, B extends PReLULayerBuilder<C, B>>
extends BaseLayerConfigurationBuilder<C, B> {
/** /**
* Explicitly set input shape of incoming activations so that parameters can be initialized * Explicitly set input shape of incoming activations so that parameters can be initialized

View File

@ -35,6 +35,6 @@ import lombok.experimental.SuperBuilder;
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class Pooling1D extends Subsampling1DLayer { public class Pooling1D extends Subsampling1DLayer {
} }

View File

@ -35,6 +35,6 @@ import lombok.experimental.SuperBuilder;
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class Pooling2D extends SubsamplingLayer { public class Pooling2D extends SubsamplingLayer {
} }

View File

@ -41,7 +41,7 @@ import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class PrimaryCapsules extends SameDiffLayer { public class PrimaryCapsules extends SameDiffLayer {
private static final String WEIGHT_PARAM = "weight"; private static final String WEIGHT_PARAM = "weight";
@ -335,7 +335,7 @@ public class PrimaryCapsules extends SameDiffLayer {
} }
} }
public static abstract class PrimaryCapsulesBuilder< public abstract static class PrimaryCapsulesBuilder<
C extends PrimaryCapsules, B extends PrimaryCapsulesBuilder<C, B>> C extends PrimaryCapsules, B extends PrimaryCapsulesBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> { extends SameDiffLayerBuilder<C, B> {
@ -396,27 +396,30 @@ public class PrimaryCapsules extends SameDiffLayer {
this.useLeakyReLU$set = true; this.useLeakyReLU$set = true;
return self(); return self();
} }
}
public C build() { private static final class PrimaryCapsulesBuilderImpl
C l = initBuild(); extends PrimaryCapsulesBuilder<PrimaryCapsules, PrimaryCapsulesBuilderImpl> {
if (capsuleDimensions <= 0 || channels$value <= 0) { public PrimaryCapsules build() {
PrimaryCapsules l = new PrimaryCapsules(this);
if (l.getCapsuleDimensions() <= 0 || l.getChannels() <= 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Invalid configuration for Primary Capsules (layer name = \"" "Invalid configuration for Primary Capsules (layer name = \""
+ l.getName() + l.getName()
+ "\"):" + "\"):"
+ " capsuleDimensions and channels must be > 0. Got: " + " capsuleDimensions and channels must be > 0. Got: "
+ capsuleDimensions + l.getCapsuleDimensions()
+ ", " + ", "
+ channels$value); + l.getChannels());
} }
if (capsules < 0) { if (l.getCapsules() < 0) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Invalid configuration for Capsule ILayer (layer name = \"" "Invalid configuration for Capsule ILayer (layer name = \""
+ l.getName() + l.getName()
+ "\"):" + "\"):"
+ " capsules must be >= 0 if set. Got: " + " capsules must be >= 0 if set. Got: "
+ capsules); + l.getCapsules());
} }
return l; return l;
} }

View File

@ -43,21 +43,25 @@ import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class RecurrentAttentionLayer extends SameDiffLayer { public class RecurrentAttentionLayer extends SameDiffLayer {
private static final class RecurrentAttentionLayerBuilderImpl extends RecurrentAttentionLayerBuilder<RecurrentAttentionLayer, RecurrentAttentionLayerBuilderImpl> {
public RecurrentAttentionLayer build() {
RecurrentAttentionLayer l = new RecurrentAttentionLayer(this);
Preconditions.checkArgument(l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(l.isProjectInput() || l.getNIn() == l.getNOut(), "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(!l.isProjectInput() || l.getNOut() != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(l.getNOut() % l.getNHeads() == 0 || l.getNHeads() > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
return l;
}
}
public static abstract class RecurrentAttentionLayerBuilder<C extends RecurrentAttentionLayer, B extends RecurrentAttentionLayerBuilder<C,B>> public static abstract class RecurrentAttentionLayerBuilder<C extends RecurrentAttentionLayer, B extends RecurrentAttentionLayerBuilder<C,B>>
extends SameDiffLayerBuilder<C,B> { extends SameDiffLayerBuilder<C,B> {
public C build() {
Preconditions.checkArgument(this.projectInput$value || this.nHeads == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(this.projectInput$value || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(!this.projectInput$value || nOut != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
C l = initBuild();
return l;
}
} }
/** /**

View File

@ -24,7 +24,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString; import lombok.ToString;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
@ -42,30 +41,34 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class RnnOutputLayer extends BaseOutputLayer { public class RnnOutputLayer extends BaseOutputLayer {
/** /**
* @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength], * @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size,
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW. * timeSeriesLength], NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
*/ */
private RNNFormat dataFormat; private RNNFormat dataFormat;
public static RnnOutputLayerBuilder<?,?> builder() { public static RnnOutputLayerBuilder<?, ?> builder() {
return innerBuilder(); return innerBuilder();
} }
/** /**
* @param lossFn Loss function for the output layer * @param lossFn Loss function for the output layer
*/ */
public static RnnOutputLayerBuilder<?,?> builder(LossFunctions.LossFunction lossFn) { public static RnnOutputLayerBuilder<?, ?> builder(LossFunctions.LossFunction lossFn) {
return innerBuilder() return innerBuilder().lossFunction(lossFn);
.lossFunction(lossFn);
} }
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerValidation.assertNInNOutSet("RnnOutputLayer", getName(), layerIndex, getNIn(), getNOut()); LayerValidation.assertNInNOutSet("RnnOutputLayer", getName(), layerIndex, getNIn(), getNOut());
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
@ -88,8 +91,13 @@ public class RnnOutputLayer extends BaseOutputLayer {
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) { if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer index = " + layerIndex throw new IllegalStateException(
+ ", layer name=\"" + getName() + "\"): Expected RNN input, got " + inputType); "Invalid input type for RnnOutputLayer (layer index = "
+ layerIndex
+ ", layer name=\""
+ getName()
+ "\"): Expected RNN input, got "
+ inputType);
} }
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
@ -99,12 +107,15 @@ public class RnnOutputLayer extends BaseOutputLayer {
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) { if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer name=\"" + getName() throw new IllegalStateException(
+ "\"): Expected RNN input, got " + inputType); "Invalid input type for RnnOutputLayer (layer name=\""
+ getName()
+ "\"): Expected RNN input, got "
+ inputType);
} }
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
if(dataFormat == null || override) { if (dataFormat == null || override) {
this.dataFormat = r.getFormat(); this.dataFormat = r.getFormat();
} }
@ -118,12 +129,16 @@ public class RnnOutputLayer extends BaseOutputLayer {
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, dataFormat, getName()); return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, dataFormat, getName());
} }
public static abstract class RnnOutputLayerBuilder<C extends RnnOutputLayer, B extends RnnOutputLayerBuilder<C, B>> extends BaseOutputLayerBuilder<C, B> { public abstract static class RnnOutputLayerBuilder<
public C build() { C extends RnnOutputLayer, B extends RnnOutputLayerBuilder<C, B>>
C l = this.initBuild(); extends BaseOutputLayerBuilder<C, B> {}
private static final class RnnOutputLayerBuilderImpl
extends RnnOutputLayerBuilder<RnnOutputLayer, RnnOutputLayerBuilderImpl> {
public RnnOutputLayer build() {
RnnOutputLayer l = new RnnOutputLayer(this);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }
} }
} }

View File

@ -38,7 +38,7 @@ import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class SelfAttentionLayer extends SameDiffLayer { public class SelfAttentionLayer extends SameDiffLayer {
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq"; private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk"; private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";

View File

@ -44,7 +44,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class SeparableConvolution2D extends ConvolutionLayer { public class SeparableConvolution2D extends ConvolutionLayer {
/** /**
* Set constraints to be applied to the point-wise convolution weight parameters of this layer. * Set constraints to be applied to the point-wise convolution weight parameters of this layer.

View File

@ -50,7 +50,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class Subsampling1DLayer extends SubsamplingLayer { public class Subsampling1DLayer extends SubsamplingLayer {
@Override @Override
@ -153,11 +153,9 @@ public class Subsampling1DLayer extends SubsamplingLayer {
return true; return true;
} }
public static abstract class Subsampling1DLayerBuilder<C extends Subsampling1DLayer, B extends Subsampling1DLayerBuilder<C, B>> extends private static final class Subsampling1DLayerBuilderImpl extends Subsampling1DLayerBuilder<Subsampling1DLayer, Subsampling1DLayerBuilderImpl> {
SubsamplingLayerBuilder<C, B> { public Subsampling1DLayer build() {
Subsampling1DLayer l =new Subsampling1DLayer(this);
public C build() {
C l = this.initBuild();
if (l.getPoolingType() == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && l.getPnorm() <= 0) { if (l.getPoolingType() == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && l.getPnorm() <= 0) {
throw new IllegalStateException( throw new IllegalStateException(
"Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM"); "Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM");
@ -167,6 +165,11 @@ public class Subsampling1DLayer extends SubsamplingLayer {
return l; return l;
} }
}
public static abstract class Subsampling1DLayerBuilder<C extends Subsampling1DLayer, B extends Subsampling1DLayerBuilder<C, B>> extends
SubsamplingLayerBuilder<C, B> {
/** /**
* *
* @param kernelSize * @param kernelSize

View File

@ -45,7 +45,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild") @SuperBuilder(builderMethodName = "innerBuilder")
public class Subsampling3DLayer extends NoParamLayer { public class Subsampling3DLayer extends NoParamLayer {
@Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate; @Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
@ -304,17 +304,22 @@ public class Subsampling3DLayer extends NoParamLayer {
return self(); return self();
} }
public C build() {
if (kernelSize.length != 3) { }
private static final class Subsampling3DLayerBuilderImpl extends Subsampling3DLayerBuilder<Subsampling3DLayer, Subsampling3DLayerBuilderImpl> {
public Subsampling3DLayer build() {
Subsampling3DLayer l = new Subsampling3DLayer(this);
if (l.getKernelSize().length != 3) {
throw new IllegalArgumentException("Kernel size must be length 3"); throw new IllegalArgumentException("Kernel size must be length 3");
} }
if (stride.length != 3) { if (l.getStride().length != 3) {
throw new IllegalArgumentException("Invalid stride, must be length 3"); throw new IllegalArgumentException("Invalid stride, must be length 3");
} }
C l = this.initBuild();
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), padding); ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
Convolution3DUtils.validateCnn3DKernelStridePadding(kernelSize, stride, padding); Convolution3DUtils.validateCnn3DKernelStridePadding(l.getKernelSize(), l.getStride(), l.getPadding());
return l; return l;
} }
} }

View File

@ -45,7 +45,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class SubsamplingLayer extends NoParamLayer { public class SubsamplingLayer extends NoParamLayer {
public static final CNN2DFormat DEFAULT_FORMAT = CNN2DFormat.NCHW; public static final CNN2DFormat DEFAULT_FORMAT = CNN2DFormat.NCHW;
@ -425,25 +425,7 @@ public class SubsamplingLayer extends NoParamLayer {
return self(); return self();
} }
public C build() {
if (kernelSize$value.length != 2) {
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
}
if (stride$value.length != 2) {
throw new IllegalArgumentException("Invalid stride, must be length 2");
}
if (poolingType$value == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && pnorm <= 0) {
throw new IllegalStateException(
"Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM");
}
ConvolutionUtils.validateConvolutionModePadding(convolutionMode$value, padding$value);
ConvolutionUtils.validateCnnKernelStridePadding(
kernelSize$value, stride$value, padding$value);
C l = initBuild();
return l;
}
public B setConvolutionMode(ConvolutionMode convolutionMode){ public B setConvolutionMode(ConvolutionMode convolutionMode){
Preconditions.checkState(allowCausal$value || convolutionMode$value != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" + Preconditions.checkState(allowCausal$value || convolutionMode$value != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
@ -459,4 +441,25 @@ public class SubsamplingLayer extends NoParamLayer {
return self(); return self();
} }
} }
private static final class SubsamplingLayerBuilderImpl extends SubsamplingLayerBuilder<SubsamplingLayer, SubsamplingLayerBuilderImpl> {
public SubsamplingLayer build() {
SubsamplingLayer l = new SubsamplingLayer(this);
if (l.getKernelSize().length != 2) {
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
}
if (l.getStride().length != 2) {
throw new IllegalArgumentException("Invalid stride, must be length 2");
}
if (l.getPoolingType() == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && l.getPnorm() <= 0) {
throw new IllegalStateException(
"Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM");
}
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
ConvolutionUtils.validateCnnKernelStridePadding(
l.getKernelSize(), l.getStride(), l.getPadding());
return l;
}
}
} }

View File

@ -41,7 +41,7 @@ import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild") @SuperBuilder(builderMethodName = "innerBuilder")
public class ZeroPaddingLayer extends NoParamLayer { public class ZeroPaddingLayer extends NoParamLayer {
/** /**
* @param padding Padding value for top, bottom, left, and right. Must be length 4 array * @param padding Padding value for top, bottom, left, and right. Must be length 4 array

View File

@ -48,7 +48,7 @@ import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder
public class Yolo2OutputLayer extends LayerConfiguration { public class Yolo2OutputLayer extends LayerConfiguration {
/** /**

View File

@ -43,8 +43,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration { public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
/** The configuration to of another layer to wrap */ /** The configuration to of another layer to wrap */
@Getter @Setter @Getter @Setter protected LayerConfiguration underlying;
protected LayerConfiguration underlying;
/** /**
* Set the net configuration for this configuration as well as for the underlying layer (if not * Set the net configuration for this configuration as well as for the underlying layer (if not