Using @SuperBuilder for LayerConfigurations
Signed-off-by: brian <brian@brutex.de>master
parent
7628bbdd53
commit
cb236878a4
|
@ -138,7 +138,7 @@ public class ActivationLayer extends NoParamLayer {
|
|||
|
||||
private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> {
|
||||
public ActivationLayer build() {
|
||||
ActivationLayer l = this.initBuild();
|
||||
ActivationLayer l = new ActivationLayer(this);
|
||||
l.initializeConstraints();
|
||||
return l;
|
||||
}
|
||||
|
|
|
@ -121,13 +121,10 @@ public class CenterLossOutputLayer extends BaseOutputLayer {
|
|||
.build();
|
||||
}
|
||||
|
||||
|
||||
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
|
||||
BaseOutputLayerBuilder<C, B> {
|
||||
public C build() {
|
||||
C l = initBuild();
|
||||
l.initializeConstraints();
|
||||
return l;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer,
|
||||
|
|
|
@ -185,7 +185,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
|||
|
||||
private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> {
|
||||
public ConvolutionLayer build() {
|
||||
ConvolutionLayer l = initBuild();
|
||||
ConvolutionLayer l = new ConvolutionLayer(this);
|
||||
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||
ConvolutionUtils.validateCnnKernelStridePadding(
|
||||
l.getKernelSize(), l.getStride(), l.getPadding());
|
||||
|
|
|
@ -144,21 +144,27 @@ public class SelfAttentionLayer extends SameDiffLayer {
|
|||
}
|
||||
}
|
||||
|
||||
public static abstract class SelfAttentionLayerBuilder<
|
||||
public abstract static class SelfAttentionLayerBuilder<
|
||||
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
|
||||
extends SameDiffLayerBuilder<C, B> {
|
||||
public C build() {
|
||||
extends SameDiffLayerBuilder<C, B> {}
|
||||
|
||||
private static final class SelfAttentionLayerBuilderImpl
|
||||
extends SelfAttentionLayerBuilder<SelfAttentionLayer, SelfAttentionLayerBuilderImpl> {
|
||||
public SelfAttentionLayer build() {
|
||||
SelfAttentionLayer l = new SelfAttentionLayer(this);
|
||||
Preconditions.checkArgument(
|
||||
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||
l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
|
||||
Preconditions.checkArgument(
|
||||
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||
l.isProjectInput() || l.getNIn() == l.getNOut(),
|
||||
"nIn must be equal to nOut when projectInput is false");
|
||||
Preconditions.checkArgument(
|
||||
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||
!l.isProjectInput() || l.getNOut() != 0,
|
||||
"nOut must be specified when projectInput is true");
|
||||
Preconditions.checkArgument(
|
||||
this.nOut % nHeads == 0 || headSize > 0,
|
||||
l.getNOut() % l.getNHeads() == 0 || l.getHeadSize() > 0,
|
||||
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||
|
||||
return initBuild();
|
||||
return l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,14 +54,6 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
|||
* have been updated.
|
||||
*/
|
||||
protected List<LayerConstraint> pointWiseConstraints;
|
||||
/**
|
||||
* Set channels multiplier of channels-wise step in separable convolution
|
||||
*
|
||||
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
|
||||
* channels-wise step.
|
||||
* @return Builder
|
||||
*/
|
||||
@Builder.Default private int depthMultiplier = 1;
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
|
@ -72,6 +64,15 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
|||
@Builder.Default
|
||||
protected CNN2DFormat dataFormat =
|
||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||
/**
|
||||
* Set channels multiplier of channels-wise step in separable convolution
|
||||
*
|
||||
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
|
||||
* channels-wise step.
|
||||
* @return Builder
|
||||
*/
|
||||
@Builder.Default private int depthMultiplier = 1;
|
||||
|
||||
public static SeparableConvolution2DBuilder<?, ?> builder() {
|
||||
return innerBuilder();
|
||||
}
|
||||
|
@ -176,27 +177,9 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
|||
SeparableConvolution2DLayer.class);
|
||||
}
|
||||
|
||||
public static abstract class SeparableConvolution2DBuilder<
|
||||
public abstract static class SeparableConvolution2DBuilder<
|
||||
C extends SeparableConvolution2D, B extends SeparableConvolution2DBuilder<C, B>>
|
||||
extends ConvolutionLayerBuilder<C, B> {
|
||||
public C build() {
|
||||
C l = this.initBuild();
|
||||
if (l.getKernelSize().length != 2) {
|
||||
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
|
||||
}
|
||||
|
||||
if (l.getStride().length != 2) {
|
||||
throw new IllegalArgumentException(
|
||||
"Stride should include stride for rows and columns (a 2d array)");
|
||||
}
|
||||
|
||||
if (l.getPadding().length != 2) {
|
||||
throw new IllegalArgumentException(
|
||||
"Padding should include padding for rows and columns (a 2d array)");
|
||||
}
|
||||
l.initializeConstraints();
|
||||
return l;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set constraints to be applied to the point-wise convolution weight parameters of this layer.
|
||||
|
@ -231,4 +214,27 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
|||
return self();
|
||||
}
|
||||
}
|
||||
|
||||
private static final class SeparableConvolution2DBuilderImpl
|
||||
extends SeparableConvolution2DBuilder<
|
||||
SeparableConvolution2D, SeparableConvolution2DBuilderImpl> {
|
||||
public SeparableConvolution2D build() {
|
||||
SeparableConvolution2D l = new SeparableConvolution2D(this);
|
||||
if (l.getKernelSize().length != 2) {
|
||||
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
|
||||
}
|
||||
|
||||
if (l.getStride().length != 2) {
|
||||
throw new IllegalArgumentException(
|
||||
"Stride should include stride for rows and columns (a 2d array)");
|
||||
}
|
||||
|
||||
if (l.getPadding().length != 2) {
|
||||
throw new IllegalArgumentException(
|
||||
"Padding should include padding for rows and columns (a 2d array)");
|
||||
}
|
||||
l.initializeConstraints();
|
||||
return l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
|
@ -35,10 +38,6 @@ import org.nd4j.common.base.Preconditions;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||
|
@ -48,6 +47,14 @@ public class ZeroPaddingLayer extends NoParamLayer {
|
|||
*/
|
||||
@Builder.Default
|
||||
private int[] padding = new int[] {0, 0, 0, 0}; // Padding: top, bottom, left, right
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
*
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||
|
||||
public static ZeroPaddingLayerBuilder<?, ?> builder() {
|
||||
return innerBuilder();
|
||||
|
@ -58,60 +65,32 @@ public class ZeroPaddingLayer extends NoParamLayer {
|
|||
* @param padWidth Padding for both the left and right
|
||||
*/
|
||||
public static ZeroPaddingLayerBuilder<?, ?> builder(int padHeight, int padWidth) {
|
||||
return innerBuilder()
|
||||
.padding(padHeight, padHeight, padWidth, padWidth);
|
||||
return innerBuilder().padding(padHeight, padHeight, padWidth, padWidth);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param padTop Top padding value
|
||||
* @param padBottom Bottom padding value
|
||||
* @param padLeft Left padding value
|
||||
* @param padRight Right padding value
|
||||
*/
|
||||
public static ZeroPaddingLayerBuilder<?,?> builder(int padTop, int padBottom, int padLeft, int padRight) {
|
||||
return innerBuilder()
|
||||
.padding(padTop, padBottom, padLeft, padRight);
|
||||
public static ZeroPaddingLayerBuilder<?, ?> builder(
|
||||
int padTop, int padBottom, int padLeft, int padRight) {
|
||||
return innerBuilder().padding(padTop, padBottom, padLeft, padRight);
|
||||
}
|
||||
|
||||
public static ZeroPaddingLayerBuilder<?, ?> builder(int[] padding) {
|
||||
return innerBuilder()
|
||||
.padding(padding);
|
||||
return innerBuilder().padding(padding);
|
||||
}
|
||||
|
||||
public static abstract class ZeroPaddingLayerBuilder<C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
|
||||
extends NoParamLayerBuilder<C, B> {
|
||||
public C build() {
|
||||
if (padding$value == null || padding$value.length != 4) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]." + " Got: "
|
||||
+ (padding$value == null ? null : Arrays.toString(padding$value)));
|
||||
}
|
||||
|
||||
C l = initBuild();
|
||||
l.initializeConstraints();
|
||||
return l;
|
||||
}
|
||||
|
||||
public B padding(int ... padding) {
|
||||
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
|
||||
this.padding$set = true;
|
||||
return self();
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default
|
||||
private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
boolean initializeParams, DataType networkDataType) {
|
||||
public org.deeplearning4j.nn.api.Layer instantiate(
|
||||
NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners,
|
||||
int layerIndex,
|
||||
INDArray layerParamsView,
|
||||
boolean initializeParams,
|
||||
DataType networkDataType) {
|
||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||
runInheritance();
|
||||
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret =
|
||||
|
@ -137,8 +116,11 @@ runInheritance();
|
|||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
Preconditions.checkArgument(inputType != null, "Invalid input for ZeroPaddingLayer layer (layer name=\""
|
||||
+ getName() + "\"): InputType is null");
|
||||
Preconditions.checkArgument(
|
||||
inputType != null,
|
||||
"Invalid input for ZeroPaddingLayer layer (layer name=\""
|
||||
+ getName()
|
||||
+ "\"): InputType is null");
|
||||
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
|
||||
}
|
||||
|
||||
|
@ -148,9 +130,11 @@ runInheritance();
|
|||
|
||||
return new LayerMemoryReport.Builder(name, ZeroPaddingLayer.class, inputType, outputType)
|
||||
.standardMemory(0, 0) // No params
|
||||
//Inference and training is same - just output activations, no working memory in addition to that
|
||||
// Inference and training is same - just output activations, no working memory in addition
|
||||
// to that
|
||||
.workingMemory(0, 0, MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS)
|
||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
|
||||
.cacheMemory(
|
||||
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -160,5 +144,30 @@ runInheritance();
|
|||
this.dataFormat = c.getFormat();
|
||||
}
|
||||
|
||||
|
||||
private static final class ZeroPaddingLayerBuilderImpl
|
||||
extends ZeroPaddingLayerBuilder<ZeroPaddingLayer, ZeroPaddingLayerBuilderImpl> {
|
||||
public ZeroPaddingLayer build() {
|
||||
ZeroPaddingLayer l = new ZeroPaddingLayer(this);
|
||||
if (l.getPadding() == null || l.getPadding().length != 4) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]."
|
||||
+ " Got: "
|
||||
+ (l.getPadding() == null ? null : Arrays.toString(l.getPadding())));
|
||||
}
|
||||
|
||||
l.initializeConstraints();
|
||||
return l;
|
||||
}
|
||||
}
|
||||
|
||||
public abstract static class ZeroPaddingLayerBuilder<
|
||||
C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
|
||||
extends NoParamLayerBuilder<C, B> {
|
||||
|
||||
public B padding(int... padding) {
|
||||
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
|
||||
this.padding$set = true;
|
||||
return self();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -164,17 +164,21 @@ public class Yolo2OutputLayer extends LayerConfiguration {
|
|||
public static abstract class Yolo2OutputLayerBuilder<
|
||||
C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
|
||||
extends LayerConfigurationBuilder<C, B> {
|
||||
public C build() {
|
||||
if (boundingBoxes == null) {
|
||||
|
||||
}
|
||||
private static final class Yolo2OutputLayerBuilderImpl extends Yolo2OutputLayerBuilder<Yolo2OutputLayer, Yolo2OutputLayerBuilderImpl> {
|
||||
public Yolo2OutputLayer build() {
|
||||
Yolo2OutputLayer l = new Yolo2OutputLayer(this);
|
||||
if (l.getBoundingBoxes() == null) {
|
||||
throw new IllegalStateException("Bounding boxes have not been set");
|
||||
}
|
||||
|
||||
if (boundingBoxes.rank() != 2 || boundingBoxes.size(1) != 2) {
|
||||
if (l.getBoundingBoxes().rank() != 2 || l.getBoundingBoxes().size(1) != 2) {
|
||||
throw new IllegalStateException(
|
||||
"Bounding box priors must have shape [nBoxes, 2]. Has shape: "
|
||||
+ Arrays.toString(boundingBoxes.shape()));
|
||||
+ Arrays.toString(l.getBoundingBoxes().shape()));
|
||||
}
|
||||
return initBuild();
|
||||
return l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue