Using @SuperBuilder for LayerConfigurations
Signed-off-by: brian <brian@brutex.de>master
parent
396dbec24e
commit
7628bbdd53
|
@ -43,7 +43,7 @@ import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class ActivationLayer extends NoParamLayer {
|
public class ActivationLayer extends NoParamLayer {
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,8 +133,12 @@ public class ActivationLayer extends NoParamLayer {
|
||||||
public static abstract class ActivationLayerBuilder<
|
public static abstract class ActivationLayerBuilder<
|
||||||
C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>>
|
C extends ActivationLayer, B extends ActivationLayerBuilder<C, B>>
|
||||||
extends NoParamLayer.NoParamLayerBuilder<C, B> {
|
extends NoParamLayer.NoParamLayerBuilder<C, B> {
|
||||||
public C build() {
|
|
||||||
C l = this.initBuild();
|
}
|
||||||
|
|
||||||
|
private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> {
|
||||||
|
public ActivationLayer build() {
|
||||||
|
ActivationLayer l = this.initBuild();
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder()
|
@SuperBuilder
|
||||||
public abstract class BaseUpsamplingLayer extends NoParamLayer {
|
public abstract class BaseUpsamplingLayer extends NoParamLayer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -25,10 +25,10 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
|
||||||
import net.brutex.ai.dnn.api.LayerType;
|
import net.brutex.ai.dnn.api.LayerType;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -48,287 +48,326 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class BatchNormalization extends FeedForwardLayer {
|
public class BatchNormalization extends FeedForwardLayer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* At test time: we can use a global estimate of the mean and variance, calculated using a moving average of the
|
* At test time: we can use a global estimate of the mean and variance, calculated using a moving
|
||||||
* batch means/variances. This moving average is implemented as:<br> globalMeanEstimate = decay *
|
* average of the batch means/variances. This moving average is implemented as:<br>
|
||||||
* globalMeanEstimate + (1-decay) * batchMean<br> globalVarianceEstimate = decay * globalVarianceEstimate +
|
* globalMeanEstimate = decay * globalMeanEstimate + (1-decay) * batchMean<br>
|
||||||
* (1-decay) * batchVariance<br>
|
* globalVarianceEstimate = decay * globalVarianceEstimate + (1-decay) * batchVariance<br>
|
||||||
*
|
*
|
||||||
* @param decay Decay value to use for global stats calculation
|
* @param decay Decay value to use for global stats calculation
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default
|
@lombok.Builder.Default protected double decay = 0.9;
|
||||||
protected double decay = 0.9;
|
|
||||||
|
|
||||||
//Note: need to set defaults here in addition to builder, in case user uses no-op constructor...
|
// Note: need to set defaults here in addition to builder, in case user uses no-op constructor...
|
||||||
/**
|
/**
|
||||||
* Epsilon value for batch normalization; small floating point value added to variance (algorithm 1 in <a
|
* Epsilon value for batch normalization; small floating point value added to variance (algorithm
|
||||||
* href="https://arxiv.org/pdf/1502.03167v3.pdf">https://arxiv.org/pdf/1502.03167v3.pdf</a>) to reduce/avoid
|
* 1 in <a
|
||||||
* underflow issues.<br> Default: 1e-5
|
* href="https://arxiv.org/pdf/1502.03167v3.pdf">https://arxiv.org/pdf/1502.03167v3.pdf</a>) to
|
||||||
*
|
* reduce/avoid underflow issues.<br>
|
||||||
* @param eps Epsilon values to use
|
* Default: 1e-5
|
||||||
*/
|
*
|
||||||
@lombok.Builder.Default protected double eps = 1e-5;
|
* @param eps Epsilon values to use
|
||||||
/**
|
*/
|
||||||
* If doing minibatch training or not. Default: true. Under most circumstances, this should be set to true. If
|
@lombok.Builder.Default protected double eps = 1e-5;
|
||||||
* doing full batch training (i.e., all examples in a single DataSet object - very small data sets) then this
|
/**
|
||||||
* should be set to false. Affects how global mean/variance estimates are calculated.
|
* If doing minibatch training or not. Default: true. Under most circumstances, this should be set
|
||||||
*
|
* to true. If doing full batch training (i.e., all examples in a single DataSet object - very
|
||||||
* @param minibatch Minibatch parameter
|
* small data sets) then this should be set to false. Affects how global mean/variance estimates
|
||||||
*/
|
* are calculated.
|
||||||
@lombok.Builder.Default protected boolean isMinibatch = true;
|
*
|
||||||
|
* @param minibatch Minibatch parameter
|
||||||
|
*/
|
||||||
|
@lombok.Builder.Default protected boolean isMinibatch = true;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}.
|
||||||
* 1.0
|
* Value is not used otherwise.<br>
|
||||||
*
|
* Default: 1.0
|
||||||
* @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode
|
*
|
||||||
*/
|
* @param gamma Gamma parameter for all activations, used only with locked gamma/beta
|
||||||
@lombok.Builder.Default protected double gamma = 1.0;
|
* configuration mode
|
||||||
/**
|
*/
|
||||||
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
@lombok.Builder.Default protected double gamma = 1.0;
|
||||||
* 0.0
|
/**
|
||||||
*
|
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}.
|
||||||
* @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode
|
* Value is not used otherwise.<br>
|
||||||
*/
|
* Default: 0.0
|
||||||
@lombok.Builder.Default protected double beta = 0.0;
|
*
|
||||||
/**
|
* @param beta Beta parameter for all activations, used only with locked gamma/beta configuration
|
||||||
* Set constraints to be applied to the beta parameter of this batch normalisation layer. Default: no
|
* mode
|
||||||
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
|
*/
|
||||||
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
@lombok.Builder.Default protected double beta = 0.0;
|
||||||
* been updated.
|
/**
|
||||||
*
|
* Set constraints to be applied to the beta parameter of this batch normalisation layer. Default:
|
||||||
*/
|
* no constraints.<br>
|
||||||
protected List<LayerConstraint> betaConstraints;
|
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
||||||
|
* regularization, etc). These constraints are applied at each iteration, after the parameters
|
||||||
|
* have been updated.
|
||||||
|
*/
|
||||||
|
protected List<LayerConstraint> betaConstraints;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to the gamma parameter of this batch normalisation layer. Default: no
|
* Set constraints to be applied to the gamma parameter of this batch normalisation layer.
|
||||||
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
|
* Default: no constraints.<br>
|
||||||
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
* Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm
|
||||||
* been updated.
|
* regularization, etc). These constraints are applied at each iteration, after the parameters
|
||||||
*
|
* have been updated.
|
||||||
*/
|
*/
|
||||||
protected List<LayerConstraint> gammaConstraints;
|
protected List<LayerConstraint> gammaConstraints;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper
|
||||||
|
* implementation be allowed? If set to false, an exception in the helper will be propagated back
|
||||||
|
* to the user. If true, the built-in (non-MKL/CuDNN) implementation for BatchNormalizationLayer
|
||||||
|
* will be used
|
||||||
|
*
|
||||||
|
* @param allowFallback Whether fallback to non-CuDNN implementation should be used
|
||||||
|
*/
|
||||||
|
@lombok.Builder.Default protected boolean cudnnAllowFallback = true;
|
||||||
|
/**
|
||||||
|
* How should the moving average of variance be stored? Two different parameterizations are
|
||||||
|
* supported. useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is
|
||||||
|
* stored directly as variable<br>
|
||||||
|
* useLogStd(true): (Default) variance is stored as log10(stdev)<br>
|
||||||
|
* The motivation here is for numerical stability (FP16 etc) and also distributed training:
|
||||||
|
* storing the variance directly can cause numerical issues. For example, a standard deviation of
|
||||||
|
* 1e-3 (something that could be encountered in practice) gives a variance of 1e-6, which can be
|
||||||
|
* problematic for 16-bit floating point
|
||||||
|
*
|
||||||
|
* <p>How should the moving average of variance be stored? Two different parameterizations are
|
||||||
|
* supported. useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is
|
||||||
|
* stored directly as variable<br>
|
||||||
|
* useLogStd(true): (Default) variance is stored as log10(stdev)<br>
|
||||||
|
* The motivation here is for numerical stability (FP16 etc) and also distributed training:
|
||||||
|
* storing the variance directly can cause numerical issues. For example, a standard deviation of
|
||||||
|
* 1e-3 (something that could be encountered in practice) gives a variance of 1e-6, which can be
|
||||||
|
* problematic for 16-bit floating point
|
||||||
|
*/
|
||||||
|
@lombok.Builder.Default
|
||||||
|
protected boolean useLogStd =
|
||||||
|
false; // Default for deserialized models (1.0.0-beta3) and earlier: store variance as
|
||||||
|
// variance. Post 1.0.0-beta3: use log stdev instead
|
||||||
|
/**
|
||||||
|
* Set the input and output array data format. Defaults to NCHW format - i.e., channels first. See
|
||||||
|
* {@link CNN2DFormat} for more details
|
||||||
|
*
|
||||||
|
* @param format Format to use
|
||||||
|
*/
|
||||||
|
@lombok.Builder.Default
|
||||||
|
protected CNN2DFormat dataFormat =
|
||||||
|
CNN2DFormat.NCHW; // Default for deserialized models, 1.0.0-beta6 and earlier
|
||||||
|
|
||||||
/**
|
private boolean lockGammaBeta;
|
||||||
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed?
|
|
||||||
* If set to false, an exception in the helper will be propagated back to the user. If true, the built-in
|
|
||||||
* (non-MKL/CuDNN) implementation for BatchNormalizationLayer will be used
|
|
||||||
*
|
|
||||||
* @param allowFallback Whether fallback to non-CuDNN implementation should be used
|
|
||||||
*/
|
|
||||||
@lombok.Builder.Default protected boolean cudnnAllowFallback = true;
|
|
||||||
/**
|
|
||||||
* How should the moving average of variance be stored? Two different parameterizations are supported.
|
|
||||||
* useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is stored directly as
|
|
||||||
* variable<br> useLogStd(true): (Default) variance is stored as log10(stdev)<br> The motivation here is for
|
|
||||||
* numerical stability (FP16 etc) and also distributed training: storing the variance directly can cause
|
|
||||||
* numerical issues. For example, a standard deviation of 1e-3 (something that could be encountered in practice)
|
|
||||||
* gives a variance of 1e-6, which can be problematic for 16-bit floating point
|
|
||||||
*
|
|
||||||
* How should the moving average of variance be stored? Two different parameterizations are supported.
|
|
||||||
* useLogStd(false): equivalent to 1.0.0-beta3 and earlier. The variance "parameter" is stored directly as
|
|
||||||
* variable<br> useLogStd(true): (Default) variance is stored as log10(stdev)<br> The motivation here is for
|
|
||||||
* numerical stability (FP16 etc) and also distributed training: storing the variance directly can cause
|
|
||||||
* numerical issues. For example, a standard deviation of 1e-3 (something that could be encountered in practice)
|
|
||||||
* gives a variance of 1e-6, which can be problematic for 16-bit floating point
|
|
||||||
*/
|
|
||||||
@lombok.Builder.Default protected boolean useLogStd = false; //Default for deserialized models (1.0.0-beta3) and earlier: store variance as variance. Post 1.0.0-beta3: use log stdev instead
|
|
||||||
/**
|
|
||||||
* Set the input and output array data format. Defaults to NCHW format - i.e., channels first.
|
|
||||||
* See {@link CNN2DFormat} for more details
|
|
||||||
* @param format Format to use
|
|
||||||
*/
|
|
||||||
@lombok.Builder.Default protected CNN2DFormat dataFormat = CNN2DFormat.NCHW; //Default for deserialized models, 1.0.0-beta6 and earlier
|
|
||||||
|
|
||||||
private boolean lockGammaBeta;
|
public static BatchNormalizationBuilder<?, ?> builder() {
|
||||||
|
return innerBuilder();
|
||||||
|
}
|
||||||
|
|
||||||
public static BatchNormalizationBuilder<?, ?> builder() {
|
public static BatchNormalizationBuilder<?, ?> builder(double gamma, double beta) {
|
||||||
return innerBuilder();
|
return innerBuilder().gamma(gamma).beta(beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static BatchNormalizationBuilder<?, ?> builder(boolean lockGammaBeta) {
|
||||||
|
return innerBuilder().lockGammaBeta(lockGammaBeta);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BatchNormalization clone() {
|
||||||
|
BatchNormalization clone = (BatchNormalization) super.clone();
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Layer instantiate(
|
||||||
|
NeuralNetConfiguration conf,
|
||||||
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
|
this.setNetConfiguration(conf);
|
||||||
|
LayerValidation.assertNOutSet("BatchNormalization", getName(), layerIndex, getNOut());
|
||||||
|
runInheritance();
|
||||||
|
|
||||||
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
org.deeplearning4j.nn.layers.normalization.BatchNormalization ret =
|
||||||
|
new org.deeplearning4j.nn.layers.normalization.BatchNormalization(lconf, networkDataType);
|
||||||
|
ret.addTrainingListeners(trainingListeners);
|
||||||
|
ret.setIndex(layerIndex);
|
||||||
|
ret.setParamsViewArray(layerParamsView);
|
||||||
|
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||||
|
ret.setParamTable(paramTable);
|
||||||
|
ret.setLayerConfiguration(lconf);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParamInitializer initializer() {
|
||||||
|
return BatchNormalizationParamInitializer.getInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
|
if (inputType == null) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Invalid input type: Batch norm layer expected input of type CNN, got null for layer \""
|
||||||
|
+ getName()
|
||||||
|
+ "\"");
|
||||||
}
|
}
|
||||||
|
|
||||||
public static BatchNormalizationBuilder<?, ?> builder(double gamma, double beta) {
|
// Can handle CNN, flat CNN, CNN3D or FF input formats only
|
||||||
return innerBuilder()
|
switch (inputType.getType()) {
|
||||||
.gamma(gamma)
|
case FF:
|
||||||
.beta(beta);
|
case CNN:
|
||||||
|
case CNNFlat:
|
||||||
|
case CNN3D:
|
||||||
|
return inputType; // OK
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got "
|
||||||
|
+ inputType
|
||||||
|
+ " for layer index "
|
||||||
|
+ layerIndex
|
||||||
|
+ ", layer name = "
|
||||||
|
+ getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
|
if (nIn <= 0 || override) {
|
||||||
|
switch (inputType.getType()) {
|
||||||
|
case FF:
|
||||||
|
nIn = ((InputType.InputTypeFeedForward) inputType).getSize();
|
||||||
|
break;
|
||||||
|
case CNN:
|
||||||
|
nIn = ((InputType.InputTypeConvolutional) inputType).getChannels();
|
||||||
|
dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||||
|
break;
|
||||||
|
case CNN3D:
|
||||||
|
nIn = ((InputType.InputTypeConvolutional3D) inputType).getChannels();
|
||||||
|
break;
|
||||||
|
case CNNFlat:
|
||||||
|
nIn = ((InputType.InputTypeConvolutionalFlat) inputType).getDepth();
|
||||||
|
default:
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got "
|
||||||
|
+ inputType
|
||||||
|
+ " for layer "
|
||||||
|
+ getName()
|
||||||
|
+ "\"");
|
||||||
|
}
|
||||||
|
nOut = nIn;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
|
if (inputType.getType() == InputType.Type.CNNFlat) {
|
||||||
|
InputType.InputTypeConvolutionalFlat i = (InputType.InputTypeConvolutionalFlat) inputType;
|
||||||
|
return new FeedForwardToCnnPreProcessor(i.getHeight(), i.getWidth(), i.getDepth());
|
||||||
|
} else if (inputType.getType() == InputType.Type.RNN) {
|
||||||
|
return new RnnToFeedForwardPreProcessor();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static BatchNormalizationBuilder<?, ?> builder(boolean lockGammaBeta) {
|
return null;
|
||||||
return innerBuilder()
|
}
|
||||||
.lockGammaBeta(lockGammaBeta);
|
|
||||||
|
@Override
|
||||||
|
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||||
|
// Don't regularize batch norm params: similar to biases in the sense that there are not many of
|
||||||
|
// them...
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public IUpdater getUpdaterByParam(String paramName) {
|
||||||
|
switch (paramName) {
|
||||||
|
case BatchNormalizationParamInitializer.BETA:
|
||||||
|
case BatchNormalizationParamInitializer.GAMMA:
|
||||||
|
return getUpdater();
|
||||||
|
case BatchNormalizationParamInitializer.GLOBAL_MEAN:
|
||||||
|
case BatchNormalizationParamInitializer.GLOBAL_VAR:
|
||||||
|
case BatchNormalizationParamInitializer.GLOBAL_LOG_STD:
|
||||||
|
return new NoOp();
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("Unknown parameter: \"" + paramName + "\"");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
|
// TODO CuDNN helper etc
|
||||||
|
|
||||||
|
val numParams = initializer().numParams(this);
|
||||||
|
int updaterStateSize = 0;
|
||||||
|
|
||||||
|
for (String s : BatchNormalizationParamInitializer.getInstance().paramKeys(this)) {
|
||||||
|
updaterStateSize += getUpdaterByParam(s).stateSize(nOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
// During forward pass: working memory size approx. equal to 2x input size (copy ops, etc)
|
||||||
public BatchNormalization clone() {
|
val inferenceWorkingSize = 2 * inputType.arrayElementsPerExample();
|
||||||
BatchNormalization clone = (BatchNormalization) super.clone();
|
|
||||||
return clone;
|
// During training: we calculate mean and variance... result is equal to nOut, and INDEPENDENT
|
||||||
|
// of minibatch size
|
||||||
|
val trainWorkFixed = 2 * nOut;
|
||||||
|
// During backprop: multiple working arrays... output size, 2 * output size (indep. of example
|
||||||
|
// size),
|
||||||
|
val trainWorkingSizePerExample =
|
||||||
|
inferenceWorkingSize // Inference during backprop
|
||||||
|
+ (outputType.arrayElementsPerExample() + 2 * nOut); // Backprop gradient calculation
|
||||||
|
|
||||||
|
return new LayerMemoryReport.Builder(name, BatchNormalization.class, inputType, outputType)
|
||||||
|
.standardMemory(numParams, updaterStateSize)
|
||||||
|
.workingMemory(
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
trainWorkFixed,
|
||||||
|
trainWorkingSizePerExample) // No additional memory (beyond activations) for inference
|
||||||
|
.cacheMemory(
|
||||||
|
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isPretrainParam(String paramName) {
|
||||||
|
return false; // No pretrain params in BN
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class BatchNormalizationBuilderImpl
|
||||||
|
extends BatchNormalizationBuilder<BatchNormalization, BatchNormalizationBuilderImpl> {
|
||||||
|
public BatchNormalization build() {
|
||||||
|
BatchNormalization l = new BatchNormalization(this);
|
||||||
|
l.setType(LayerType.BN);
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract static class BatchNormalizationBuilder<
|
||||||
|
C extends BatchNormalization, B extends BatchNormalizationBuilder<C, B>>
|
||||||
|
extends FeedForwardLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
public B helperAllowFallback(boolean b) {
|
||||||
|
this.cudnnAllowFallback$value = b;
|
||||||
|
this.cudnnAllowFallback$set = true;
|
||||||
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public B constrainBeta(LayerConstraint... constraints) {
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
this.betaConstraints = List.of(constraints);
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
return self();
|
||||||
this.setNetConfiguration(conf);
|
|
||||||
LayerValidation.assertNOutSet("BatchNormalization", getName(), layerIndex, getNOut());
|
|
||||||
runInheritance();
|
|
||||||
|
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
|
||||||
org.deeplearning4j.nn.layers.normalization.BatchNormalization ret =
|
|
||||||
new org.deeplearning4j.nn.layers.normalization.BatchNormalization(lconf, networkDataType);
|
|
||||||
ret.addTrainingListeners(trainingListeners);
|
|
||||||
ret.setIndex(layerIndex);
|
|
||||||
ret.setParamsViewArray(layerParamsView);
|
|
||||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
|
||||||
ret.setParamTable(paramTable);
|
|
||||||
ret.setLayerConfiguration(lconf);
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public B constrainGamma(LayerConstraint... constraints) {
|
||||||
public ParamInitializer initializer() {
|
this.gammaConstraints = List.of(constraints);
|
||||||
return BatchNormalizationParamInitializer.getInstance();
|
return self();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@Override
|
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
|
||||||
if (inputType == null) {
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"Invalid input type: Batch norm layer expected input of type CNN, got null for layer \""
|
|
||||||
+ getName() + "\"");
|
|
||||||
}
|
|
||||||
|
|
||||||
//Can handle CNN, flat CNN, CNN3D or FF input formats only
|
|
||||||
switch (inputType.getType()) {
|
|
||||||
case FF:
|
|
||||||
case CNN:
|
|
||||||
case CNNFlat:
|
|
||||||
case CNN3D:
|
|
||||||
return inputType; //OK
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got "
|
|
||||||
+ inputType + " for layer index " + layerIndex + ", layer name = "
|
|
||||||
+ getName());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
|
||||||
if (nIn <= 0 || override) {
|
|
||||||
switch (inputType.getType()) {
|
|
||||||
case FF:
|
|
||||||
nIn = ((InputType.InputTypeFeedForward) inputType).getSize();
|
|
||||||
break;
|
|
||||||
case CNN:
|
|
||||||
nIn = ((InputType.InputTypeConvolutional) inputType).getChannels();
|
|
||||||
dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
|
||||||
break;
|
|
||||||
case CNN3D:
|
|
||||||
nIn = ((InputType.InputTypeConvolutional3D) inputType).getChannels();
|
|
||||||
break;
|
|
||||||
case CNNFlat:
|
|
||||||
nIn = ((InputType.InputTypeConvolutionalFlat) inputType).getDepth();
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got "
|
|
||||||
+ inputType + " for layer " + getName() + "\"");
|
|
||||||
}
|
|
||||||
nOut = nIn;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
|
||||||
if (inputType.getType() == InputType.Type.CNNFlat) {
|
|
||||||
InputType.InputTypeConvolutionalFlat i = (InputType.InputTypeConvolutionalFlat) inputType;
|
|
||||||
return new FeedForwardToCnnPreProcessor(i.getHeight(), i.getWidth(), i.getDepth());
|
|
||||||
} else if (inputType.getType() == InputType.Type.RNN) {
|
|
||||||
return new RnnToFeedForwardPreProcessor();
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Regularization> getRegularizationByParam(String paramName){
|
|
||||||
//Don't regularize batch norm params: similar to biases in the sense that there are not many of them...
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IUpdater getUpdaterByParam(String paramName) {
|
|
||||||
switch (paramName) {
|
|
||||||
case BatchNormalizationParamInitializer.BETA:
|
|
||||||
case BatchNormalizationParamInitializer.GAMMA:
|
|
||||||
return getUpdater();
|
|
||||||
case BatchNormalizationParamInitializer.GLOBAL_MEAN:
|
|
||||||
case BatchNormalizationParamInitializer.GLOBAL_VAR:
|
|
||||||
case BatchNormalizationParamInitializer.GLOBAL_LOG_STD:
|
|
||||||
return new NoOp();
|
|
||||||
default:
|
|
||||||
throw new IllegalArgumentException("Unknown parameter: \"" + paramName + "\"");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
|
||||||
|
|
||||||
//TODO CuDNN helper etc
|
|
||||||
|
|
||||||
val numParams = initializer().numParams(this);
|
|
||||||
int updaterStateSize = 0;
|
|
||||||
|
|
||||||
for (String s : BatchNormalizationParamInitializer.getInstance().paramKeys(this)) {
|
|
||||||
updaterStateSize += getUpdaterByParam(s).stateSize(nOut);
|
|
||||||
}
|
|
||||||
|
|
||||||
//During forward pass: working memory size approx. equal to 2x input size (copy ops, etc)
|
|
||||||
val inferenceWorkingSize = 2 * inputType.arrayElementsPerExample();
|
|
||||||
|
|
||||||
//During training: we calculate mean and variance... result is equal to nOut, and INDEPENDENT of minibatch size
|
|
||||||
val trainWorkFixed = 2 * nOut;
|
|
||||||
//During backprop: multiple working arrays... output size, 2 * output size (indep. of example size),
|
|
||||||
val trainWorkingSizePerExample = inferenceWorkingSize //Inference during backprop
|
|
||||||
+ (outputType.arrayElementsPerExample() + 2 * nOut); //Backprop gradient calculation
|
|
||||||
|
|
||||||
return new LayerMemoryReport.Builder(name, BatchNormalization.class, inputType, outputType)
|
|
||||||
.standardMemory(numParams, updaterStateSize)
|
|
||||||
.workingMemory(0, 0, trainWorkFixed, trainWorkingSizePerExample) //No additional memory (beyond activations) for inference
|
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isPretrainParam(String paramName) {
|
|
||||||
return false; //No pretrain params in BN
|
|
||||||
}
|
|
||||||
|
|
||||||
public static abstract class BatchNormalizationBuilder<C extends BatchNormalization, B extends BatchNormalizationBuilder<C, B>> extends FeedForwardLayerBuilder<C, B> {
|
|
||||||
public C build() {
|
|
||||||
C l = this.initBuild();
|
|
||||||
l.setType(LayerType.BN);
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
public B helperAllowFallback(boolean b) {
|
|
||||||
this.cudnnAllowFallback$value = b;
|
|
||||||
this.cudnnAllowFallback$set = true;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
|
|
||||||
public B constrainBeta(LayerConstraint ... constraints) {
|
|
||||||
this.betaConstraints = List.of(constraints);
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
public B constrainGamma(LayerConstraint ... constraints) {
|
|
||||||
this.gammaConstraints = List.of(constraints);
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,9 +38,8 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class CapsuleLayer extends SameDiffLayer {
|
public class CapsuleLayer extends SameDiffLayer {
|
||||||
|
|
||||||
private static final String WEIGHT_PARAM = "weight";
|
private static final String WEIGHT_PARAM = "weight";
|
||||||
|
@ -78,6 +77,18 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
*/
|
*/
|
||||||
@Builder.Default @Getter @Setter private int routings = 3;
|
@Builder.Default @Getter @Setter private int routings = 3;
|
||||||
|
|
||||||
|
public static CapsuleLayerBuilder<?,?> builder() {
|
||||||
|
return innerBuilder()
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static CapsuleLayerBuilder<?,?> builder(int capsules, int capsulesDim, int routings) {
|
||||||
|
return innerBuilder()
|
||||||
|
.capsules(capsules)
|
||||||
|
.capsuleDimensions(capsulesDim)
|
||||||
|
.routings(routings);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
if(inputType == null || inputType.getType() != Type.RNN) {
|
if(inputType == null || inputType.getType() != Type.RNN) {
|
||||||
|
@ -185,16 +196,6 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
return InputType.recurrent(capsules, capsuleDimensions);
|
return InputType.recurrent(capsules, capsuleDimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static CapsuleLayerBuilder<?,?> builder() {
|
|
||||||
return innerBuilder()
|
|
||||||
;
|
|
||||||
}
|
|
||||||
public static CapsuleLayerBuilder<?,?> builder(int capsules, int capsulesDim, int routings) {
|
|
||||||
return innerBuilder()
|
|
||||||
.capsules(capsules)
|
|
||||||
.capsuleDimensions(capsulesDim)
|
|
||||||
.routings(routings);
|
|
||||||
}
|
|
||||||
public static abstract class CapsuleLayerBuilder<
|
public static abstract class CapsuleLayerBuilder<
|
||||||
C extends CapsuleLayer, B extends CapsuleLayerBuilder<C, B>>
|
C extends CapsuleLayer, B extends CapsuleLayerBuilder<C, B>>
|
||||||
extends SameDiffLayerBuilder<C, B> {
|
extends SameDiffLayerBuilder<C, B> {
|
||||||
|
@ -215,35 +216,37 @@ public class CapsuleLayer extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public C build() {
|
|
||||||
C l = this.initBuild();
|
|
||||||
if (capsules <= 0 || capsuleDimensions <= 0 || routings$value <= 0) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Invalid configuration for Capsule ILayer (layer name = \""
|
|
||||||
+ l.getName()
|
|
||||||
+ "\"):"
|
|
||||||
+ " capsules, capsuleDimensions, and routings must be > 0. Got: "
|
|
||||||
+ capsules
|
|
||||||
+ ", "
|
|
||||||
+ capsuleDimensions
|
|
||||||
+ ", "
|
|
||||||
+ routings$value);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (inputCapsules$value < 0 || inputCapsuleDimensions$value < 0) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Invalid configuration for Capsule ILayer (layer name = \""
|
|
||||||
+ l.getName()
|
|
||||||
+ "\"):"
|
|
||||||
+ " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: "
|
|
||||||
+ inputCapsules$value
|
|
||||||
+ ", "
|
|
||||||
+ inputCapsuleDimensions$value);
|
|
||||||
}
|
|
||||||
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class CapsuleLayerBuilderImpl extends CapsuleLayerBuilder<CapsuleLayer, CapsuleLayerBuilderImpl> {
|
||||||
|
public CapsuleLayer build() {
|
||||||
|
CapsuleLayer l = new CapsuleLayer(this);
|
||||||
|
if (l.getCapsules() <= 0 || l.getCapsuleDimensions() <= 0 || l.getRoutings() <= 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Invalid configuration for Capsule ILayer (layer name = \""
|
||||||
|
+ l.getName()
|
||||||
|
+ "\"):"
|
||||||
|
+ " capsules, capsuleDimensions, and routings must be > 0. Got: "
|
||||||
|
+ l.getCapsules()
|
||||||
|
+ ", "
|
||||||
|
+ l.getCapsuleDimensions()
|
||||||
|
+ ", "
|
||||||
|
+ l.getRoutings());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l.getInputCapsules() < 0 || l.getInputCapsuleDimensions() < 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Invalid configuration for Capsule ILayer (layer name = \""
|
||||||
|
+ l.getName()
|
||||||
|
+ "\"):"
|
||||||
|
+ " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: "
|
||||||
|
+ l.getInputCapsules()
|
||||||
|
+ ", "
|
||||||
|
+ l.getInputCapsuleDimensions() );
|
||||||
|
}
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -30,36 +32,21 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.params.CenterLossParamInitializer;
|
import org.deeplearning4j.nn.params.CenterLossParamInitializer;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.linalg.learning.config.NoOp;
|
import org.nd4j.linalg.learning.config.NoOp;
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
|
||||||
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class CenterLossOutputLayer extends BaseOutputLayer {
|
public class CenterLossOutputLayer extends BaseOutputLayer {
|
||||||
|
|
||||||
@Builder.Default protected double alpha= 0.805;
|
@Builder.Default protected double alpha= 0.805;
|
||||||
@Builder.Default protected double lambda = 2e-4;
|
@Builder.Default protected double lambda = 2e-4;
|
||||||
@Builder.Default protected boolean gradientCheck = false;
|
@Builder.Default protected boolean gradientCheck = false;
|
||||||
|
|
||||||
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
|
|
||||||
BaseOutputLayerBuilder<C, B> {
|
|
||||||
public C build() {
|
|
||||||
C l = initBuild();
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
||||||
|
@ -91,7 +78,6 @@ public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOu
|
||||||
return getUpdater();
|
return getUpdater();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public boolean getGradientCheck() {
|
public boolean getGradientCheck() {
|
||||||
return gradientCheck;
|
return gradientCheck;
|
||||||
}
|
}
|
||||||
|
@ -135,6 +121,24 @@ public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOu
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
|
||||||
|
BaseOutputLayerBuilder<C, B> {
|
||||||
|
public C build() {
|
||||||
|
C l = initBuild();
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer,
|
||||||
|
CenterLossOutputLayerBuilderImpl> {
|
||||||
|
public CenterLossOutputLayer build() {
|
||||||
|
CenterLossOutputLayer l = new CenterLossOutputLayer(this);
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,6 @@ import lombok.experimental.SuperBuilder;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class Convolution1D extends Convolution1DLayer {
|
public class Convolution1D extends Convolution1DLayer {
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,9 +47,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class Convolution1DLayer extends ConvolutionLayer {
|
public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
|
||||||
/**
|
/**
|
||||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
* See {@link CNN2DFormat} for more details.<br>
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
@ -60,6 +59,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
protected CNN2DFormat dataFormat =
|
protected CNN2DFormat dataFormat =
|
||||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||||
|
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||||
/**
|
/**
|
||||||
* Size of the convolution
|
* Size of the convolution
|
||||||
*
|
*
|
||||||
|
@ -183,17 +183,20 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class Convolution1DLayerBuilder<
|
private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> {
|
||||||
C extends ConvolutionLayer, B extends Convolution1DLayerBuilder<C, B>>
|
public ConvolutionLayer build() {
|
||||||
extends ConvolutionLayerBuilder<C, B> {
|
ConvolutionLayer l = initBuild();
|
||||||
public C build() {
|
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
C l = initBuild();
|
|
||||||
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), padding$value);
|
|
||||||
ConvolutionUtils.validateCnnKernelStridePadding(
|
ConvolutionUtils.validateCnnKernelStridePadding(
|
||||||
kernelSize$value, stride$value, padding$value);
|
l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
public static abstract class Convolution1DLayerBuilder<
|
||||||
|
C extends ConvolutionLayer, B extends Convolution1DLayerBuilder<C, B>>
|
||||||
|
extends ConvolutionLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
|
||||||
public B kernelSize(int @NonNull ... kernelSize) {
|
public B kernelSize(int @NonNull ... kernelSize) {
|
||||||
this.kernelSize$value[0] = ValidationUtils.validate1NonNegative(kernelSize, "kernelSize")[0];
|
this.kernelSize$value[0] = ValidationUtils.validate1NonNegative(kernelSize, "kernelSize")[0];
|
||||||
|
|
|
@ -30,6 +30,6 @@ import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class Convolution2D extends ConvolutionLayer {
|
public class Convolution2D extends ConvolutionLayer {
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class Convolution3D extends ConvolutionLayer {
|
public class Convolution3D extends ConvolutionLayer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -235,17 +235,20 @@ public class Convolution3D extends ConvolutionLayer {
|
||||||
NDHWC
|
NDHWC
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class Convolution3DBuilderImpl extends Convolution3DBuilder<Convolution3D, Convolution3DBuilderImpl> {
|
||||||
|
public Convolution3D build() {
|
||||||
|
Convolution3D l = new Convolution3D(this);
|
||||||
|
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
|
Convolution3DUtils.validateCnn3DKernelStridePadding(l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
// public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
|
// public Builder(int[] kernelSize, int[] stride, int[] padding, int[] dilation) {
|
||||||
// sup/er(kernelSize, stride, padding, dilation, 3);
|
// sup/er(kernelSize, stride, padding, dilation, 3);
|
||||||
public static abstract class Convolution3DBuilder<
|
public static abstract class Convolution3DBuilder<
|
||||||
C extends Convolution3D, B extends Convolution3DBuilder<C, B>>
|
C extends Convolution3D, B extends Convolution3DBuilder<C, B>>
|
||||||
extends ConvolutionLayer.ConvolutionLayerBuilder<C, B> {
|
extends ConvolutionLayer.ConvolutionLayerBuilder<C, B> {
|
||||||
public C build() {
|
|
||||||
ConvolutionUtils.validateConvolutionModePadding(convolutionMode$value, padding);
|
|
||||||
Convolution3DUtils.validateCnn3DKernelStridePadding(kernelSize, stride, padding);
|
|
||||||
C l = initBuild();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override // TODO we can use the parent builder and do not need to redefine the variables.
|
@Override // TODO we can use the parent builder and do not need to redefine the variables.
|
||||||
// Validation can be done in override function!
|
// Validation can be done in override function!
|
||||||
|
|
|
@ -48,7 +48,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
*/
|
*/
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class ConvolutionLayer extends FeedForwardLayer {
|
public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
/**
|
/**
|
||||||
* Size of the convolution rows/columns
|
* Size of the convolution rows/columns
|
||||||
|
@ -397,48 +397,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public C build() {
|
|
||||||
ConvolutionUtils.validateConvolutionModePadding(convolutionMode$value, padding$value);
|
|
||||||
ConvolutionUtils.validateCnnKernelStridePadding(
|
|
||||||
kernelSize$value, stride$value, padding$value);
|
|
||||||
|
|
||||||
if (kernelSize$value.length != convolutionDim$value) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Kernel argument should be a "
|
|
||||||
+ convolutionDim$value
|
|
||||||
+ "d array, got "
|
|
||||||
+ Arrays.toString(kernelSize$value));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stride$value.length != convolutionDim$value) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Strides argument should be a "
|
|
||||||
+ convolutionDim$value
|
|
||||||
+ "d array, got "
|
|
||||||
+ Arrays.toString(stride$value));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (padding$value.length != convolutionDim$value) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Padding argument should be a "
|
|
||||||
+ convolutionDim$value
|
|
||||||
+ "d array, got "
|
|
||||||
+ Arrays.toString(padding$value));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dilation$value.length != convolutionDim$value) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Dilation argument should be a "
|
|
||||||
+ convolutionDim$value
|
|
||||||
+ "d array, got "
|
|
||||||
+ Arrays.toString(dilation$value));
|
|
||||||
}
|
|
||||||
|
|
||||||
C l = initBuild();
|
|
||||||
l.setType(LayerType.CONV);
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper
|
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper
|
||||||
|
@ -454,4 +413,47 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
private static final class ConvolutionLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, ConvolutionLayerBuilderImpl> {
|
||||||
|
public ConvolutionLayer build() {
|
||||||
|
ConvolutionLayer l = new ConvolutionLayer(this);
|
||||||
|
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
|
ConvolutionUtils.validateCnnKernelStridePadding(
|
||||||
|
l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
|
|
||||||
|
if (l.getKernelSize().length != l.getConvolutionDim()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Kernel argument should be a "
|
||||||
|
+ l.getConvolutionDim()
|
||||||
|
+ "d array, got "
|
||||||
|
+ Arrays.toString(l.getKernelSize()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l.getStride().length != l.getConvolutionDim()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Strides argument should be a "
|
||||||
|
+ l.getConvolutionDim()
|
||||||
|
+ "d array, got "
|
||||||
|
+ Arrays.toString(l.getStride()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l.getPadding().length != l.getConvolutionDim()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Padding argument should be a "
|
||||||
|
+ l.getConvolutionDim()
|
||||||
|
+ "d array, got "
|
||||||
|
+ Arrays.toString(l.getPadding()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l.getDilation().length != l.getConvolutionDim()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Dilation argument should be a "
|
||||||
|
+ l.getConvolutionDim()
|
||||||
|
+ "d array, got "
|
||||||
|
+ Arrays.toString(l.getDilation()));
|
||||||
|
}
|
||||||
|
l.setType(LayerType.CONV);
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ import java.util.Map;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuild")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class Deconvolution2D extends ConvolutionLayer {
|
public class Deconvolution2D extends ConvolutionLayer {
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,12 +57,15 @@ private CNN2DFormat format = CNN2DFormat.NCHW;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class Deconvolution2DBuilder<C extends Deconvolution2D, B extends Deconvolution2DBuilder<C, B>> extends ConvolutionLayerBuilder<C, B> {
|
private static final class Deconvolution2DBuilderImpl extends Deconvolution2DBuilder<Deconvolution2D, Deconvolution2DBuilderImpl> {
|
||||||
public C build() {
|
public Deconvolution2D build() {
|
||||||
C l = initBuild();
|
Deconvolution2D l = new Deconvolution2D(this);
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
public static abstract class Deconvolution2DBuilder<C extends Deconvolution2D, B extends Deconvolution2DBuilder<C, B>> extends ConvolutionLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -44,7 +44,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class Deconvolution3D extends ConvolutionLayer {
|
public class Deconvolution3D extends ConvolutionLayer {
|
||||||
/**
|
/**
|
||||||
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more
|
* Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more
|
||||||
|
@ -56,6 +56,15 @@ public class Deconvolution3D extends ConvolutionLayer {
|
||||||
private Convolution3D.DataFormat dataFormat =
|
private Convolution3D.DataFormat dataFormat =
|
||||||
Convolution3D.DataFormat.NCDHW; // in libnd4j: 1 - NCDHW, 0 - NDHWC
|
Convolution3D.DataFormat.NCDHW; // in libnd4j: 1 - NCDHW, 0 - NDHWC
|
||||||
|
|
||||||
|
public static Deconvolution3DBuilder<?, ?> builder() {
|
||||||
|
return innerBuilder()
|
||||||
|
.kernelSize(new int[] {2, 2, 2})
|
||||||
|
.stride(new int[] {1, 1, 1})
|
||||||
|
.padding(new int[] {0, 0, 0})
|
||||||
|
.dilation(new int[] {1, 1, 1})
|
||||||
|
.convolutionDim(3);
|
||||||
|
}
|
||||||
|
|
||||||
protected boolean allowCausal() {
|
protected boolean allowCausal() {
|
||||||
// Causal convolution - allowed for 1D only
|
// Causal convolution - allowed for 1D only
|
||||||
return false;
|
return false;
|
||||||
|
@ -69,13 +78,13 @@ public class Deconvolution3D extends ConvolutionLayer {
|
||||||
public Deconvolution3D clone() {
|
public Deconvolution3D clone() {
|
||||||
Deconvolution3D clone = (Deconvolution3D) super.clone();
|
Deconvolution3D clone = (Deconvolution3D) super.clone();
|
||||||
if (clone.getKernelSize() != null) {
|
if (clone.getKernelSize() != null) {
|
||||||
clone.setKernelSize( clone.getKernelSize().clone());
|
clone.setKernelSize(clone.getKernelSize().clone());
|
||||||
}
|
}
|
||||||
if (clone.getStride() != null) {
|
if (clone.getStride() != null) {
|
||||||
clone.setStride( clone.getStride().clone());
|
clone.setStride(clone.getStride().clone());
|
||||||
}
|
}
|
||||||
if (clone.getPadding() != null) {
|
if (clone.getPadding() != null) {
|
||||||
clone.setPadding( clone.getPadding().clone());
|
clone.setPadding(clone.getPadding().clone());
|
||||||
}
|
}
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
@ -134,6 +143,11 @@ public class Deconvolution3D extends ConvolutionLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// private int[] kernelSize;
|
||||||
|
// private int[] stride;
|
||||||
|
// private int[] padding;
|
||||||
|
// private int[] dilation;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
|
if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
|
||||||
|
@ -158,29 +172,16 @@ public class Deconvolution3D extends ConvolutionLayer {
|
||||||
Deconvolution3DLayer.class);
|
Deconvolution3DLayer.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
//private int[] kernelSize;
|
public abstract static class Deconvolution3DBuilder<
|
||||||
//private int[] stride;
|
|
||||||
//private int[] padding;
|
|
||||||
//private int[] dilation;
|
|
||||||
|
|
||||||
public static abstract class Deconvolution3DBuilder<
|
|
||||||
C extends Deconvolution3D, B extends Deconvolution3DBuilder<C, B>>
|
C extends Deconvolution3D, B extends Deconvolution3DBuilder<C, B>>
|
||||||
extends ConvolutionLayerBuilder<C, B> {
|
extends ConvolutionLayerBuilder<C, B> {}
|
||||||
public C build() {
|
|
||||||
C l = initBuild();
|
private static final class Deconvolution3DBuilderImpl
|
||||||
|
extends Deconvolution3DBuilder<Deconvolution3D, Deconvolution3DBuilderImpl> {
|
||||||
|
public Deconvolution3D build() {
|
||||||
|
Deconvolution3D l = new Deconvolution3D(this);
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Deconvolution3DBuilder<?,?> builder() {
|
|
||||||
return innerBuilder()
|
|
||||||
.kernelSize(new int[] {2, 2, 2})
|
|
||||||
.stride(new int[] {1, 1, 1})
|
|
||||||
.padding(new int[] {0, 0, 0})
|
|
||||||
.dilation(new int[] {1, 1, 1})
|
|
||||||
.convolutionDim(3);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class DepthwiseConvolution2D extends ConvolutionLayer {
|
public class DepthwiseConvolution2D extends ConvolutionLayer {
|
||||||
/**
|
/**
|
||||||
* Set channels multiplier for depth-wise convolution
|
* Set channels multiplier for depth-wise convolution
|
||||||
|
@ -145,21 +145,25 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
|
||||||
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract static class DepthwiseConvolution2DBuilder<
|
private static final class DepthwiseConvolution2DBuilderImpl extends DepthwiseConvolution2DBuilder<DepthwiseConvolution2D, DepthwiseConvolution2DBuilderImpl> {
|
||||||
C extends DepthwiseConvolution2D, B extends DepthwiseConvolution2DBuilder<C, B>>
|
public DepthwiseConvolution2D build() {
|
||||||
extends ConvolutionLayerBuilder<C, B> {
|
DepthwiseConvolution2D l = new DepthwiseConvolution2D(this);
|
||||||
public C build() {
|
|
||||||
Preconditions.checkState(
|
Preconditions.checkState(
|
||||||
depthMultiplier$value > 0,
|
l.getDepthMultiplier() > 0,
|
||||||
"Depth multiplier must be > 0, got %s",
|
"Depth multiplier must be > 0, got %s",
|
||||||
depthMultiplier$value);
|
l.getDepthMultiplier());
|
||||||
C l = this.initBuild();
|
|
||||||
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
ConvolutionUtils.validateCnnKernelStridePadding(
|
ConvolutionUtils.validateCnnKernelStridePadding(
|
||||||
l.getKernelSize(), l.getStride(), l.getPadding());
|
l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
public abstract static class DepthwiseConvolution2DBuilder<
|
||||||
|
C extends DepthwiseConvolution2D, B extends DepthwiseConvolution2DBuilder<C, B>>
|
||||||
|
extends ConvolutionLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public B kernelSize(int... kernelSize) {
|
public B kernelSize(int... kernelSize) {
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.Accessors;
|
import lombok.experimental.Accessors;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
|
@ -30,127 +32,136 @@ import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
import org.deeplearning4j.nn.conf.memory.MemoryReport;
|
||||||
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
|
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class EmbeddingLayer extends FeedForwardLayer {
|
public class EmbeddingLayer extends FeedForwardLayer {
|
||||||
/**
|
/**
|
||||||
* If true: include bias parameters in the layer. False (default): no bias.
|
* If true: include bias parameters in the layer. False (default): no bias.
|
||||||
* @param hasBias If true: include bias parameters in this layer
|
*
|
||||||
*/
|
* @param hasBias If true: include bias parameters in this layer
|
||||||
@Accessors @Builder.Default
|
*/
|
||||||
private boolean hasBias = false;
|
@Accessors @Builder.Default private boolean hasBias = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*Default to Identity activation - i.e., don't inherit.
|
* Default to Identity activation - i.e., don't inherit. For example, if user sets ReLU as global
|
||||||
* For example, if user sets ReLU as global default, they very likely don't intend to use it for Embedding layer also
|
* default, they very likely don't intend to use it for Embedding layer also
|
||||||
*
|
|
||||||
*/
|
*/
|
||||||
public static EmbeddingLayerBuilder<?, ?> builder() {
|
public static EmbeddingLayerBuilder<?, ?> builder() {
|
||||||
return innerBuilder()
|
return innerBuilder().activation(Activation.IDENTITY);
|
||||||
.activation(Activation.IDENTITY);
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Layer instantiate(
|
||||||
|
NeuralNetConfiguration conf,
|
||||||
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret =
|
||||||
|
new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(
|
||||||
|
lconf, networkDataType);
|
||||||
|
runInheritance();
|
||||||
|
|
||||||
|
ret.addTrainingListeners(trainingListeners);
|
||||||
|
ret.setIndex(layerIndex);
|
||||||
|
ret.setParamsViewArray(layerParamsView);
|
||||||
|
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||||
|
ret.setParamTable(paramTable);
|
||||||
|
ret.setLayerConfiguration(lconf);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParamInitializer initializer() {
|
||||||
|
return EmbeddingLayerParamInitializer.getInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
|
// Basically a dense layer, but no dropout is possible here, and no epsilons
|
||||||
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
|
val actElementsPerEx = outputType.arrayElementsPerExample();
|
||||||
|
val numParams = initializer().numParams(this);
|
||||||
|
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
||||||
|
|
||||||
|
// Embedding layer does not use caching.
|
||||||
|
// Inference: no working memory - just activations (pullRows)
|
||||||
|
// Training: preout op, the only in-place ops on epsilon (from layer above) + assign ops
|
||||||
|
|
||||||
|
return new LayerMemoryReport.Builder(name, EmbeddingLayer.class, inputType, outputType)
|
||||||
|
.standardMemory(numParams, updaterStateSize)
|
||||||
|
.workingMemory(0, 0, 0, actElementsPerEx)
|
||||||
|
.cacheMemory(
|
||||||
|
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static final class EmbeddingLayerBuilderImpl
|
||||||
|
extends EmbeddingLayerBuilder<EmbeddingLayer, EmbeddingLayerBuilderImpl> {
|
||||||
|
public EmbeddingLayer build() {
|
||||||
|
EmbeddingLayer l = new EmbeddingLayer(this);
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public static abstract class EmbeddingLayerBuilder<C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C,B>>
|
public abstract static class EmbeddingLayerBuilder<
|
||||||
extends FeedForwardLayerBuilder<C,B>{
|
C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C, B>>
|
||||||
public C build() {
|
extends FeedForwardLayerBuilder<C, B> {
|
||||||
C l = initBuild();
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Weight initialization scheme to use, for initial weight values
|
* Weight initialization scheme to use, for initial weight values
|
||||||
*
|
*
|
||||||
* @param weightInit
|
* @param weightInit
|
||||||
* @see WeightInit
|
* @see WeightInit
|
||||||
*/
|
*/
|
||||||
@Override
|
|
||||||
public B weightInit(WeightInit weightInit) {
|
|
||||||
if(weightInit.getWeightInitFunction() instanceof WeightInitEmbedding){
|
|
||||||
long[] shape = ((WeightInitEmbedding) weightInit.getWeightInitFunction()).shape();
|
|
||||||
nIn(shape[0]);
|
|
||||||
nOut(shape[1]);
|
|
||||||
}
|
|
||||||
super.weightInit(weightInit);
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Initialize the embedding layer using values from the specified array. Note that the array should have shape
|
|
||||||
* [vocabSize, vectorSize]. After copying values from the array to initialize the network parameters, the input
|
|
||||||
* array will be discarded (so that, if necessary, it can be garbage collected)
|
|
||||||
*
|
|
||||||
* @param vectors Vectors to initialize the embedding layer with
|
|
||||||
*/
|
|
||||||
public B weightInit(INDArray vectors){
|
|
||||||
weightInit(new ArrayEmbeddingInitializer(vectors));
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec instance
|
|
||||||
*
|
|
||||||
* @param embeddingInitializer Source of the embedding layer weights
|
|
||||||
*/
|
|
||||||
public B weightInit(EmbeddingInitializer embeddingInitializer) {
|
|
||||||
var weightIn = new WeightInitEmbedding(embeddingInitializer);
|
|
||||||
super.weightInit(weightIn);
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@Override
|
@Override
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
public B weightInit(WeightInit weightInit) {
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
if (weightInit.getWeightInitFunction() instanceof WeightInitEmbedding) {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
long[] shape = ((WeightInitEmbedding) weightInit.getWeightInitFunction()).shape();
|
||||||
org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret =
|
nIn(shape[0]);
|
||||||
new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(lconf, networkDataType);
|
nOut(shape[1]);
|
||||||
runInheritance();
|
}
|
||||||
|
super.weightInit(weightInit);
|
||||||
ret.addTrainingListeners(trainingListeners);
|
return self();
|
||||||
ret.setIndex(layerIndex);
|
}
|
||||||
ret.setParamsViewArray(layerParamsView);
|
/**
|
||||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
* Initialize the embedding layer using values from the specified array. Note that the array
|
||||||
ret.setParamTable(paramTable);
|
* should have shape [vocabSize, vectorSize]. After copying values from the array to initialize
|
||||||
ret.setLayerConfiguration(lconf);
|
* the network parameters, the input array will be discarded (so that, if necessary, it can be
|
||||||
return ret;
|
* garbage collected)
|
||||||
|
*
|
||||||
|
* @param vectors Vectors to initialize the embedding layer with
|
||||||
|
*/
|
||||||
|
public B weightInit(INDArray vectors) {
|
||||||
|
weightInit(new ArrayEmbeddingInitializer(vectors));
|
||||||
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
/**
|
||||||
public ParamInitializer initializer() {
|
* Initialize the embedding layer using the specified EmbeddingInitializer - such as a Word2Vec
|
||||||
return EmbeddingLayerParamInitializer.getInstance();
|
* instance
|
||||||
}
|
*
|
||||||
|
* @param embeddingInitializer Source of the embedding layer weights
|
||||||
@Override
|
*/
|
||||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
public B weightInit(EmbeddingInitializer embeddingInitializer) {
|
||||||
//Basically a dense layer, but no dropout is possible here, and no epsilons
|
var weightIn = new WeightInitEmbedding(embeddingInitializer);
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
super.weightInit(weightIn);
|
||||||
|
return self();
|
||||||
val actElementsPerEx = outputType.arrayElementsPerExample();
|
|
||||||
val numParams = initializer().numParams(this);
|
|
||||||
val updaterStateSize = (int) getIUpdater().stateSize(numParams);
|
|
||||||
|
|
||||||
//Embedding layer does not use caching.
|
|
||||||
//Inference: no working memory - just activations (pullRows)
|
|
||||||
//Training: preout op, the only in-place ops on epsilon (from layer above) + assign ops
|
|
||||||
|
|
||||||
return new LayerMemoryReport.Builder(name, EmbeddingLayer.class, inputType, outputType)
|
|
||||||
.standardMemory(numParams, updaterStateSize).workingMemory(0, 0, 0, actElementsPerEx)
|
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ import java.util.Map;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
||||||
/**
|
/**
|
||||||
* Set input sequence length for this embedding layer.
|
* Set input sequence length for this embedding layer.
|
||||||
|
@ -70,13 +70,16 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer {
|
||||||
@Builder.Default private boolean inferInputLength = false; // use input length as provided by input data
|
@Builder.Default private boolean inferInputLength = false; // use input length as provided by input data
|
||||||
@Builder.Default private RNNFormat outputDataFormat = RNNFormat.NCW; //Default value for older deserialized models
|
@Builder.Default private RNNFormat outputDataFormat = RNNFormat.NCW; //Default value for older deserialized models
|
||||||
|
|
||||||
|
private static final class EmbeddingSequenceLayerBuilderImpl extends EmbeddingSequenceLayerBuilder<EmbeddingSequenceLayer, EmbeddingSequenceLayerBuilderImpl> {
|
||||||
|
public EmbeddingSequenceLayer build() {
|
||||||
|
EmbeddingSequenceLayer l = new EmbeddingSequenceLayer(this);
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
public static abstract class EmbeddingSequenceLayerBuilder<C extends EmbeddingSequenceLayer, B extends EmbeddingSequenceLayerBuilder<C, B>>
|
public static abstract class EmbeddingSequenceLayerBuilder<C extends EmbeddingSequenceLayer, B extends EmbeddingSequenceLayerBuilder<C, B>>
|
||||||
extends FeedForwardLayerBuilder<C, B> {
|
extends FeedForwardLayerBuilder<C, B> {
|
||||||
public C build() {
|
|
||||||
C l = initBuild();
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
public B weightInit(IWeightInit weightInit){
|
public B weightInit(IWeightInit weightInit){
|
||||||
if(weightInit instanceof WeightInitEmbedding){
|
if(weightInit instanceof WeightInitEmbedding){
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -36,23 +37,21 @@ import org.nd4j.linalg.activations.impl.ActivationSigmoid;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class GravesBidirectionalLSTM extends BaseRecurrentLayer {
|
public class GravesBidirectionalLSTM extends BaseRecurrentLayer {
|
||||||
|
|
||||||
public static abstract class GravesBidirectionalLSTMBuilder<C extends GravesBidirectionalLSTM, B extends
|
/**
|
||||||
GravesBidirectionalLSTMBuilder<C, B>> extends BaseRecurrentLayerBuilder<C, B> {
|
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed?
|
||||||
public C build() {
|
* If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in
|
||||||
C l = this.initBuild();
|
* (non-CuDNN) implementation for GravesBidirectionalLSTM will be used
|
||||||
l.initializeConstraints();
|
*
|
||||||
return l;
|
*/
|
||||||
}
|
@Builder.Default
|
||||||
}
|
protected boolean helperAllowFallback = true;
|
||||||
/**
|
/**
|
||||||
* Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or longer-term
|
* Set forget gate bias initalizations. Values in range 1-5 can potentially help with learning or longer-term
|
||||||
* dependencies.
|
* dependencies.
|
||||||
|
@ -66,15 +65,6 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer {
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private IActivation gateActivationFunction = new ActivationSigmoid();
|
private IActivation gateActivationFunction = new ActivationSigmoid();
|
||||||
/**
|
|
||||||
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation be allowed?
|
|
||||||
* If set to false, an exception in CuDNN will be propagated back to the user. If false, the built-in
|
|
||||||
* (non-CuDNN) implementation for GravesBidirectionalLSTM will be used
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Builder.Default
|
|
||||||
protected boolean helperAllowFallback = true;
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void initializeConstraints() {
|
protected void initializeConstraints() {
|
||||||
|
@ -121,5 +111,18 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer {
|
||||||
return LSTMHelpers.getMemoryReport(this, inputType);
|
return LSTMHelpers.getMemoryReport(this, inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class GravesBidirectionalLSTMBuilderImpl extends GravesBidirectionalLSTMBuilder<GravesBidirectionalLSTM, GravesBidirectionalLSTMBuilderImpl> {
|
||||||
|
public GravesBidirectionalLSTM build() {
|
||||||
|
GravesBidirectionalLSTM l = new GravesBidirectionalLSTM(this);
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static abstract class GravesBidirectionalLSTMBuilder<C extends GravesBidirectionalLSTM, B extends
|
||||||
|
GravesBidirectionalLSTMBuilder<C, B>> extends BaseRecurrentLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Deprecated
|
@Deprecated
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class GravesLSTM extends AbstractLSTM {
|
public class GravesLSTM extends AbstractLSTM {
|
||||||
|
|
||||||
private double forgetGateBiasInit;
|
private double forgetGateBiasInit;
|
||||||
|
@ -103,9 +103,12 @@ public class GravesLSTM extends AbstractLSTM {
|
||||||
|
|
||||||
public abstract static class GravesLSTMBuilder<
|
public abstract static class GravesLSTMBuilder<
|
||||||
C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>>
|
C extends GravesLSTM, B extends GravesLSTMBuilder<C, B>>
|
||||||
extends AbstractLSTMBuilder<C, B> {
|
extends AbstractLSTMBuilder<C, B> {}
|
||||||
public C build() {
|
|
||||||
C l = initBuild();
|
private static final class GravesLSTMBuilderImpl
|
||||||
|
extends GravesLSTMBuilder<GravesLSTM, GravesLSTMBuilderImpl> {
|
||||||
|
public GravesLSTM build() {
|
||||||
|
GravesLSTM l = new GravesLSTM(this);
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -31,71 +35,75 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
|
||||||
import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers;
|
import org.deeplearning4j.nn.layers.recurrent.LSTMHelpers;
|
||||||
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
import org.deeplearning4j.nn.params.LSTMParamInitializer;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class LSTM extends AbstractLSTM {
|
public class LSTM extends AbstractLSTM {
|
||||||
|
|
||||||
private double forgetGateBiasInit;
|
private double forgetGateBiasInit;
|
||||||
|
|
||||||
public static abstract class LSTMBuilder<C extends LSTM, B extends LSTMBuilder<C, B>> extends AbstractLSTMBuilder<C, B> {
|
@Override
|
||||||
@Override public C build() {
|
protected void initializeConstraints() {
|
||||||
C l = this.initBuild();
|
super.initializeConstraints();
|
||||||
l.initializeConstraints();
|
if (recurrentConstraints != null) {
|
||||||
return l;
|
if (constraints == null) {
|
||||||
}
|
constraints = new ArrayList<>();
|
||||||
|
}
|
||||||
|
for (LayerConstraint c : recurrentConstraints) {
|
||||||
|
LayerConstraint c2 = c.clone();
|
||||||
|
c2.setParams(Collections.singleton(LSTMParamInitializer.RECURRENT_WEIGHT_KEY));
|
||||||
|
constraints.add(c2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Layer instantiate(
|
||||||
|
NeuralNetConfiguration conf,
|
||||||
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
|
LayerValidation.assertNInNOutSet("LSTM", getName(), layerIndex, getNIn(), getNOut());
|
||||||
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
runInheritance();
|
||||||
|
|
||||||
|
org.deeplearning4j.nn.layers.recurrent.LSTM ret =
|
||||||
|
new org.deeplearning4j.nn.layers.recurrent.LSTM(lconf, networkDataType);
|
||||||
|
ret.addTrainingListeners(trainingListeners);
|
||||||
|
ret.setIndex(layerIndex);
|
||||||
|
ret.setParamsViewArray(layerParamsView);
|
||||||
|
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||||
|
ret.setParamTable(paramTable);
|
||||||
|
ret.setLayerConfiguration(lconf);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParamInitializer initializer() {
|
||||||
|
return LSTMParamInitializer.getInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
|
// TODO - CuDNN etc
|
||||||
|
return LSTMHelpers.getMemoryReport(this, inputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract static class LSTMBuilder<C extends LSTM, B extends LSTMBuilder<C, B>>
|
||||||
|
extends AbstractLSTMBuilder<C, B> {}
|
||||||
|
|
||||||
|
private static final class LSTMBuilderImpl extends LSTMBuilder<LSTM, LSTMBuilderImpl> {
|
||||||
@Override
|
@Override
|
||||||
protected void initializeConstraints() {
|
public LSTM build() {
|
||||||
super.initializeConstraints();
|
LSTM l = new LSTM(this);
|
||||||
if (recurrentConstraints != null) {
|
l.initializeConstraints();
|
||||||
if (constraints == null) {
|
return l;
|
||||||
constraints = new ArrayList<>();
|
|
||||||
}
|
|
||||||
for (LayerConstraint c : recurrentConstraints) {
|
|
||||||
LayerConstraint c2 = c.clone();
|
|
||||||
c2.setParams(Collections.singleton(LSTMParamInitializer.RECURRENT_WEIGHT_KEY));
|
|
||||||
constraints.add(c2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
|
||||||
LayerValidation.assertNInNOutSet("LSTM", getName(), layerIndex, getNIn(), getNOut());
|
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
|
||||||
runInheritance();
|
|
||||||
|
|
||||||
org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(lconf, networkDataType);
|
|
||||||
ret.addTrainingListeners(trainingListeners);
|
|
||||||
ret.setIndex(layerIndex);
|
|
||||||
ret.setParamsViewArray(layerParamsView);
|
|
||||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
|
||||||
ret.setParamTable(paramTable);
|
|
||||||
ret.setLayerConfiguration(lconf);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ParamInitializer initializer() {
|
|
||||||
return LSTMParamInitializer.getInstance();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
|
||||||
//TODO - CuDNN etc
|
|
||||||
return LSTMHelpers.getMemoryReport(this, inputType);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -354,15 +354,5 @@ public abstract class LayerConfiguration
|
||||||
biasConstraints = Arrays.asList(constraints);
|
biasConstraints = Arrays.asList(constraints);
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* we are doing this to avoid BUG https://github.com/projectlombok/lombok/issues/3419 as some
|
|
||||||
* child classes may specify their own buildMethodName in @SuperBuilder, but we use only
|
|
||||||
* "initBuild" here consequently
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public C initBuild() {
|
|
||||||
return build();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
||||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||||
|
@ -173,19 +173,24 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
||||||
public static abstract class LearnedSelfAttentionLayerBuilder<
|
public static abstract class LearnedSelfAttentionLayerBuilder<
|
||||||
C extends LearnedSelfAttentionLayer, B extends LearnedSelfAttentionLayerBuilder<C, B>>
|
C extends LearnedSelfAttentionLayer, B extends LearnedSelfAttentionLayerBuilder<C, B>>
|
||||||
extends SameDiffLayerBuilder<C, B> {
|
extends SameDiffLayerBuilder<C, B> {
|
||||||
public C build() {
|
|
||||||
Preconditions.checkArgument(
|
|
||||||
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
|
||||||
Preconditions.checkArgument(
|
|
||||||
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
|
||||||
Preconditions.checkArgument(
|
|
||||||
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
|
||||||
Preconditions.checkArgument(
|
|
||||||
this.nOut % nHeads == 0 || headSize > 0,
|
|
||||||
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
|
||||||
Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries.");
|
|
||||||
|
|
||||||
return initBuild();
|
}
|
||||||
|
|
||||||
|
private static final class LearnedSelfAttentionLayerBuilderImpl extends LearnedSelfAttentionLayerBuilder<LearnedSelfAttentionLayer, LearnedSelfAttentionLayerBuilderImpl> {
|
||||||
|
public LearnedSelfAttentionLayer build() {
|
||||||
|
LearnedSelfAttentionLayer l = new LearnedSelfAttentionLayer(this);
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
l.isProjectInput() || l.getNIn() == l.getNOut(), "nIn must be equal to nOut when projectInput is false");
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
!l.isProjectInput() || l.getNOut() != 0, "nOut must be specified when projectInput is true");
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
l.getNOut() % l.getNHeads() == 0 || l.getHeadSize() > 0,
|
||||||
|
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||||
|
Preconditions.checkArgument(l.getNQueries() > 0, "You must set numQueries.");
|
||||||
|
|
||||||
|
return l;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,19 +48,9 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@JsonIgnoreProperties({"paramShapes"})
|
@JsonIgnoreProperties({"paramShapes"})
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class LocallyConnected1D extends SameDiffLayer {
|
public class LocallyConnected1D extends SameDiffLayer {
|
||||||
|
|
||||||
public static abstract class LocallyConnected1DBuilder<C extends LocallyConnected1D, B extends LocallyConnected1DBuilder<C, B>> extends
|
|
||||||
SameDiffLayerBuilder<C, B> {
|
|
||||||
public C build() {
|
|
||||||
Convolution1DUtils.validateConvolutionModePadding(convolutionMode$value, padding$value);
|
|
||||||
Convolution1DUtils.validateCnn1DKernelStridePadding(kernelSize$value, stride$value, padding$value);
|
|
||||||
C l = initBuild();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static final List<String> WEIGHT_KEYS =
|
private static final List<String> WEIGHT_KEYS =
|
||||||
Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY);
|
Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
private static final List<String> BIAS_KEYS =
|
private static final List<String> BIAS_KEYS =
|
||||||
|
@ -89,10 +79,8 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
private int paddingR; // Right/bottom padding
|
private int paddingR; // Right/bottom padding
|
||||||
/** Convolution mode for the layer. See {@link ConvolutionMode} for details */
|
/** Convolution mode for the layer. See {@link ConvolutionMode} for details */
|
||||||
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Same;
|
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Same;
|
||||||
|
|
||||||
/** Dilation for the layer */
|
/** Dilation for the layer */
|
||||||
@Builder.Default private int dilation = 1;
|
@Builder.Default private int dilation = 1;
|
||||||
|
|
||||||
/** If true (default is false) the layer will have a bias */
|
/** If true (default is false) the layer will have a bias */
|
||||||
@Builder.Default private boolean hasBias = true;
|
@Builder.Default private boolean hasBias = true;
|
||||||
|
|
||||||
|
@ -272,4 +260,20 @@ public class LocallyConnected1D extends SameDiffLayer {
|
||||||
convolutionMode = global_conf.getConvolutionMode();
|
convolutionMode = global_conf.getConvolutionMode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class LocallyConnected1DBuilderImpl
|
||||||
|
extends LocallyConnected1DBuilder<LocallyConnected1D, LocallyConnected1DBuilderImpl> {
|
||||||
|
public LocallyConnected1D build() {
|
||||||
|
LocallyConnected1D l = new LocallyConnected1D(this);
|
||||||
|
Convolution1DUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
|
Convolution1DUtils.validateCnn1DKernelStridePadding(
|
||||||
|
l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
|
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract static class LocallyConnected1DBuilder<
|
||||||
|
C extends LocallyConnected1D, B extends LocallyConnected1DBuilder<C, B>>
|
||||||
|
extends SameDiffLayerBuilder<C, B> {}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.enums.PadMode;
|
import org.nd4j.enums.PadMode;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.IActivation;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -49,7 +48,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@JsonIgnoreProperties({"paramShapes"})
|
@JsonIgnoreProperties({"paramShapes"})
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class LocallyConnected2D extends SameDiffLayer {
|
public class LocallyConnected2D extends SameDiffLayer {
|
||||||
|
|
||||||
private static final List<String> WEIGHT_KEYS =
|
private static final List<String> WEIGHT_KEYS =
|
||||||
|
@ -318,40 +317,44 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class LocallyConnected2DBuilder<
|
private static final class LocallyConnected2DBuilderImpl
|
||||||
C extends LocallyConnected2D, B extends LocallyConnected2DBuilder<C, B>>
|
extends LocallyConnected2DBuilder<LocallyConnected2D, LocallyConnected2DBuilderImpl> {
|
||||||
extends SameDiffLayerBuilder<C, B> {
|
public LocallyConnected2D build() {
|
||||||
public C build() {
|
LocallyConnected2D l = new LocallyConnected2D(this);
|
||||||
featureDim(kernel$value[0] * kernel$value[1] * (int) nIn);
|
l.setFeatureDim(l.getKernel()[0] * l.getKernel()[1] * (int) l.getNIn());
|
||||||
C l = initBuild();
|
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public B kernelSize(int ... kernel) {
|
public abstract static class LocallyConnected2DBuilder<
|
||||||
this.kernel$value = ValidationUtils.validate2NonNegative(kernel, false, "kernel");
|
C extends LocallyConnected2D, B extends LocallyConnected2DBuilder<C, B>>
|
||||||
|
extends SameDiffLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
public B kernelSize(int... kernel) {
|
||||||
|
this.kernel$value = ValidationUtils.validate2NonNegative(kernel, false, "kernel");
|
||||||
this.kernel$set = true;
|
this.kernel$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B inputSize(int ... size) {
|
public B inputSize(int... size) {
|
||||||
this.inputSize = size;
|
this.inputSize = size;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B stride(int ... stride) {
|
public B stride(int... stride) {
|
||||||
this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride");
|
this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride");
|
||||||
this.stride$set = true;
|
this.stride$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B padding(int ... padding) {
|
public B padding(int... padding) {
|
||||||
this.padding$value = ValidationUtils.validate2NonNegative(padding, false, "padding");
|
this.padding$value = ValidationUtils.validate2NonNegative(padding, false, "padding");
|
||||||
this.padding$set = true;
|
this.padding$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B dilation(int ... dilation) {
|
public B dilation(int... dilation) {
|
||||||
this.dilation$value = ValidationUtils.validate2NonNegative(dilation, false, "dilation");
|
this.dilation$value = ValidationUtils.validate2NonNegative(dilation, false, "dilation");
|
||||||
this.dilation$set = true;
|
this.dilation$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,6 @@ package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -33,13 +32,12 @@ import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class OutputLayer extends BaseOutputLayer {
|
public class OutputLayer extends BaseOutputLayer {
|
||||||
|
|
||||||
{ // Set default activation function to softmax (to match default loss function MCXENT)
|
{ // Set default activation function to softmax (to match default loss function MCXENT)
|
||||||
|
@ -82,15 +80,16 @@ public class OutputLayer extends BaseOutputLayer {
|
||||||
return DefaultParamInitializer.getInstance();
|
return DefaultParamInitializer.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class OutputLayerBuilder<
|
public abstract static class OutputLayerBuilder<
|
||||||
C extends OutputLayer, B extends OutputLayerBuilder<C, B>>
|
C extends OutputLayer, B extends OutputLayerBuilder<C, B>>
|
||||||
extends BaseOutputLayerBuilder<C, B> {
|
extends BaseOutputLayerBuilder<C, B> {}
|
||||||
public C build() {
|
|
||||||
C l = this.initBuild();
|
private static final class OutputLayerBuilderImpl
|
||||||
|
extends OutputLayerBuilder<OutputLayer, OutputLayerBuilderImpl> {
|
||||||
|
public OutputLayer build() {
|
||||||
|
OutputLayer l = new OutputLayer(this);
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class PReLULayer extends BaseLayerConfiguration {
|
public class PReLULayer extends BaseLayerConfiguration {
|
||||||
/**
|
/**
|
||||||
* Explicitly set input shape of incoming activations so that parameters can be initialized
|
* Explicitly set input shape of incoming activations so that parameters can be initialized
|
||||||
|
@ -129,14 +129,17 @@ public class PReLULayer extends BaseLayerConfiguration {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class PReLULayerBuilder<
|
private static final class PReLULayerBuilderImpl extends PReLULayerBuilder<PReLULayer, PReLULayerBuilderImpl> {
|
||||||
C extends PReLULayer, B extends PReLULayerBuilder<C, B>>
|
public PReLULayer build() {
|
||||||
extends BaseLayerConfigurationBuilder<C, B> {
|
PReLULayer l = new PReLULayer(this);
|
||||||
public C build() {
|
|
||||||
C l = initBuild();
|
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
public static abstract class PReLULayerBuilder<
|
||||||
|
C extends PReLULayer, B extends PReLULayerBuilder<C, B>>
|
||||||
|
extends BaseLayerConfigurationBuilder<C, B> {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Explicitly set input shape of incoming activations so that parameters can be initialized
|
* Explicitly set input shape of incoming activations so that parameters can be initialized
|
||||||
|
|
|
@ -35,6 +35,6 @@ import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class Pooling1D extends Subsampling1DLayer {
|
public class Pooling1D extends Subsampling1DLayer {
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,6 @@ import lombok.experimental.SuperBuilder;
|
||||||
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class Pooling2D extends SubsamplingLayer {
|
public class Pooling2D extends SubsamplingLayer {
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class PrimaryCapsules extends SameDiffLayer {
|
public class PrimaryCapsules extends SameDiffLayer {
|
||||||
|
|
||||||
private static final String WEIGHT_PARAM = "weight";
|
private static final String WEIGHT_PARAM = "weight";
|
||||||
|
@ -335,7 +335,7 @@ public class PrimaryCapsules extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class PrimaryCapsulesBuilder<
|
public abstract static class PrimaryCapsulesBuilder<
|
||||||
C extends PrimaryCapsules, B extends PrimaryCapsulesBuilder<C, B>>
|
C extends PrimaryCapsules, B extends PrimaryCapsulesBuilder<C, B>>
|
||||||
extends SameDiffLayerBuilder<C, B> {
|
extends SameDiffLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
@ -396,27 +396,30 @@ public class PrimaryCapsules extends SameDiffLayer {
|
||||||
this.useLeakyReLU$set = true;
|
this.useLeakyReLU$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public C build() {
|
private static final class PrimaryCapsulesBuilderImpl
|
||||||
C l = initBuild();
|
extends PrimaryCapsulesBuilder<PrimaryCapsules, PrimaryCapsulesBuilderImpl> {
|
||||||
if (capsuleDimensions <= 0 || channels$value <= 0) {
|
public PrimaryCapsules build() {
|
||||||
|
PrimaryCapsules l = new PrimaryCapsules(this);
|
||||||
|
if (l.getCapsuleDimensions() <= 0 || l.getChannels() <= 0) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Invalid configuration for Primary Capsules (layer name = \""
|
"Invalid configuration for Primary Capsules (layer name = \""
|
||||||
+ l.getName()
|
+ l.getName()
|
||||||
+ "\"):"
|
+ "\"):"
|
||||||
+ " capsuleDimensions and channels must be > 0. Got: "
|
+ " capsuleDimensions and channels must be > 0. Got: "
|
||||||
+ capsuleDimensions
|
+ l.getCapsuleDimensions()
|
||||||
+ ", "
|
+ ", "
|
||||||
+ channels$value);
|
+ l.getChannels());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (capsules < 0) {
|
if (l.getCapsules() < 0) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Invalid configuration for Capsule ILayer (layer name = \""
|
"Invalid configuration for Capsule ILayer (layer name = \""
|
||||||
+ l.getName()
|
+ l.getName()
|
||||||
+ "\"):"
|
+ "\"):"
|
||||||
+ " capsules must be >= 0 if set. Got: "
|
+ " capsules must be >= 0 if set. Got: "
|
||||||
+ capsules);
|
+ l.getCapsules());
|
||||||
}
|
}
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,21 +43,25 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class RecurrentAttentionLayer extends SameDiffLayer {
|
public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
|
|
||||||
|
private static final class RecurrentAttentionLayerBuilderImpl extends RecurrentAttentionLayerBuilder<RecurrentAttentionLayer, RecurrentAttentionLayerBuilderImpl> {
|
||||||
|
public RecurrentAttentionLayer build() {
|
||||||
|
RecurrentAttentionLayer l = new RecurrentAttentionLayer(this);
|
||||||
|
Preconditions.checkArgument(l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
|
||||||
|
Preconditions.checkArgument(l.isProjectInput() || l.getNIn() == l.getNOut(), "nIn must be equal to nOut when projectInput is false");
|
||||||
|
Preconditions.checkArgument(!l.isProjectInput() || l.getNOut() != 0, "nOut must be specified when projectInput is true");
|
||||||
|
Preconditions.checkArgument(l.getNOut() % l.getNHeads() == 0 || l.getNHeads() > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||||
|
|
||||||
|
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
public static abstract class RecurrentAttentionLayerBuilder<C extends RecurrentAttentionLayer, B extends RecurrentAttentionLayerBuilder<C,B>>
|
public static abstract class RecurrentAttentionLayerBuilder<C extends RecurrentAttentionLayer, B extends RecurrentAttentionLayerBuilder<C,B>>
|
||||||
extends SameDiffLayerBuilder<C,B> {
|
extends SameDiffLayerBuilder<C,B> {
|
||||||
|
|
||||||
public C build() {
|
|
||||||
Preconditions.checkArgument(this.projectInput$value || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
|
||||||
Preconditions.checkArgument(this.projectInput$value || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
|
||||||
Preconditions.checkArgument(!this.projectInput$value || nOut != 0, "nOut must be specified when projectInput is true");
|
|
||||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
|
||||||
|
|
||||||
C l = initBuild();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -24,7 +24,6 @@ import java.util.Collection;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
|
@ -42,88 +41,104 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class RnnOutputLayer extends BaseOutputLayer {
|
public class RnnOutputLayer extends BaseOutputLayer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size, timeSeriesLength],
|
* @param rnnDataFormat Data format expected by the layer. NCW = [miniBatchSize, size,
|
||||||
* NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
|
* timeSeriesLength], NWC = [miniBatchSize, timeSeriesLength, size]. Defaults to NCW.
|
||||||
*/
|
*/
|
||||||
private RNNFormat dataFormat;
|
private RNNFormat dataFormat;
|
||||||
|
|
||||||
public static RnnOutputLayerBuilder<?,?> builder() {
|
public static RnnOutputLayerBuilder<?, ?> builder() {
|
||||||
return innerBuilder();
|
return innerBuilder();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param lossFn Loss function for the output layer
|
||||||
|
*/
|
||||||
|
public static RnnOutputLayerBuilder<?, ?> builder(LossFunctions.LossFunction lossFn) {
|
||||||
|
return innerBuilder().lossFunction(lossFn);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Layer instantiate(
|
||||||
|
NeuralNetConfiguration conf,
|
||||||
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
|
LayerValidation.assertNInNOutSet("RnnOutputLayer", getName(), layerIndex, getNIn(), getNOut());
|
||||||
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
|
||||||
|
org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer ret =
|
||||||
|
new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(lconf, networkDataType);
|
||||||
|
ret.addTrainingListeners(trainingListeners);
|
||||||
|
ret.setIndex(layerIndex);
|
||||||
|
ret.setParamsViewArray(layerParamsView);
|
||||||
|
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||||
|
ret.setParamTable(paramTable);
|
||||||
|
ret.setLayerConfiguration(lconf);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParamInitializer initializer() {
|
||||||
|
return DefaultParamInitializer.getInstance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
|
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Invalid input type for RnnOutputLayer (layer index = "
|
||||||
|
+ layerIndex
|
||||||
|
+ ", layer name=\""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): Expected RNN input, got "
|
||||||
|
+ inputType);
|
||||||
|
}
|
||||||
|
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
||||||
|
|
||||||
|
return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
|
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Invalid input type for RnnOutputLayer (layer name=\""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): Expected RNN input, got "
|
||||||
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||||
* @param lossFn Loss function for the output layer
|
if (dataFormat == null || override) {
|
||||||
*/
|
this.dataFormat = r.getFormat();
|
||||||
public static RnnOutputLayerBuilder<?,?> builder(LossFunctions.LossFunction lossFn) {
|
|
||||||
return innerBuilder()
|
|
||||||
.lossFunction(lossFn);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
if (nIn <= 0 || override) {
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
this.nIn = r.getSize();
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
|
||||||
LayerValidation.assertNInNOutSet("RnnOutputLayer", getName(), layerIndex, getNIn(), getNOut());
|
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
|
||||||
|
|
||||||
org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer ret =
|
|
||||||
new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(lconf, networkDataType);
|
|
||||||
ret.addTrainingListeners(trainingListeners);
|
|
||||||
ret.setIndex(layerIndex);
|
|
||||||
ret.setParamsViewArray(layerParamsView);
|
|
||||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
|
||||||
ret.setParamTable(paramTable);
|
|
||||||
ret.setLayerConfiguration(lconf);
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
return DefaultParamInitializer.getInstance();
|
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, dataFormat, getName());
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract static class RnnOutputLayerBuilder<
|
||||||
|
C extends RnnOutputLayer, B extends RnnOutputLayerBuilder<C, B>>
|
||||||
|
extends BaseOutputLayerBuilder<C, B> {}
|
||||||
|
|
||||||
|
private static final class RnnOutputLayerBuilderImpl
|
||||||
|
extends RnnOutputLayerBuilder<RnnOutputLayer, RnnOutputLayerBuilderImpl> {
|
||||||
|
public RnnOutputLayer build() {
|
||||||
|
RnnOutputLayer l = new RnnOutputLayer(this);
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
@Override
|
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
|
||||||
throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer index = " + layerIndex
|
|
||||||
+ ", layer name=\"" + getName() + "\"): Expected RNN input, got " + inputType);
|
|
||||||
}
|
|
||||||
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
|
||||||
|
|
||||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength(), itr.getFormat());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
|
||||||
throw new IllegalStateException("Invalid input type for RnnOutputLayer (layer name=\"" + getName()
|
|
||||||
+ "\"): Expected RNN input, got " + inputType);
|
|
||||||
}
|
|
||||||
|
|
||||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
|
||||||
if(dataFormat == null || override) {
|
|
||||||
this.dataFormat = r.getFormat();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nIn <= 0 || override) {
|
|
||||||
this.nIn = r.getSize();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
|
||||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, dataFormat, getName());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static abstract class RnnOutputLayerBuilder<C extends RnnOutputLayer, B extends RnnOutputLayerBuilder<C, B>> extends BaseOutputLayerBuilder<C, B> {
|
|
||||||
public C build() {
|
|
||||||
C l = this.initBuild();
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class SelfAttentionLayer extends SameDiffLayer {
|
public class SelfAttentionLayer extends SameDiffLayer {
|
||||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||||
|
|
|
@ -44,7 +44,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class SeparableConvolution2D extends ConvolutionLayer {
|
public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to the point-wise convolution weight parameters of this layer.
|
* Set constraints to be applied to the point-wise convolution weight parameters of this layer.
|
||||||
|
|
|
@ -50,7 +50,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class Subsampling1DLayer extends SubsamplingLayer {
|
public class Subsampling1DLayer extends SubsamplingLayer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -153,11 +153,9 @@ public class Subsampling1DLayer extends SubsamplingLayer {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class Subsampling1DLayerBuilder<C extends Subsampling1DLayer, B extends Subsampling1DLayerBuilder<C, B>> extends
|
private static final class Subsampling1DLayerBuilderImpl extends Subsampling1DLayerBuilder<Subsampling1DLayer, Subsampling1DLayerBuilderImpl> {
|
||||||
SubsamplingLayerBuilder<C, B> {
|
public Subsampling1DLayer build() {
|
||||||
|
Subsampling1DLayer l =new Subsampling1DLayer(this);
|
||||||
public C build() {
|
|
||||||
C l = this.initBuild();
|
|
||||||
if (l.getPoolingType() == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && l.getPnorm() <= 0) {
|
if (l.getPoolingType() == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && l.getPnorm() <= 0) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM");
|
"Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM");
|
||||||
|
@ -167,6 +165,11 @@ public class Subsampling1DLayer extends SubsamplingLayer {
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
public static abstract class Subsampling1DLayerBuilder<C extends Subsampling1DLayer, B extends Subsampling1DLayerBuilder<C, B>> extends
|
||||||
|
SubsamplingLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param kernelSize
|
* @param kernelSize
|
||||||
|
|
|
@ -45,7 +45,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class Subsampling3DLayer extends NoParamLayer {
|
public class Subsampling3DLayer extends NoParamLayer {
|
||||||
|
|
||||||
@Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
@Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||||
|
@ -304,17 +304,22 @@ public class Subsampling3DLayer extends NoParamLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public C build() {
|
|
||||||
if (kernelSize.length != 3) {
|
}
|
||||||
|
|
||||||
|
private static final class Subsampling3DLayerBuilderImpl extends Subsampling3DLayerBuilder<Subsampling3DLayer, Subsampling3DLayerBuilderImpl> {
|
||||||
|
public Subsampling3DLayer build() {
|
||||||
|
Subsampling3DLayer l = new Subsampling3DLayer(this);
|
||||||
|
if (l.getKernelSize().length != 3) {
|
||||||
throw new IllegalArgumentException("Kernel size must be length 3");
|
throw new IllegalArgumentException("Kernel size must be length 3");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (stride.length != 3) {
|
if (l.getStride().length != 3) {
|
||||||
throw new IllegalArgumentException("Invalid stride, must be length 3");
|
throw new IllegalArgumentException("Invalid stride, must be length 3");
|
||||||
}
|
}
|
||||||
C l = this.initBuild();
|
|
||||||
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), padding);
|
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
Convolution3DUtils.validateCnn3DKernelStridePadding(kernelSize, stride, padding);
|
Convolution3DUtils.validateCnn3DKernelStridePadding(l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class SubsamplingLayer extends NoParamLayer {
|
public class SubsamplingLayer extends NoParamLayer {
|
||||||
|
|
||||||
public static final CNN2DFormat DEFAULT_FORMAT = CNN2DFormat.NCHW;
|
public static final CNN2DFormat DEFAULT_FORMAT = CNN2DFormat.NCHW;
|
||||||
|
@ -425,25 +425,7 @@ public class SubsamplingLayer extends NoParamLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public C build() {
|
|
||||||
if (kernelSize$value.length != 2) {
|
|
||||||
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (stride$value.length != 2) {
|
|
||||||
throw new IllegalArgumentException("Invalid stride, must be length 2");
|
|
||||||
}
|
|
||||||
if (poolingType$value == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && pnorm <= 0) {
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM");
|
|
||||||
}
|
|
||||||
ConvolutionUtils.validateConvolutionModePadding(convolutionMode$value, padding$value);
|
|
||||||
ConvolutionUtils.validateCnnKernelStridePadding(
|
|
||||||
kernelSize$value, stride$value, padding$value);
|
|
||||||
|
|
||||||
C l = initBuild();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
public B setConvolutionMode(ConvolutionMode convolutionMode){
|
public B setConvolutionMode(ConvolutionMode convolutionMode){
|
||||||
Preconditions.checkState(allowCausal$value || convolutionMode$value != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
Preconditions.checkState(allowCausal$value || convolutionMode$value != ConvolutionMode.Causal, "Causal convolution mode can only be used with 1D" +
|
||||||
|
@ -459,4 +441,25 @@ public class SubsamplingLayer extends NoParamLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
private static final class SubsamplingLayerBuilderImpl extends SubsamplingLayerBuilder<SubsamplingLayer, SubsamplingLayerBuilderImpl> {
|
||||||
|
public SubsamplingLayer build() {
|
||||||
|
SubsamplingLayer l = new SubsamplingLayer(this);
|
||||||
|
if (l.getKernelSize().length != 2) {
|
||||||
|
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l.getStride().length != 2) {
|
||||||
|
throw new IllegalArgumentException("Invalid stride, must be length 2");
|
||||||
|
}
|
||||||
|
if (l.getPoolingType() == org.deeplearning4j.nn.conf.layers.PoolingType.PNORM && l.getPnorm() <= 0) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Incorrect Subsampling config: p-norm must be set when using PoolingType.PNORM");
|
||||||
|
}
|
||||||
|
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
|
ConvolutionUtils.validateCnnKernelStridePadding(
|
||||||
|
l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
|
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,7 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class ZeroPaddingLayer extends NoParamLayer {
|
public class ZeroPaddingLayer extends NoParamLayer {
|
||||||
/**
|
/**
|
||||||
* @param padding Padding value for top, bottom, left, and right. Must be length 4 array
|
* @param padding Padding value for top, bottom, left, and right. Must be length 4 array
|
||||||
|
|
|
@ -48,7 +48,7 @@ import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
||||||
|
|
||||||
|
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
@SuperBuilder(buildMethodName = "initBuild")
|
@SuperBuilder
|
||||||
public class Yolo2OutputLayer extends LayerConfiguration {
|
public class Yolo2OutputLayer extends LayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -43,8 +43,7 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration {
|
||||||
|
|
||||||
/** The configuration to of another layer to wrap */
|
/** The configuration to of another layer to wrap */
|
||||||
@Getter @Setter
|
@Getter @Setter protected LayerConfiguration underlying;
|
||||||
protected LayerConfiguration underlying;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set the net configuration for this configuration as well as for the underlying layer (if not
|
* Set the net configuration for this configuration as well as for the underlying layer (if not
|
||||||
|
|
Loading…
Reference in New Issue