Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-27 16:03:54 +02:00
parent 7628bbdd53
commit cb236878a4
7 changed files with 192 additions and 170 deletions

View File

@ -138,7 +138,7 @@ public class ActivationLayer extends NoParamLayer {
private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> { private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> {
public ActivationLayer build() { public ActivationLayer build() {
ActivationLayer l = this.initBuild(); ActivationLayer l = new ActivationLayer(this);
l.initializeConstraints(); l.initializeConstraints();
return l; return l;
} }

View File

@ -121,13 +121,10 @@ public class CenterLossOutputLayer extends BaseOutputLayer {
.build(); .build();
} }
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
BaseOutputLayerBuilder<C, B> { BaseOutputLayerBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
} }
private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer, private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer,

View File

@ -185,7 +185,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> { private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> {
public ConvolutionLayer build() { public ConvolutionLayer build() {
ConvolutionLayer l = initBuild(); ConvolutionLayer l = new ConvolutionLayer(this);
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding()); ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
ConvolutionUtils.validateCnnKernelStridePadding( ConvolutionUtils.validateCnnKernelStridePadding(
l.getKernelSize(), l.getStride(), l.getPadding()); l.getKernelSize(), l.getStride(), l.getPadding());

View File

@ -144,21 +144,27 @@ public class SelfAttentionLayer extends SameDiffLayer {
} }
} }
public static abstract class SelfAttentionLayerBuilder< public abstract static class SelfAttentionLayerBuilder<
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>> C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> { extends SameDiffLayerBuilder<C, B> {}
public C build() {
private static final class SelfAttentionLayerBuilderImpl
extends SelfAttentionLayerBuilder<SelfAttentionLayer, SelfAttentionLayerBuilderImpl> {
public SelfAttentionLayer build() {
SelfAttentionLayer l = new SelfAttentionLayer(this);
Preconditions.checkArgument( Preconditions.checkArgument(
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1"); l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument( Preconditions.checkArgument(
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false"); l.isProjectInput() || l.getNIn() == l.getNOut(),
"nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument( Preconditions.checkArgument(
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true"); !l.isProjectInput() || l.getNOut() != 0,
"nOut must be specified when projectInput is true");
Preconditions.checkArgument( Preconditions.checkArgument(
this.nOut % nHeads == 0 || headSize > 0, l.getNOut() % l.getNHeads() == 0 || l.getHeadSize() > 0,
"nOut isn't divided by nHeads cleanly. Specify the headSize manually."); "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
return initBuild(); return l;
} }
} }
} }

View File

@ -54,14 +54,6 @@ public class SeparableConvolution2D extends ConvolutionLayer {
* have been updated. * have been updated.
*/ */
protected List<LayerConstraint> pointWiseConstraints; protected List<LayerConstraint> pointWiseConstraints;
/**
* Set channels multiplier of channels-wise step in separable convolution
*
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
* channels-wise step.
* @return Builder
*/
@Builder.Default private int depthMultiplier = 1;
/** /**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br> * See {@link CNN2DFormat} for more details.<br>
@ -72,6 +64,15 @@ public class SeparableConvolution2D extends ConvolutionLayer {
@Builder.Default @Builder.Default
protected CNN2DFormat dataFormat = protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons CNN2DFormat.NCHW; // default value for legacy serialization reasons
/**
* Set channels multiplier of channels-wise step in separable convolution
*
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
* channels-wise step.
* @return Builder
*/
@Builder.Default private int depthMultiplier = 1;
public static SeparableConvolution2DBuilder<?, ?> builder() { public static SeparableConvolution2DBuilder<?, ?> builder() {
return innerBuilder(); return innerBuilder();
} }
@ -109,13 +110,13 @@ public class SeparableConvolution2D extends ConvolutionLayer {
public SeparableConvolution2D clone() { public SeparableConvolution2D clone() {
SeparableConvolution2D clone = (SeparableConvolution2D) super.clone(); SeparableConvolution2D clone = (SeparableConvolution2D) super.clone();
if (clone.getKernelSize() != null) { if (clone.getKernelSize() != null) {
clone.setKernelSize( clone.getKernelSize().clone()); clone.setKernelSize(clone.getKernelSize().clone());
} }
if (clone.getStride() != null) { if (clone.getStride() != null) {
clone.setStride( clone.getStride().clone()); clone.setStride(clone.getStride().clone());
} }
if (clone.getPadding() != null) { if (clone.getPadding() != null) {
clone.setPadding( clone.getPadding().clone()); clone.setPadding(clone.getPadding().clone());
} }
return clone; return clone;
} }
@ -176,27 +177,9 @@ public class SeparableConvolution2D extends ConvolutionLayer {
SeparableConvolution2DLayer.class); SeparableConvolution2DLayer.class);
} }
public static abstract class SeparableConvolution2DBuilder< public abstract static class SeparableConvolution2DBuilder<
C extends SeparableConvolution2D, B extends SeparableConvolution2DBuilder<C, B>> C extends SeparableConvolution2D, B extends SeparableConvolution2DBuilder<C, B>>
extends ConvolutionLayerBuilder<C, B> { extends ConvolutionLayerBuilder<C, B> {
public C build() {
C l = this.initBuild();
if (l.getKernelSize().length != 2) {
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
}
if (l.getStride().length != 2) {
throw new IllegalArgumentException(
"Stride should include stride for rows and columns (a 2d array)");
}
if (l.getPadding().length != 2) {
throw new IllegalArgumentException(
"Padding should include padding for rows and columns (a 2d array)");
}
l.initializeConstraints();
return l;
}
/** /**
* Set constraints to be applied to the point-wise convolution weight parameters of this layer. * Set constraints to be applied to the point-wise convolution weight parameters of this layer.
@ -231,4 +214,27 @@ public class SeparableConvolution2D extends ConvolutionLayer {
return self(); return self();
} }
} }
private static final class SeparableConvolution2DBuilderImpl
extends SeparableConvolution2DBuilder<
SeparableConvolution2D, SeparableConvolution2DBuilderImpl> {
public SeparableConvolution2D build() {
SeparableConvolution2D l = new SeparableConvolution2D(this);
if (l.getKernelSize().length != 2) {
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
}
if (l.getStride().length != 2) {
throw new IllegalArgumentException(
"Stride should include stride for rows and columns (a 2d array)");
}
if (l.getPadding().length != 2) {
throw new IllegalArgumentException(
"Padding should include padding for rows and columns (a 2d array)");
}
l.initializeConstraints();
return l;
}
}
} }

View File

@ -20,6 +20,9 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
@ -35,10 +38,6 @@ import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
@ -47,9 +46,17 @@ public class ZeroPaddingLayer extends NoParamLayer {
* @param padding Padding value for top, bottom, left, and right. Must be length 4 array * @param padding Padding value for top, bottom, left, and right. Must be length 4 array
*/ */
@Builder.Default @Builder.Default
private int[] padding = new int[] {0, 0, 0, 0}; //Padding: top, bottom, left, right private int[] padding = new int[] {0, 0, 0, 0}; // Padding: top, bottom, left, right
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
public static ZeroPaddingLayerBuilder<?,?> builder() { public static ZeroPaddingLayerBuilder<?, ?> builder() {
return innerBuilder(); return innerBuilder();
} }
@ -57,63 +64,35 @@ public class ZeroPaddingLayer extends NoParamLayer {
* @param padHeight Padding for both the top and bottom * @param padHeight Padding for both the top and bottom
* @param padWidth Padding for both the left and right * @param padWidth Padding for both the left and right
*/ */
public static ZeroPaddingLayerBuilder<?,?> builder(int padHeight, int padWidth) { public static ZeroPaddingLayerBuilder<?, ?> builder(int padHeight, int padWidth) {
return innerBuilder() return innerBuilder().padding(padHeight, padHeight, padWidth, padWidth);
.padding(padHeight, padHeight, padWidth, padWidth);
} }
/** /**
* @param padTop Top padding value * @param padTop Top padding value
* @param padBottom Bottom padding value * @param padBottom Bottom padding value
* @param padLeft Left padding value * @param padLeft Left padding value
* @param padRight Right padding value * @param padRight Right padding value
*/ */
public static ZeroPaddingLayerBuilder<?,?> builder(int padTop, int padBottom, int padLeft, int padRight) { public static ZeroPaddingLayerBuilder<?, ?> builder(
return innerBuilder() int padTop, int padBottom, int padLeft, int padRight) {
.padding(padTop, padBottom, padLeft, padRight); return innerBuilder().padding(padTop, padBottom, padLeft, padRight);
} }
public static ZeroPaddingLayerBuilder<?,?> builder(int[] padding) { public static ZeroPaddingLayerBuilder<?, ?> builder(int[] padding) {
return innerBuilder() return innerBuilder().padding(padding);
.padding(padding);
} }
public static abstract class ZeroPaddingLayerBuilder<C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
extends NoParamLayerBuilder<C, B> {
public C build() {
if (padding$value == null || padding$value.length != 4) {
throw new IllegalArgumentException(
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]." + " Got: "
+ (padding$value == null ? null : Arrays.toString(padding$value)));
}
C l = initBuild();
l.initializeConstraints();
return l;
}
public B padding(int ... padding) {
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
this.padding$set = true;
return self();
}
}
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
* @param format Format for activations (in and out)
*/
@Builder.Default
private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
@Override @Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, public org.deeplearning4j.nn.api.Layer instantiate(
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, NeuralNetConfiguration conf,
boolean initializeParams, DataType networkDataType) { Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance(); runInheritance();
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret = org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret =
new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(lconf, networkDataType); new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(lconf, networkDataType);
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
@ -130,15 +109,18 @@ runInheritance();
int outH = hwd[0] + padding[0] + padding[1]; int outH = hwd[0] + padding[0] + padding[1];
int outW = hwd[1] + padding[2] + padding[3]; int outW = hwd[1] + padding[2] + padding[3];
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional(outH, outW, hwd[2], c.getFormat()); return InputType.convolutional(outH, outW, hwd[2], c.getFormat());
} }
@Override @Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
Preconditions.checkArgument(inputType != null, "Invalid input for ZeroPaddingLayer layer (layer name=\"" Preconditions.checkArgument(
+ getName() + "\"): InputType is null"); inputType != null,
"Invalid input for ZeroPaddingLayer layer (layer name=\""
+ getName()
+ "\"): InputType is null");
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName()); return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
} }
@ -147,18 +129,45 @@ runInheritance();
InputType outputType = getOutputType(-1, inputType); InputType outputType = getOutputType(-1, inputType);
return new LayerMemoryReport.Builder(name, ZeroPaddingLayer.class, inputType, outputType) return new LayerMemoryReport.Builder(name, ZeroPaddingLayer.class, inputType, outputType)
.standardMemory(0, 0) //No params .standardMemory(0, 0) // No params
//Inference and training is same - just output activations, no working memory in addition to that // Inference and training is same - just output activations, no working memory in addition
// to that
.workingMemory(0, 0, MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) .workingMemory(0, 0, MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS)
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching .cacheMemory(
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
.build(); .build();
} }
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType; InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
this.dataFormat = c.getFormat(); this.dataFormat = c.getFormat();
} }
private static final class ZeroPaddingLayerBuilderImpl
extends ZeroPaddingLayerBuilder<ZeroPaddingLayer, ZeroPaddingLayerBuilderImpl> {
public ZeroPaddingLayer build() {
ZeroPaddingLayer l = new ZeroPaddingLayer(this);
if (l.getPadding() == null || l.getPadding().length != 4) {
throw new IllegalArgumentException(
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]."
+ " Got: "
+ (l.getPadding() == null ? null : Arrays.toString(l.getPadding())));
}
l.initializeConstraints();
return l;
}
}
public abstract static class ZeroPaddingLayerBuilder<
C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
extends NoParamLayerBuilder<C, B> {
public B padding(int... padding) {
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
this.padding$set = true;
return self();
}
}
} }

View File

@ -164,17 +164,21 @@ public class Yolo2OutputLayer extends LayerConfiguration {
public static abstract class Yolo2OutputLayerBuilder< public static abstract class Yolo2OutputLayerBuilder<
C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>> C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> { extends LayerConfigurationBuilder<C, B> {
public C build() {
if (boundingBoxes == null) { }
private static final class Yolo2OutputLayerBuilderImpl extends Yolo2OutputLayerBuilder<Yolo2OutputLayer, Yolo2OutputLayerBuilderImpl> {
public Yolo2OutputLayer build() {
Yolo2OutputLayer l = new Yolo2OutputLayer(this);
if (l.getBoundingBoxes() == null) {
throw new IllegalStateException("Bounding boxes have not been set"); throw new IllegalStateException("Bounding boxes have not been set");
} }
if (boundingBoxes.rank() != 2 || boundingBoxes.size(1) != 2) { if (l.getBoundingBoxes().rank() != 2 || l.getBoundingBoxes().size(1) != 2) {
throw new IllegalStateException( throw new IllegalStateException(
"Bounding box priors must have shape [nBoxes, 2]. Has shape: " "Bounding box priors must have shape [nBoxes, 2]. Has shape: "
+ Arrays.toString(boundingBoxes.shape())); + Arrays.toString(l.getBoundingBoxes().shape()));
} }
return initBuild(); return l;
} }
} }
} }