Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-27 16:03:54 +02:00
parent 7628bbdd53
commit cb236878a4
7 changed files with 192 additions and 170 deletions

View File

@ -138,7 +138,7 @@ public class ActivationLayer extends NoParamLayer {
private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> {
public ActivationLayer build() {
ActivationLayer l = this.initBuild();
ActivationLayer l = new ActivationLayer(this);
l.initializeConstraints();
return l;
}

View File

@ -121,13 +121,10 @@ public class CenterLossOutputLayer extends BaseOutputLayer {
.build();
}
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
BaseOutputLayerBuilder<C, B> {
public C build() {
C l = initBuild();
l.initializeConstraints();
return l;
}
}
private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer,

View File

@ -185,7 +185,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> {
public ConvolutionLayer build() {
ConvolutionLayer l = initBuild();
ConvolutionLayer l = new ConvolutionLayer(this);
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
ConvolutionUtils.validateCnnKernelStridePadding(
l.getKernelSize(), l.getStride(), l.getPadding());

View File

@ -144,21 +144,27 @@ public class SelfAttentionLayer extends SameDiffLayer {
}
}
public static abstract class SelfAttentionLayerBuilder<
public abstract static class SelfAttentionLayerBuilder<
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> {
public C build() {
extends SameDiffLayerBuilder<C, B> {}
private static final class SelfAttentionLayerBuilderImpl
extends SelfAttentionLayerBuilder<SelfAttentionLayer, SelfAttentionLayerBuilderImpl> {
public SelfAttentionLayer build() {
SelfAttentionLayer l = new SelfAttentionLayer(this);
Preconditions.checkArgument(
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
l.isProjectInput() || l.getNIn() == l.getNOut(),
"nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
!l.isProjectInput() || l.getNOut() != 0,
"nOut must be specified when projectInput is true");
Preconditions.checkArgument(
this.nOut % nHeads == 0 || headSize > 0,
l.getNOut() % l.getNHeads() == 0 || l.getHeadSize() > 0,
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
return initBuild();
return l;
}
}
}

View File

@ -54,14 +54,6 @@ public class SeparableConvolution2D extends ConvolutionLayer {
* have been updated.
*/
protected List<LayerConstraint> pointWiseConstraints;
/**
* Set channels multiplier of channels-wise step in separable convolution
*
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
* channels-wise step.
* @return Builder
*/
@Builder.Default private int depthMultiplier = 1;
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
@ -72,6 +64,15 @@ public class SeparableConvolution2D extends ConvolutionLayer {
@Builder.Default
protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons
/**
* Set channels multiplier of channels-wise step in separable convolution
*
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
* channels-wise step.
* @return Builder
*/
@Builder.Default private int depthMultiplier = 1;
public static SeparableConvolution2DBuilder<?, ?> builder() {
return innerBuilder();
}
@ -176,27 +177,9 @@ public class SeparableConvolution2D extends ConvolutionLayer {
SeparableConvolution2DLayer.class);
}
public static abstract class SeparableConvolution2DBuilder<
public abstract static class SeparableConvolution2DBuilder<
C extends SeparableConvolution2D, B extends SeparableConvolution2DBuilder<C, B>>
extends ConvolutionLayerBuilder<C, B> {
public C build() {
C l = this.initBuild();
if (l.getKernelSize().length != 2) {
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
}
if (l.getStride().length != 2) {
throw new IllegalArgumentException(
"Stride should include stride for rows and columns (a 2d array)");
}
if (l.getPadding().length != 2) {
throw new IllegalArgumentException(
"Padding should include padding for rows and columns (a 2d array)");
}
l.initializeConstraints();
return l;
}
/**
* Set constraints to be applied to the point-wise convolution weight parameters of this layer.
@ -231,4 +214,27 @@ public class SeparableConvolution2D extends ConvolutionLayer {
return self();
}
}
private static final class SeparableConvolution2DBuilderImpl
extends SeparableConvolution2DBuilder<
SeparableConvolution2D, SeparableConvolution2DBuilderImpl> {
public SeparableConvolution2D build() {
SeparableConvolution2D l = new SeparableConvolution2D(this);
if (l.getKernelSize().length != 2) {
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
}
if (l.getStride().length != 2) {
throw new IllegalArgumentException(
"Stride should include stride for rows and columns (a 2d array)");
}
if (l.getPadding().length != 2) {
throw new IllegalArgumentException(
"Padding should include padding for rows and columns (a 2d array)");
}
l.initializeConstraints();
return l;
}
}
}

View File

@ -20,6 +20,9 @@
package org.deeplearning4j.nn.conf.layers;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.CNN2DFormat;
@ -35,10 +38,6 @@ import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
@Data
@EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder")
@ -48,6 +47,14 @@ public class ZeroPaddingLayer extends NoParamLayer {
*/
@Builder.Default
private int[] padding = new int[] {0, 0, 0, 0}; // Padding: top, bottom, left, right
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
public static ZeroPaddingLayerBuilder<?, ?> builder() {
return innerBuilder();
@ -58,60 +65,32 @@ public class ZeroPaddingLayer extends NoParamLayer {
* @param padWidth Padding for both the left and right
*/
public static ZeroPaddingLayerBuilder<?, ?> builder(int padHeight, int padWidth) {
return innerBuilder()
.padding(padHeight, padHeight, padWidth, padWidth);
return innerBuilder().padding(padHeight, padHeight, padWidth, padWidth);
}
/**
* @param padTop Top padding value
* @param padBottom Bottom padding value
* @param padLeft Left padding value
* @param padRight Right padding value
*/
public static ZeroPaddingLayerBuilder<?,?> builder(int padTop, int padBottom, int padLeft, int padRight) {
return innerBuilder()
.padding(padTop, padBottom, padLeft, padRight);
public static ZeroPaddingLayerBuilder<?, ?> builder(
int padTop, int padBottom, int padLeft, int padRight) {
return innerBuilder().padding(padTop, padBottom, padLeft, padRight);
}
public static ZeroPaddingLayerBuilder<?, ?> builder(int[] padding) {
return innerBuilder()
.padding(padding);
return innerBuilder().padding(padding);
}
public static abstract class ZeroPaddingLayerBuilder<C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
extends NoParamLayerBuilder<C, B> {
public C build() {
if (padding$value == null || padding$value.length != 4) {
throw new IllegalArgumentException(
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]." + " Got: "
+ (padding$value == null ? null : Arrays.toString(padding$value)));
}
C l = initBuild();
l.initializeConstraints();
return l;
}
public B padding(int ... padding) {
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
this.padding$set = true;
return self();
}
}
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
* @param format Format for activations (in and out)
*/
@Builder.Default
private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
@Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) {
public org.deeplearning4j.nn.api.Layer instantiate(
NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance();
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret =
@ -137,8 +116,11 @@ runInheritance();
@Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
Preconditions.checkArgument(inputType != null, "Invalid input for ZeroPaddingLayer layer (layer name=\""
+ getName() + "\"): InputType is null");
Preconditions.checkArgument(
inputType != null,
"Invalid input for ZeroPaddingLayer layer (layer name=\""
+ getName()
+ "\"): InputType is null");
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
}
@ -148,9 +130,11 @@ runInheritance();
return new LayerMemoryReport.Builder(name, ZeroPaddingLayer.class, inputType, outputType)
.standardMemory(0, 0) // No params
//Inference and training is same - just output activations, no working memory in addition to that
// Inference and training is same - just output activations, no working memory in addition
// to that
.workingMemory(0, 0, MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS)
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
.cacheMemory(
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
.build();
}
@ -160,5 +144,30 @@ runInheritance();
this.dataFormat = c.getFormat();
}
private static final class ZeroPaddingLayerBuilderImpl
extends ZeroPaddingLayerBuilder<ZeroPaddingLayer, ZeroPaddingLayerBuilderImpl> {
public ZeroPaddingLayer build() {
ZeroPaddingLayer l = new ZeroPaddingLayer(this);
if (l.getPadding() == null || l.getPadding().length != 4) {
throw new IllegalArgumentException(
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]."
+ " Got: "
+ (l.getPadding() == null ? null : Arrays.toString(l.getPadding())));
}
l.initializeConstraints();
return l;
}
}
public abstract static class ZeroPaddingLayerBuilder<
C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
extends NoParamLayerBuilder<C, B> {
public B padding(int... padding) {
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
this.padding$set = true;
return self();
}
}
}

View File

@ -164,17 +164,21 @@ public class Yolo2OutputLayer extends LayerConfiguration {
public static abstract class Yolo2OutputLayerBuilder<
C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> {
public C build() {
if (boundingBoxes == null) {
}
private static final class Yolo2OutputLayerBuilderImpl extends Yolo2OutputLayerBuilder<Yolo2OutputLayer, Yolo2OutputLayerBuilderImpl> {
public Yolo2OutputLayer build() {
Yolo2OutputLayer l = new Yolo2OutputLayer(this);
if (l.getBoundingBoxes() == null) {
throw new IllegalStateException("Bounding boxes have not been set");
}
if (boundingBoxes.rank() != 2 || boundingBoxes.size(1) != 2) {
if (l.getBoundingBoxes().rank() != 2 || l.getBoundingBoxes().size(1) != 2) {
throw new IllegalStateException(
"Bounding box priors must have shape [nBoxes, 2]. Has shape: "
+ Arrays.toString(boundingBoxes.shape()));
+ Arrays.toString(l.getBoundingBoxes().shape()));
}
return initBuild();
return l;
}
}
}