Using @SuperBuilder for LayerConfigurations
Signed-off-by: brian <brian@brutex.de>master
parent
7628bbdd53
commit
cb236878a4
|
@ -138,7 +138,7 @@ public class ActivationLayer extends NoParamLayer {
|
||||||
|
|
||||||
private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> {
|
private static final class ActivationLayerBuilderImpl extends ActivationLayerBuilder<ActivationLayer, ActivationLayerBuilderImpl> {
|
||||||
public ActivationLayer build() {
|
public ActivationLayer build() {
|
||||||
ActivationLayer l = this.initBuild();
|
ActivationLayer l = new ActivationLayer(this);
|
||||||
l.initializeConstraints();
|
l.initializeConstraints();
|
||||||
return l;
|
return l;
|
||||||
}
|
}
|
||||||
|
|
|
@ -121,13 +121,10 @@ public class CenterLossOutputLayer extends BaseOutputLayer {
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
|
public static abstract class CenterLossOutputLayerBuilder<C extends CenterLossOutputLayer, B extends CenterLossOutputLayerBuilder<C,B>> extends
|
||||||
BaseOutputLayerBuilder<C, B> {
|
BaseOutputLayerBuilder<C, B> {
|
||||||
public C build() {
|
|
||||||
C l = initBuild();
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer,
|
private static final class CenterLossOutputLayerBuilderImpl extends CenterLossOutputLayerBuilder<CenterLossOutputLayer,
|
||||||
|
|
|
@ -185,7 +185,7 @@ public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> {
|
private static final class Convolution1DLayerBuilderImpl extends ConvolutionLayerBuilder<ConvolutionLayer, Convolution1DLayerBuilderImpl> {
|
||||||
public ConvolutionLayer build() {
|
public ConvolutionLayer build() {
|
||||||
ConvolutionLayer l = initBuild();
|
ConvolutionLayer l = new ConvolutionLayer(this);
|
||||||
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
ConvolutionUtils.validateConvolutionModePadding(l.getConvolutionMode(), l.getPadding());
|
||||||
ConvolutionUtils.validateCnnKernelStridePadding(
|
ConvolutionUtils.validateCnnKernelStridePadding(
|
||||||
l.getKernelSize(), l.getStride(), l.getPadding());
|
l.getKernelSize(), l.getStride(), l.getPadding());
|
||||||
|
|
|
@ -144,21 +144,27 @@ public class SelfAttentionLayer extends SameDiffLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class SelfAttentionLayerBuilder<
|
public abstract static class SelfAttentionLayerBuilder<
|
||||||
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
|
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
|
||||||
extends SameDiffLayerBuilder<C, B> {
|
extends SameDiffLayerBuilder<C, B> {}
|
||||||
public C build() {
|
|
||||||
|
private static final class SelfAttentionLayerBuilderImpl
|
||||||
|
extends SelfAttentionLayerBuilder<SelfAttentionLayer, SelfAttentionLayerBuilderImpl> {
|
||||||
|
public SelfAttentionLayer build() {
|
||||||
|
SelfAttentionLayer l = new SelfAttentionLayer(this);
|
||||||
Preconditions.checkArgument(
|
Preconditions.checkArgument(
|
||||||
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
l.isProjectInput() || l.getNHeads() == 1, "projectInput must be true when nHeads != 1");
|
||||||
Preconditions.checkArgument(
|
Preconditions.checkArgument(
|
||||||
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
l.isProjectInput() || l.getNIn() == l.getNOut(),
|
||||||
|
"nIn must be equal to nOut when projectInput is false");
|
||||||
Preconditions.checkArgument(
|
Preconditions.checkArgument(
|
||||||
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
!l.isProjectInput() || l.getNOut() != 0,
|
||||||
|
"nOut must be specified when projectInput is true");
|
||||||
Preconditions.checkArgument(
|
Preconditions.checkArgument(
|
||||||
this.nOut % nHeads == 0 || headSize > 0,
|
l.getNOut() % l.getNHeads() == 0 || l.getHeadSize() > 0,
|
||||||
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||||
|
|
||||||
return initBuild();
|
return l;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,14 +54,6 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
* have been updated.
|
* have been updated.
|
||||||
*/
|
*/
|
||||||
protected List<LayerConstraint> pointWiseConstraints;
|
protected List<LayerConstraint> pointWiseConstraints;
|
||||||
/**
|
|
||||||
* Set channels multiplier of channels-wise step in separable convolution
|
|
||||||
*
|
|
||||||
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
|
|
||||||
* channels-wise step.
|
|
||||||
* @return Builder
|
|
||||||
*/
|
|
||||||
@Builder.Default private int depthMultiplier = 1;
|
|
||||||
/**
|
/**
|
||||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
* See {@link CNN2DFormat} for more details.<br>
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
@ -72,6 +64,15 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
protected CNN2DFormat dataFormat =
|
protected CNN2DFormat dataFormat =
|
||||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||||
|
/**
|
||||||
|
* Set channels multiplier of channels-wise step in separable convolution
|
||||||
|
*
|
||||||
|
* @param depthMultiplier integer value, for each input map we get depthMultipler outputs in
|
||||||
|
* channels-wise step.
|
||||||
|
* @return Builder
|
||||||
|
*/
|
||||||
|
@Builder.Default private int depthMultiplier = 1;
|
||||||
|
|
||||||
public static SeparableConvolution2DBuilder<?, ?> builder() {
|
public static SeparableConvolution2DBuilder<?, ?> builder() {
|
||||||
return innerBuilder();
|
return innerBuilder();
|
||||||
}
|
}
|
||||||
|
@ -109,13 +110,13 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
public SeparableConvolution2D clone() {
|
public SeparableConvolution2D clone() {
|
||||||
SeparableConvolution2D clone = (SeparableConvolution2D) super.clone();
|
SeparableConvolution2D clone = (SeparableConvolution2D) super.clone();
|
||||||
if (clone.getKernelSize() != null) {
|
if (clone.getKernelSize() != null) {
|
||||||
clone.setKernelSize( clone.getKernelSize().clone());
|
clone.setKernelSize(clone.getKernelSize().clone());
|
||||||
}
|
}
|
||||||
if (clone.getStride() != null) {
|
if (clone.getStride() != null) {
|
||||||
clone.setStride( clone.getStride().clone());
|
clone.setStride(clone.getStride().clone());
|
||||||
}
|
}
|
||||||
if (clone.getPadding() != null) {
|
if (clone.getPadding() != null) {
|
||||||
clone.setPadding( clone.getPadding().clone());
|
clone.setPadding(clone.getPadding().clone());
|
||||||
}
|
}
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
@ -176,27 +177,9 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
SeparableConvolution2DLayer.class);
|
SeparableConvolution2DLayer.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class SeparableConvolution2DBuilder<
|
public abstract static class SeparableConvolution2DBuilder<
|
||||||
C extends SeparableConvolution2D, B extends SeparableConvolution2DBuilder<C, B>>
|
C extends SeparableConvolution2D, B extends SeparableConvolution2DBuilder<C, B>>
|
||||||
extends ConvolutionLayerBuilder<C, B> {
|
extends ConvolutionLayerBuilder<C, B> {
|
||||||
public C build() {
|
|
||||||
C l = this.initBuild();
|
|
||||||
if (l.getKernelSize().length != 2) {
|
|
||||||
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (l.getStride().length != 2) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Stride should include stride for rows and columns (a 2d array)");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (l.getPadding().length != 2) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Padding should include padding for rows and columns (a 2d array)");
|
|
||||||
}
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Set constraints to be applied to the point-wise convolution weight parameters of this layer.
|
* Set constraints to be applied to the point-wise convolution weight parameters of this layer.
|
||||||
|
@ -231,4 +214,27 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class SeparableConvolution2DBuilderImpl
|
||||||
|
extends SeparableConvolution2DBuilder<
|
||||||
|
SeparableConvolution2D, SeparableConvolution2DBuilderImpl> {
|
||||||
|
public SeparableConvolution2D build() {
|
||||||
|
SeparableConvolution2D l = new SeparableConvolution2D(this);
|
||||||
|
if (l.getKernelSize().length != 2) {
|
||||||
|
throw new IllegalArgumentException("Kernel size of should be rows x columns (a 2d array)");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l.getStride().length != 2) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Stride should include stride for rows and columns (a 2d array)");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (l.getPadding().length != 2) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Padding should include padding for rows and columns (a 2d array)");
|
||||||
|
}
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
@ -35,10 +38,6 @@ import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@SuperBuilder(builderMethodName = "innerBuilder")
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
|
@ -47,9 +46,17 @@ public class ZeroPaddingLayer extends NoParamLayer {
|
||||||
* @param padding Padding value for top, bottom, left, and right. Must be length 4 array
|
* @param padding Padding value for top, bottom, left, and right. Must be length 4 array
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private int[] padding = new int[] {0, 0, 0, 0}; //Padding: top, bottom, left, right
|
private int[] padding = new int[] {0, 0, 0, 0}; // Padding: top, bottom, left, right
|
||||||
|
/**
|
||||||
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
* Default: NCHW
|
||||||
|
*
|
||||||
|
* @param format Format for activations (in and out)
|
||||||
|
*/
|
||||||
|
@Builder.Default private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||||
|
|
||||||
public static ZeroPaddingLayerBuilder<?,?> builder() {
|
public static ZeroPaddingLayerBuilder<?, ?> builder() {
|
||||||
return innerBuilder();
|
return innerBuilder();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,63 +64,35 @@ public class ZeroPaddingLayer extends NoParamLayer {
|
||||||
* @param padHeight Padding for both the top and bottom
|
* @param padHeight Padding for both the top and bottom
|
||||||
* @param padWidth Padding for both the left and right
|
* @param padWidth Padding for both the left and right
|
||||||
*/
|
*/
|
||||||
public static ZeroPaddingLayerBuilder<?,?> builder(int padHeight, int padWidth) {
|
public static ZeroPaddingLayerBuilder<?, ?> builder(int padHeight, int padWidth) {
|
||||||
return innerBuilder()
|
return innerBuilder().padding(padHeight, padHeight, padWidth, padWidth);
|
||||||
.padding(padHeight, padHeight, padWidth, padWidth);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param padTop Top padding value
|
* @param padTop Top padding value
|
||||||
* @param padBottom Bottom padding value
|
* @param padBottom Bottom padding value
|
||||||
* @param padLeft Left padding value
|
* @param padLeft Left padding value
|
||||||
* @param padRight Right padding value
|
* @param padRight Right padding value
|
||||||
*/
|
*/
|
||||||
public static ZeroPaddingLayerBuilder<?,?> builder(int padTop, int padBottom, int padLeft, int padRight) {
|
public static ZeroPaddingLayerBuilder<?, ?> builder(
|
||||||
return innerBuilder()
|
int padTop, int padBottom, int padLeft, int padRight) {
|
||||||
.padding(padTop, padBottom, padLeft, padRight);
|
return innerBuilder().padding(padTop, padBottom, padLeft, padRight);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ZeroPaddingLayerBuilder<?,?> builder(int[] padding) {
|
public static ZeroPaddingLayerBuilder<?, ?> builder(int[] padding) {
|
||||||
return innerBuilder()
|
return innerBuilder().padding(padding);
|
||||||
.padding(padding);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class ZeroPaddingLayerBuilder<C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
|
|
||||||
extends NoParamLayerBuilder<C, B> {
|
|
||||||
public C build() {
|
|
||||||
if (padding$value == null || padding$value.length != 4) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]." + " Got: "
|
|
||||||
+ (padding$value == null ? null : Arrays.toString(padding$value)));
|
|
||||||
}
|
|
||||||
|
|
||||||
C l = initBuild();
|
|
||||||
l.initializeConstraints();
|
|
||||||
return l;
|
|
||||||
}
|
|
||||||
|
|
||||||
public B padding(int ... padding) {
|
|
||||||
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
|
|
||||||
this.padding$set = true;
|
|
||||||
return self();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
|
||||||
* See {@link CNN2DFormat} for more details.<br>
|
|
||||||
* Default: NCHW
|
|
||||||
* @param format Format for activations (in and out)
|
|
||||||
*/
|
|
||||||
@Builder.Default
|
|
||||||
private CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
public org.deeplearning4j.nn.api.Layer instantiate(
|
||||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
NeuralNetConfiguration conf,
|
||||||
boolean initializeParams, DataType networkDataType) {
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
runInheritance();
|
runInheritance();
|
||||||
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret =
|
org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret =
|
||||||
new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(lconf, networkDataType);
|
new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(lconf, networkDataType);
|
||||||
ret.addTrainingListeners(trainingListeners);
|
ret.addTrainingListeners(trainingListeners);
|
||||||
|
@ -130,15 +109,18 @@ runInheritance();
|
||||||
int outH = hwd[0] + padding[0] + padding[1];
|
int outH = hwd[0] + padding[0] + padding[1];
|
||||||
int outW = hwd[1] + padding[2] + padding[3];
|
int outW = hwd[1] + padding[2] + padding[3];
|
||||||
|
|
||||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
|
|
||||||
return InputType.convolutional(outH, outW, hwd[2], c.getFormat());
|
return InputType.convolutional(outH, outW, hwd[2], c.getFormat());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
Preconditions.checkArgument(inputType != null, "Invalid input for ZeroPaddingLayer layer (layer name=\""
|
Preconditions.checkArgument(
|
||||||
+ getName() + "\"): InputType is null");
|
inputType != null,
|
||||||
|
"Invalid input for ZeroPaddingLayer layer (layer name=\""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): InputType is null");
|
||||||
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
|
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,18 +129,45 @@ runInheritance();
|
||||||
InputType outputType = getOutputType(-1, inputType);
|
InputType outputType = getOutputType(-1, inputType);
|
||||||
|
|
||||||
return new LayerMemoryReport.Builder(name, ZeroPaddingLayer.class, inputType, outputType)
|
return new LayerMemoryReport.Builder(name, ZeroPaddingLayer.class, inputType, outputType)
|
||||||
.standardMemory(0, 0) //No params
|
.standardMemory(0, 0) // No params
|
||||||
//Inference and training is same - just output activations, no working memory in addition to that
|
// Inference and training is same - just output activations, no working memory in addition
|
||||||
|
// to that
|
||||||
.workingMemory(0, 0, MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS)
|
.workingMemory(0, 0, MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS)
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
|
.cacheMemory(
|
||||||
|
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
this.dataFormat = c.getFormat();
|
this.dataFormat = c.getFormat();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static final class ZeroPaddingLayerBuilderImpl
|
||||||
|
extends ZeroPaddingLayerBuilder<ZeroPaddingLayer, ZeroPaddingLayerBuilderImpl> {
|
||||||
|
public ZeroPaddingLayer build() {
|
||||||
|
ZeroPaddingLayer l = new ZeroPaddingLayer(this);
|
||||||
|
if (l.getPadding() == null || l.getPadding().length != 4) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Invalid padding values: must have exactly 4 values [top, bottom, left, right]."
|
||||||
|
+ " Got: "
|
||||||
|
+ (l.getPadding() == null ? null : Arrays.toString(l.getPadding())));
|
||||||
|
}
|
||||||
|
|
||||||
|
l.initializeConstraints();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract static class ZeroPaddingLayerBuilder<
|
||||||
|
C extends ZeroPaddingLayer, B extends ZeroPaddingLayerBuilder<C, B>>
|
||||||
|
extends NoParamLayerBuilder<C, B> {
|
||||||
|
|
||||||
|
public B padding(int... padding) {
|
||||||
|
this.padding$value = ValidationUtils.validate4NonNegative(padding, "padding");
|
||||||
|
this.padding$set = true;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -164,17 +164,21 @@ public class Yolo2OutputLayer extends LayerConfiguration {
|
||||||
public static abstract class Yolo2OutputLayerBuilder<
|
public static abstract class Yolo2OutputLayerBuilder<
|
||||||
C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
|
C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
|
||||||
extends LayerConfigurationBuilder<C, B> {
|
extends LayerConfigurationBuilder<C, B> {
|
||||||
public C build() {
|
|
||||||
if (boundingBoxes == null) {
|
}
|
||||||
|
private static final class Yolo2OutputLayerBuilderImpl extends Yolo2OutputLayerBuilder<Yolo2OutputLayer, Yolo2OutputLayerBuilderImpl> {
|
||||||
|
public Yolo2OutputLayer build() {
|
||||||
|
Yolo2OutputLayer l = new Yolo2OutputLayer(this);
|
||||||
|
if (l.getBoundingBoxes() == null) {
|
||||||
throw new IllegalStateException("Bounding boxes have not been set");
|
throw new IllegalStateException("Bounding boxes have not been set");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (boundingBoxes.rank() != 2 || boundingBoxes.size(1) != 2) {
|
if (l.getBoundingBoxes().rank() != 2 || l.getBoundingBoxes().size(1) != 2) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Bounding box priors must have shape [nBoxes, 2]. Has shape: "
|
"Bounding box priors must have shape [nBoxes, 2]. Has shape: "
|
||||||
+ Arrays.toString(boundingBoxes.shape()));
|
+ Arrays.toString(l.getBoundingBoxes().shape()));
|
||||||
}
|
}
|
||||||
return initBuild();
|
return l;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue