Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-25 11:41:33 +02:00
parent f6100c362d
commit ad870c5281
118 changed files with 3278 additions and 3071 deletions

View File

@ -221,7 +221,7 @@ public class TestMiscFunctions extends BaseSparkTest {
int nIn = 10; int nIn = 10;
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list() NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.reconstructionDistribution( .reconstructionDistribution(
new GaussianReconstructionDistribution(Activation.IDENTITY)) new GaussianReconstructionDistribution(Activation.IDENTITY))
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build()) .nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build())
@ -261,7 +261,7 @@ public class TestMiscFunctions extends BaseSparkTest {
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder() NeuralNetConfiguration mlc = NeuralNetConfiguration.builder()
.list().layer(0, .list().layer(0,
new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.reconstructionDistribution(new LossFunctionWrapper( .reconstructionDistribution(new LossFunctionWrapper(
Activation.IDENTITY, new LossMSE())) Activation.IDENTITY, new LossMSE()))
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13) .nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)

View File

@ -96,7 +96,7 @@ public class App {
private static LayerConfiguration[] genLayers() { private static LayerConfiguration[] genLayers() {
return new LayerConfiguration[] { return new LayerConfiguration[] {
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(Activation.LEAKYRELU).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
ActivationLayer.builder(new ActivationLReLU(0.2)).build(), ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),

View File

@ -331,10 +331,10 @@ public class CNN2DTestCases {
.build(), .build(),
"leaky_re_lu_8") "leaky_re_lu_8")
.addLayer("outputs", .addLayer("outputs",
new Yolo2OutputLayer.Builder() Yolo2OutputLayer.builder()
.lambdaNoObj(lambdaNoObj) .lambdaNoObj(lambdaNoObj)
.lambdaCoord(lambdaCoord) .lambdaCoord(lambdaCoord)
.boundingBoxPriors(priors) .boundingBoxes(priors)
.build(), .build(),
"convolution2d_9") "convolution2d_9")
.setOutputs("outputs") .setOutputs("outputs")

View File

@ -322,7 +322,7 @@ public class RNNTestCases {
.updater(new Adam(5e-2)) .updater(new Adam(5e-2))
.l1(1e-3).l2(1e-3) .l1(1e-3).l2(1e-3)
.layer(0, Bidirectional.builder(LSTM.builder().activation(Activation.TANH).nOut(10).build())) .layer(0, Bidirectional.builder(LSTM.builder().activation(Activation.TANH).nOut(10).build()).build())
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
.layer(OutputLayer.builder().nOut(6) .layer(OutputLayer.builder().nOut(6)
.lossFunction(LossFunctions.LossFunction.MCXENT) .lossFunction(LossFunctions.LossFunction.MCXENT)

View File

@ -22,9 +22,11 @@ package org.nd4j.linalg.activations;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.activations.impl.*;
import org.nd4j.linalg.api.ndarray.INDArray;
public enum Activation { public enum Activation implements IActivation {
CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6, CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6,
RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH, RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH,
THRESHOLDEDRELU, GELU, MISH; THRESHOLDEDRELU, GELU, MISH;
@ -149,4 +151,44 @@ public enum Activation {
throw new UnsupportedOperationException("Activation function not yet supported: " + this); throw new UnsupportedOperationException("Activation function not yet supported: " + this);
} }
} }
/**
* Carry out activation function on the input array (usually known as 'preOut' or 'z')
* Implementations must overwrite "in", transform in place and return "in"
* Can support separate behaviour during test
*
* @param in input array.
* @param training true when training.
* @return transformed activation
*/
@Override
public INDArray getActivation(INDArray in, boolean training) {
return getActivationFunction().getActivation(in, training);
}
/**
* Backpropagate the errors through the activation function, given input z and epsilon dL/da.<br>
* Returns 2 INDArrays:<br>
* (a) The gradient dL/dz, calculated from dL/da, and<br>
* (b) The parameter gradients dL/dW, where w is the weights in the activation function. For activation functions
* with no gradients, this will be null.
*
* @param in Input, before applying the activation function (z, or 'preOut')
* @param epsilon Gradient to be backpropagated: dL/da, where L is the loss function
* @return dL/dz and dL/dW, for weights w (null if activation function has no weights)
*/
@Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
return getActivationFunction().backprop(in, epsilon);
}
/**
*
* @param inputSize
* @return
*/
@Override
public int numParams(int inputSize) {
return getActivationFunction().numParams(inputSize);
}
} }

View File

@ -872,7 +872,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
.nIn(10) .nIn(10)
.nOut(10) .nOut(10)
.activation(Activation.TANH) .activation(Activation.TANH)
.gateActivationFunction(Activation.SIGMOID) .gateActivationFunction(Activation.SIGMOID.getActivationFunction())
.dropOut(0.5) .dropOut(0.5)
.build()) .build())
.layer(1, RnnOutputLayer.builder() .layer(1, RnnOutputLayer.builder()

View File

@ -90,8 +90,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer( projectInput ? .layer( projectInput ?
new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build() SelfAttentionLayer.builder().nOut(4).nHeads(2).projectInput(true).build()
: new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build() : SelfAttentionLayer.builder().nHeads(1).projectInput(false).build()
) )
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
@ -151,8 +151,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
.list() .list()
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer( projectInput ? .layer( projectInput ?
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() LearnedSelfAttentionLayer.builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build() : LearnedSelfAttentionLayer.builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
) )
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
@ -191,8 +191,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
.list() .list()
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer( projectInput ? .layer( projectInput ?
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build() LearnedSelfAttentionLayer.builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build() : LearnedSelfAttentionLayer.builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
) )
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
@ -245,7 +245,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()) .layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
@ -308,7 +308,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()) .layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())

View File

@ -363,7 +363,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.nOut(1) .nOut(1)
.build()) // output: (5-2+0)/1+1 = 4 .build()) // output: (5-2+0)/1+1 = 4
.layer( .layer(
new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW) SpaceToDepthLayer.builder().blockSize(blocks).dataFormat(CNN2DFormat.NCHW)
.build()) // (mb,1,4,4) -> (mb,4,2,2) .build()) // (mb,1,4,4) -> (mb,4,2,2)
.layer( .layer(
OutputLayer.builder() OutputLayer.builder()
@ -450,10 +450,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionLayer.builder(kernel) ConvolutionLayer.builder(kernel)
.nIn(inputDepth) .nIn(inputDepth)
.nOut(3) .nOut(3)
.dataFormat(format) .convFormat(format)
.build()) .build())
.layer( .layer(
new SpaceToBatchLayer.Builder(blocks) SpaceToBatchLayer.builder(blocks)
.dataFormat(format) .dataFormat(format)
.build()) // trivial space to batch .build()) // trivial space to batch
.layer( .layer(
@ -546,7 +546,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer( .layer(
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.nIn(inputDepth) .nIn(inputDepth)
.dataFormat(format) .convFormat(format)
.nOut(3) .nOut(3)
.build()) // output: (5-2+0)/1+1 = 4 .build()) // output: (5-2+0)/1+1 = 4
.layer( .layer(
@ -641,7 +641,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
0, 0,
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.nIn(inputDepth) .nIn(inputDepth)
.dataFormat(format) .convFormat(format)
.nOut(3) .nOut(3)
.build()) // output: (5-2+0)/1+1 = 4 .build()) // output: (5-2+0)/1+1 = 4
.layer( .layer(
@ -750,7 +750,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
0, 0,
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.nIn(inputDepth) .nIn(inputDepth)
.dataFormat(format) .convFormat(format)
.nOut(3) .nOut(3)
.build()) // output: (5-2+0)/1+1 = 4 .build()) // output: (5-2+0)/1+1 = 4
.layer( .layer(
@ -765,7 +765,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer( .layer(
2, 2,
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.dataFormat(format) .convFormat(format)
.nIn(3) .nIn(3)
.nOut(2) .nOut(2)
.build()) // Output: (3-2+0)/1+1 = 2 .build()) // Output: (3-2+0)/1+1 = 2
@ -849,7 +849,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionLayer.builder() ConvolutionLayer.builder()
.kernelSize(2, 2) .kernelSize(2, 2)
.stride(1, 1) .stride(1, 1)
.dataFormat(format) .convFormat(format)
.padding(0, 0) .padding(0, 0)
.nIn(inputDepth) .nIn(inputDepth)
.nOut(2) .nOut(2)
@ -861,7 +861,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.nOut(7) .nOut(7)
.kernelSize(2, 2) .kernelSize(2, 2)
.dataFormat(format) .dataFormat(format)
.setInputSize(4, 4) .inputSize(new int[]{4, 4})
.convolutionMode(ConvolutionMode.Strict) .convolutionMode(ConvolutionMode.Strict)
.hasBias(false) .hasBias(false)
.stride(1, 1) .stride(1, 1)
@ -873,7 +873,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.nIn(7) .nIn(7)
.nOut(2) .nOut(2)
.kernelSize(2, 2) .kernelSize(2, 2)
.dataFormat(format) .convFormat(format)
.stride(1, 1) .stride(1, 1)
.padding(0, 0) .padding(0, 0)
.build()) // (3-2+0)/1+1 = 2 .build()) // (3-2+0)/1+1 = 2
@ -959,7 +959,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionLayer.builder() ConvolutionLayer.builder()
.kernelSize(2, 2) .kernelSize(2, 2)
.stride(1, 1) .stride(1, 1)
.dataFormat(format) .convFormat(format)
.padding(0, 0) .padding(0, 0)
.nIn(inputDepth) .nIn(inputDepth)
.nOut(2) .nOut(2)
@ -970,7 +970,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.nIn(2) .nIn(2)
.nOut(2) .nOut(2)
.kernelSize(2, 2) .kernelSize(2, 2)
.dataFormat(format) .convFormat(format)
.stride(1, 1) .stride(1, 1)
.padding(0, 0) .padding(0, 0)
.build()) // (4-2+0)/1+1 = 3 .build()) // (4-2+0)/1+1 = 3
@ -980,7 +980,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.nIn(2) .nIn(2)
.nOut(2) .nOut(2)
.kernelSize(2, 2) .kernelSize(2, 2)
.dataFormat(format) .convFormat(format)
.stride(1, 1) .stride(1, 1)
.padding(0, 0) .padding(0, 0)
.build()) // (3-2+0)/1+1 = 2 .build()) // (3-2+0)/1+1 = 2
@ -1076,7 +1076,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionLayer.builder() ConvolutionLayer.builder()
.name("layer 0") .name("layer 0")
.kernelSize(k, k) .kernelSize(k, k)
.dataFormat(format) .convFormat(format)
.stride(1, 1) .stride(1, 1)
.padding(0, 0) .padding(0, 0)
.nIn(inputDepth) .nIn(inputDepth)
@ -1097,7 +1097,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.nIn(2) .nIn(2)
.nOut(2) .nOut(2)
.kernelSize(k, k) .kernelSize(k, k)
.dataFormat(format) .convFormat(format)
.stride(1, 1) .stride(1, 1)
.padding(0, 0) .padding(0, 0)
.build()) .build())
@ -1181,7 +1181,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionLayer.builder() ConvolutionLayer.builder()
.name("layer 0") .name("layer 0")
.kernelSize(k, k) .kernelSize(k, k)
.dataFormat(format) .convFormat(format)
.stride(stride, stride) .stride(stride, stride)
.padding(0, 0) .padding(0, 0)
.nIn(inputDepth) .nIn(inputDepth)
@ -1297,7 +1297,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.layer( .layer(
0, 0,
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.dataFormat(format) .convFormat(format)
.nIn(inputDepth) .nIn(inputDepth)
.nOut(3) .nOut(3)
.build()) // output: (6-2+0)/1+1 = 5 .build()) // output: (6-2+0)/1+1 = 5
@ -1307,7 +1307,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.nIn(3) .nIn(3)
.nOut(3) .nOut(3)
.dataFormat(format) .convFormat(format)
.build()) // output: (6-2+0)/1+1 = 5 .build()) // output: (6-2+0)/1+1 = 5
.layer( .layer(
3, 3,
@ -1436,7 +1436,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.name("deconvolution_2D_layer") .name("deconvolution_2D_layer")
.kernelSize(k, k) .kernelSize(k, k)
.stride(s, s) .stride(s, s)
.dataFormat(format) .convFormat(format)
.dilation(d, d) .dilation(d, d)
.convolutionMode(cm) .convolutionMode(cm)
.nIn(inputDepth) .nIn(inputDepth)
@ -1530,7 +1530,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.stride(s, s) .stride(s, s)
.dilation(d, d) .dilation(d, d)
.depthMultiplier(3) .depthMultiplier(3)
.dataFormat(format) .convFormat(format)
.nIn(inputDepth) .nIn(inputDepth)
.nOut(2) .nOut(2)
.build()) .build())
@ -1621,7 +1621,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.kernelSize(k, k) .kernelSize(k, k)
.stride(s, s) .stride(s, s)
.dilation(d, d) .dilation(d, d)
.dataFormat(format) .convFormat(format)
.nIn(inputDepth) .nIn(inputDepth)
.nOut(2) .nOut(2)
.build()); .build());
@ -1642,7 +1642,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.kernelSize(k, k) .kernelSize(k, k)
.stride(s, s) .stride(s, s)
.dilation(d, d) .dilation(d, d)
.dataFormat(format) .convFormat(format)
.build()); .build());
} }
@ -1732,14 +1732,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.list() .list()
.layer( .layer(
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.dataFormat(format) .convFormat(format)
.nIn(inputDepth) .nIn(inputDepth)
.nOut(2) .nOut(2)
.build()) // output: (6-2+0)/1+1 = 5 .build()) // output: (6-2+0)/1+1 = 5
.layer(Cropping2D.builder(crop).dataFormat(format).build()) .layer(Cropping2D.builder(crop).dataFormat(format).build())
.layer( .layer(
ConvolutionLayer.builder(kernel, stride, padding) ConvolutionLayer.builder(kernel, stride, padding)
.dataFormat(format) .convFormat(format)
.nIn(2) .nIn(2)
.nOut(2) .nOut(2)
.build()) .build())
@ -1857,7 +1857,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
.stride(1, 1) .stride(1, 1)
.nIn(nIn) .nIn(nIn)
.nOut(nIn) .nOut(nIn)
.dataFormat(format) .convFormat(format)
.build()) .build())
.layer( .layer(
DepthwiseConvolution2D.builder() DepthwiseConvolution2D.builder()

View File

@ -82,7 +82,7 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest {
.seed(123) .seed(123)
.updater(new NoOp()) .updater(new NoOp())
.dist(new UniformDistribution(-6, 6)) .dist(new UniformDistribution(-6, 6))
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel) .layer(PrimaryCapsules.builder(primaryCapsDim, primarpCapsChannel)
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.build()) .build())

View File

@ -131,7 +131,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
.updater(new NoOp()) .updater(new NoOp())
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list() .dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1) .layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
.dataFormat(nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC) .convFormat(nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)
.nOut(layerDepth) .nOut(layerDepth)
.build()) .build())
.layer(1, GlobalPoolingLayer.builder().poolingType(pt).build()) .layer(1, GlobalPoolingLayer.builder().poolingType(pt).build())

View File

@ -345,10 +345,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
.dist(new NormalDistribution(0, 0.1)) .dist(new NormalDistribution(0, 0.1))
.updater(new NoOp()).graphBuilder().addInputs("input") .updater(new NoOp()).graphBuilder().addInputs("input")
.addLayer("l1", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).padding(0, 0) .addLayer("l1", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
.dataFormat(format) .convFormat(format)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input") .nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addLayer("l2", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1) .addLayer("l2", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
.padding(0, 0).dataFormat(format) .padding(0, 0).convFormat(format)
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input") .nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
.addVertex("merge", new MergeVertex(), "l1", "l2") .addVertex("merge", new MergeVertex(), "l1", "l2")
.addLayer("outputLayer", .addLayer("outputLayer",

View File

@ -116,7 +116,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
.layer(Bidirectional.builder(m, .layer(Bidirectional.builder(m,
(simple ? (simple ?
SimpleRnn.builder().nIn(3).nOut(3).hasLayerNorm(hasLayerNorm).build() : SimpleRnn.builder().nIn(3).nOut(3).hasLayerNorm(hasLayerNorm).build() :
LSTM.builder().nIn(3).nOut(3).build()))) LSTM.builder().nIn(3).nOut(3).build())).build())
.layer(RnnOutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX).build()) .layer(RnnOutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX).build())
.build(); .build();

View File

@ -115,12 +115,11 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
.activation(a) .activation(a)
.l1(l1[i]).l2(l2[i]) .l1(l1[i]).l2(l2[i])
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.list()
.layer(ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1) .layer(ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
.dataFormat(format) .convFormat(format)
.nIn(depthIn).nOut(yoloDepth).build())//output: (5-2+0)/1+1 = 4 .nIn(depthIn).nOut(yoloDepth).build())//output: (5-2+0)/1+1 = 4
.layer(new Yolo2OutputLayer.Builder() .layer(Yolo2OutputLayer.builder()
.boundingBoxPriors(bbPrior) .boundingBoxes(bbPrior)
.build()) .build())
.inputType(InputType.convolutional(h, w, depthIn, format)) .inputType(InputType.convolutional(h, w, depthIn, format))
.build(); .build();
@ -237,8 +236,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nOut(4).build()) .layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nOut(4).build())
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build()) .layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build())
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(3,3).stride(1,1).nOut(depthOut).build()) .layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(3,3).stride(1,1).nOut(depthOut).build())
.layer(new Yolo2OutputLayer.Builder() .layer(Yolo2OutputLayer.builder()
.boundingBoxPriors(bbPriors) .boundingBoxes(bbPriors)
.build()) .build())
.inputType(InputType.convolutional(h,w,c)) .inputType(InputType.convolutional(h,w,c))
.build(); .build();

View File

@ -437,7 +437,7 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
.layer(DenseLayer.builder().nIn(10).nOut(10).build()) .layer(DenseLayer.builder().nIn(10).nOut(10).build())
.layer(!lossLayer ? OutputLayer.builder().nIn(10).nOut(nOut[i]) .layer(!lossLayer ? OutputLayer.builder().nIn(10).nOut(nOut[i])
.activation(activations[i]).lossFunction(lf[i]).build() .activation(activations[i]).lossFunction(lf[i]).build()
: LossLayer.builder().lossFunction().activation(activations[i]).lossFunction(lf[i]) : LossLayer.builder().activation(activations[i]).lossFunction(lf[i].getILossFunction())
.build()) .build())
.validateOutputLayerConfig(validate) .validateOutputLayerConfig(validate)
.build(); .build();

View File

@ -48,6 +48,7 @@ import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.List;
import java.util.Map; import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@ -230,7 +231,8 @@ public class TestConstraints extends BaseDL4JTest {
.biasInit(0.2) .biasInit(0.2)
.layer(DenseLayer.builder().nIn(12).nOut(10) .layer(DenseLayer.builder().nIn(12).nOut(10)
.constrainAllParameters(lc).build()) .allParamConstraints(List.of(lc))
.build())
.layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build()) .layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build())
.build(); .build();

View File

@ -201,21 +201,21 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
.addInputs("input1", "input2", "input3") .addInputs("input1", "input2", "input3")
.addLayer("dense1", .addLayer("dense1",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input1") "input1")
.addLayer("dense2", .addLayer("dense2",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input2") "input2")
.addLayer("dense3", .addLayer("dense3",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input3") "input3")
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1", .addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
"dense2", "dense3") "dense2", "dense3")
.addLayer("output", .addLayer("output",
OutputLayer.builder().nIn(midsz).nOut(outputsz) OutputLayer.builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid()) .activation(Activation.SIGMOID)
.lossFunction(LossFunction.MSE).build(), .lossFunction(LossFunction.MSE).build(),
"elementwiseAdd") "elementwiseAdd")
.setOutputs("output").build(); .setOutputs("output").build();
@ -377,21 +377,21 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
.addInputs("input1", "input2", "input3") .addInputs("input1", "input2", "input3")
.addLayer("dense1", .addLayer("dense1",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input1") "input1")
.addLayer("dense2", .addLayer("dense2",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input2") "input2")
.addLayer("dense3", .addLayer("dense3",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input3") "input3")
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1", .addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1",
"dense2", "dense3") "dense2", "dense3")
.addLayer("output", .addLayer("output",
OutputLayer.builder().nIn(midsz).nOut(outputsz) OutputLayer.builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid()) .activation(Activation.SIGMOID)
.lossFunction(LossFunction.MSE).build(), .lossFunction(LossFunction.MSE).build(),
"elementwiseProduct") "elementwiseProduct")
.setOutputs("output").build(); .setOutputs("output").build();
@ -552,17 +552,17 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
.addInputs("input1", "input2") .addInputs("input1", "input2")
.addLayer("dense1", .addLayer("dense1",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input1") "input1")
.addLayer("dense2", .addLayer("dense2",
DenseLayer.builder().nIn(featuresz).nOut(midsz) DenseLayer.builder().nIn(featuresz).nOut(midsz)
.activation(new ActivationTanH()).build(), .activation(Activation.TANH).build(),
"input2") "input2")
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract), .addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
"dense1", "dense2") "dense1", "dense2")
.addLayer("output", .addLayer("output",
OutputLayer.builder().nIn(midsz).nOut(outputsz) OutputLayer.builder().nIn(midsz).nOut(outputsz)
.activation(new ActivationSigmoid()) .activation(Activation.SIGMOID)
.lossFunction(LossFunction.MSE).build(), .lossFunction(LossFunction.MSE).build(),
"elementwiseSubtract") "elementwiseSubtract")
.setOutputs("output").build(); .setOutputs("output").build();

View File

@ -493,7 +493,7 @@ public class DTypeTests extends BaseDL4JTest {
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(3).activation(Activation.TANH).build(); secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(3).activation(Activation.TANH).build();
break; break;
case 4: case 4:
ol = new Yolo2OutputLayer.Builder().boundingBoxPriors(Nd4j.create(new double[][]{{1.0, 1.0}, {2.0, 2.0}}).castTo(networkDtype)).build(); ol = Yolo2OutputLayer.builder().boundingBoxes(Nd4j.create(new double[][]{{1.0, 1.0}, {2.0, 2.0}}).castTo(networkDtype)).build();
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(14).activation(Activation.TANH).build(); secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(14).activation(Activation.TANH).build();
break; break;
default: default:
@ -817,8 +817,8 @@ public class DTypeTests extends BaseDL4JTest {
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.updater(new Adam(1e-2)) .updater(new Adam(1e-2))
.list() .list()
.layer(new SpaceToBatchLayer.Builder().blocks(1, 1).build()) .layer(SpaceToBatchLayer.builder().blockSize(1, 1).build())
.layer(new SpaceToDepthLayer.Builder().blocks(2).build()) .layer(SpaceToDepthLayer.builder().blockSize(2).build())
.layer(OutputLayer.builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) .layer(OutputLayer.builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
.inputType(InputType.convolutional(28, 28, 5)) .inputType(InputType.convolutional(28, 28, 5))
.build(); .build();
@ -907,7 +907,7 @@ public class DTypeTests extends BaseDL4JTest {
.layer(Bidirectional.builder(LSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()).build()) .layer(Bidirectional.builder(LSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()).build())
.layer(new TimeDistributed(DenseLayer.builder().nIn(10).nOut(5).activation(Activation.TANH).build())) .layer(new TimeDistributed(DenseLayer.builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
.layer(SimpleRnn.builder().nIn(5).nOut(5).build()) .layer(SimpleRnn.builder().nIn(5).nOut(5).build())
.layer(new MaskZeroLayer.Builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskValue(0.0).build()) .layer(MaskZeroLayer.builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskingValue(0.0).build())
.layer(secondLast) .layer(secondLast)
.layer(ol) .layer(ol)
.build(); .build();
@ -986,7 +986,7 @@ public class DTypeTests extends BaseDL4JTest {
.updater(new NoOp()) .updater(new NoOp())
.dist(new UniformDistribution(-6, 6)) .dist(new UniformDistribution(-6, 6))
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel) .layer(PrimaryCapsules.builder(primaryCapsDim, primarpCapsChannel)
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.build()) .build())
@ -1400,9 +1400,9 @@ public class DTypeTests extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer(new SelfAttentionLayer.Builder().nOut(8).nHeads(2).projectInput(true).build()) .layer(SelfAttentionLayer.builder().nOut(8).nHeads(2).projectInput(true).build())
.layer(new LearnedSelfAttentionLayer.Builder().nOut(8).nHeads(2).nQueries(numQueries).projectInput(true).build()) .layer(LearnedSelfAttentionLayer.builder().nOut(8).nHeads(2).nQueries(numQueries).projectInput(true).build())
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build()) .layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())

View File

@ -161,7 +161,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
imageHeight)) imageHeight))
.addLayer("conv1", ConvolutionLayer.builder() .addLayer("conv1", ConvolutionLayer.builder()
.kernelSize(kernelHeight, kernelWidth).stride(1, 1) .kernelSize(kernelHeight, kernelWidth).stride(1, 1)
.dataFormat(CNN2DFormat.NCHW) .convFormat(CNN2DFormat.NCHW)
.nIn(nChannels).nOut(2).weightInit(WeightInit.XAVIER) .nIn(nChannels).nOut(2).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build(), "input") .activation(Activation.RELU).build(), "input")
.addLayer("pool1", .addLayer("pool1",

View File

@ -1163,7 +1163,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
"act") "act")
.addLayer("drop", DropoutLayer.builder(0.5).build(), "pool") .addLayer("drop", DropoutLayer.builder(0.5).build(), "pool")
.addLayer("dense", DenseLayer.builder().nIn(1).nOut(1).build(), "drop") .addLayer("dense", DenseLayer.builder().nIn(1).nOut(1).build(), "drop")
.addLayer("loss", LossLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT) .addLayer("loss", LossLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT.getILossFunction())
.build(), "dense") .build(), "dense")
.allowDisconnected(true) .allowDisconnected(true)
.setOutputs("loss").build(); .setOutputs("loss").build();
@ -1457,7 +1457,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.layer("0", SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build(), "in") .layer("0", SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build(), "in")
.layer("1", LossLayer.builder().lossFunction().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build(), "0") .layer("1", LossLayer.builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build(), "0")
.setOutputs("1") .setOutputs("1")
.setInputTypes(InputType.convolutionalFlat(28,28,1)) .setInputTypes(InputType.convolutionalFlat(28,28,1))
.build(); .build();
@ -1791,7 +1791,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.nIn(10).nOut(5) .nIn(10).nOut(5)
.activation(Activation.TANH) .activation(Activation.TANH)
.dropOut(new GaussianNoise(0.05)) .dropOut(new GaussianNoise(0.05))
.build()) .build()).build()
,"merge") ,"merge")
.addLayer("out1", .addLayer("out1",
RnnOutputLayer.builder().activation(Activation.SOFTMAX) RnnOutputLayer.builder().activation(Activation.SOFTMAX)
@ -1986,10 +1986,10 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.updater(new Adam()) .updater(new Adam())
.graphBuilder() .graphBuilder()
.addInputs("x_emb") .addInputs("x_emb")
.addLayer("agg_lstm", Bidirectional.builder(CONCAT, LSTM.builder().nOut(hiddenSize/2).build()), "x_emb") .addLayer("agg_lstm", Bidirectional.builder(CONCAT, LSTM.builder().nOut(hiddenSize/2).build()).build(), "x_emb")
.addLayer("agg_att", DenseLayer.builder().nIn(100).nOut(1).activation(Activation.SOFTMAX).build(), "agg_lstm") .addLayer("agg_att", DenseLayer.builder().nIn(100).nOut(1).activation(Activation.SOFTMAX).build(), "agg_lstm")
.addVertex("att", new PreprocessorVertex(new ComposableInputPreProcessor(new FeedForwardToRnnPreProcessor(), new PermutePreprocessor(0,2,1), new RnnToFeedForwardPreProcessor())), "agg_att") .addVertex("att", new PreprocessorVertex(new ComposableInputPreProcessor(new FeedForwardToRnnPreProcessor(), new PermutePreprocessor(0,2,1), new RnnToFeedForwardPreProcessor())), "agg_att")
.addLayer("att_repeat", new RepeatVector.Builder(hiddenSize).build(),"att") .addLayer("att_repeat", RepeatVector.builder().repetitionFactor(hiddenSize).build(),"att")
.addVertex("att_trans", new PreprocessorVertex(new PermutePreprocessor(0, 2, 1)), "att_repeat") .addVertex("att_trans", new PreprocessorVertex(new PermutePreprocessor(0, 2, 1)), "att_repeat")
.addVertex("mult", new ElementWiseVertex(ElementWiseVertex.Op.Product), "agg_lstm", "att_trans") .addVertex("mult", new ElementWiseVertex(ElementWiseVertex.Op.Product), "agg_lstm", "att_trans")
.addLayer("sum", GlobalPoolingLayer.builder().build(), "mult") .addLayer("sum", GlobalPoolingLayer.builder().build(), "mult")
@ -2197,16 +2197,16 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.addInputs("in") .addInputs("in")
.layer("l0", ConvolutionLayer.builder() .layer("l0", ConvolutionLayer.builder()
.nOut(16) .nOut(16)
.dataFormat(CNN2DFormat.NHWC) .convFormat(CNN2DFormat.NHWC)
.kernelSize(2,2).stride(1,1) .kernelSize(2,2).stride(1,1)
.build(), "in") .build(), "in")
.layer("l1", ConvolutionLayer.builder() .layer("l1", ConvolutionLayer.builder()
.nOut(8) .nOut(8)
.dataFormat(CNN2DFormat.NHWC) .convFormat(CNN2DFormat.NHWC)
.kernelSize(2,2).stride(1,1) .kernelSize(2,2).stride(1,1)
.build(), "in") .build(), "in")
.addVertex("merge", new MergeVertex(), "l0", "l1") .addVertex("merge", new MergeVertex(), "l0", "l1")
.layer("out", CnnLossLayer.builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge") .layer("out", CnnLossLayer.builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build(), "merge")
.setOutputs("out") .setOutputs("out")
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC)) .setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
.build(); .build();

View File

@ -357,7 +357,7 @@ public class ActivationLayerTest extends BaseDL4JTest {
.activation(Activation.RATIONALTANH) .activation(Activation.RATIONALTANH)
.layer(DenseLayer.builder().nIn(10).nOut(10).build()) .layer(DenseLayer.builder().nIn(10).nOut(10).build())
.layer(ActivationLayer.builder()) .layer(ActivationLayer.builder().build())
.layer(ActivationLayer.builder().build()) .layer(ActivationLayer.builder().build())
.layer(ActivationLayer.builder().activation(Activation.ELU).build()) .layer(ActivationLayer.builder().activation(Activation.ELU).build())
.layer( .layer(
@ -404,7 +404,7 @@ public class ActivationLayerTest extends BaseDL4JTest {
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in") .addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in")
.addLayer("1", ActivationLayer.builder(), "0") .addLayer("1", ActivationLayer.builder().build(), "0")
.addLayer("2", ActivationLayer.builder().build(), "1") .addLayer("2", ActivationLayer.builder().build(), "1")
.addLayer("3", ActivationLayer.builder().activation(Activation.ELU).build(), "2") .addLayer("3", ActivationLayer.builder().activation(Activation.ELU).build(), "2")
.addLayer( .addLayer(

View File

@ -63,7 +63,7 @@ public class CapsNetMNISTTest extends BaseDL4JTest {
.kernelSize(9, 9) .kernelSize(9, 9)
.stride(3, 3) .stride(3, 3)
.build()) .build())
.layer(new PrimaryCapsules.Builder(8, 8) .layer(PrimaryCapsules.builder(8, 8)
.kernelSize(7, 7) .kernelSize(7, 7)
.stride(2, 2) .stride(2, 2)
.build()) .build())

View File

@ -44,7 +44,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
@Test @Test
public void testOutputType(){ public void testOutputType(){
PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8) PrimaryCapsules layer = PrimaryCapsules.builder(8, 8)
.kernelSize(7, 7) .kernelSize(7, 7)
.stride(2, 2) .stride(2, 2)
.build(); .build();
@ -57,7 +57,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
@Test @Test
public void testInputType(){ public void testInputType(){
PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8) PrimaryCapsules layer = PrimaryCapsules.builder(8, 8)
.kernelSize(7, 7) .kernelSize(7, 7)
.stride(2, 2) .stride(2, 2)
.build(); .build();
@ -72,7 +72,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
@Test @Test
public void testConfig(){ public void testConfig(){
PrimaryCapsules layer1 = new PrimaryCapsules.Builder(8, 10) PrimaryCapsules layer1 = PrimaryCapsules.builder(8, 10)
.kernelSize(5, 5) .kernelSize(5, 5)
.stride(4, 4) .stride(4, 4)
.useLeakyReLU(0.5) .useLeakyReLU(0.5)
@ -84,22 +84,22 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
assertArrayEquals(new int[]{4, 4}, layer1.getStride()); assertArrayEquals(new int[]{4, 4}, layer1.getStride());
assertArrayEquals(new int[]{0, 0}, layer1.getPadding()); assertArrayEquals(new int[]{0, 0}, layer1.getPadding());
assertArrayEquals(new int[]{1, 1}, layer1.getDilation()); assertArrayEquals(new int[]{1, 1}, layer1.getDilation());
assertTrue(layer1.isUseRelu()); assertTrue(layer1.isUseRelU());
assertEquals(0.5, layer1.getLeak(), 0.001); assertEquals(0.5, layer1.getUseLeakyReLU(), 0.001);
PrimaryCapsules layer2 = new PrimaryCapsules.Builder(8, 10) PrimaryCapsules layer2 = PrimaryCapsules.builder(8, 10)
.kernelSize(5, 5) .kernelSize(5, 5)
.stride(4, 4) .stride(4, 4)
.build(); .build();
assertFalse(layer2.isUseRelu()); assertFalse(layer2.isUseRelU());
PrimaryCapsules layer3 = new PrimaryCapsules.Builder(8, 10) PrimaryCapsules layer3 = PrimaryCapsules.builder(8, 10)
.kernelSize(5, 5) .kernelSize(5, 5)
.stride(4, 4) .stride(4, 4)
.useReLU() .useReLU()
.build(); .build();
assertTrue(layer3.isUseRelu()); assertTrue(layer3.isUseRelU());
assertEquals(0, layer3.getLeak(), 0.001); assertEquals(0, layer3.getUseLeakyReLU(), 0.001);
} }
@ -108,7 +108,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(123) .seed(123)
.list() .list()
.layer(new PrimaryCapsules.Builder(8, 10) .layer(PrimaryCapsules.builder(8, 10)
.kernelSize(5, 5) .kernelSize(5, 5)
.stride(4, 4) .stride(4, 4)
.useLeakyReLU(0.5) .useLeakyReLU(0.5)

View File

@ -561,7 +561,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
.dataFormat(format) .convFormat(format)
.nOut(3) .nOut(3)
.helperAllowFallback(false) .helperAllowFallback(false)
.build(), format, cm, null); .build(), format, cm, null);
@ -685,14 +685,14 @@ public class ConvDataFormatTests extends BaseDL4JTest {
return getNetWithLayer(Deconvolution2D.builder().nOut(2) return getNetWithLayer(Deconvolution2D.builder().nOut(2)
.activation(Activation.TANH) .activation(Activation.TANH)
.kernelSize(2,2) .kernelSize(2,2)
.dataFormat(format) .convFormat(format)
.stride(2,2) .stride(2,2)
.build(), format, cm, null); .build(), format, cm, null);
} else { } else {
return getNetWithLayer(Deconvolution2D.builder().nOut(2) return getNetWithLayer(Deconvolution2D.builder().nOut(2)
.activation(Activation.TANH) .activation(Activation.TANH)
.kernelSize(2,2) .kernelSize(2,2)
.dataFormat(format) .convFormat(format)
.stride(2,2) .stride(2,2)
.build(), format, cm, null); .build(), format, cm, null);
} }
@ -715,26 +715,26 @@ public class ConvDataFormatTests extends BaseDL4JTest {
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToDepthLayer.Builder() return getNetWithLayer(SpaceToDepthLayer.builder()
.blocks(2) .blockSize(2)
.dataFormat(format) .dataFormat(format)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} else { } else {
return getNetWithLayer(new SpaceToDepthLayer.Builder() return getNetWithLayer(SpaceToDepthLayer.builder()
.blocks(2) .blockSize(2)
.build(), format, ConvolutionMode.Same, null); .build(), format, ConvolutionMode.Same, null);
} }
} }
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) { private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
if (setOnLayerAlso) { if (setOnLayerAlso) {
return getNetWithLayer(new SpaceToBatchLayer.Builder() return getNetWithLayer(SpaceToBatchLayer.builder()
.blocks(2, 2) .blockSize(2, 2)
.dataFormat(format) .dataFormat(format)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} else { } else {
return getNetWithLayer(new SpaceToBatchLayer.Builder() return getNetWithLayer(SpaceToBatchLayer.builder()
.blocks(2, 2) .blockSize(2, 2)
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format)); .build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
} }
} }
@ -807,7 +807,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
.kernelSize(3, 3) .kernelSize(3, 3)
.stride(2, 2) .stride(2, 2)
.activation(Activation.TANH) .activation(Activation.TANH)
.dataFormat(format) .convFormat(format)
.nOut(3) .nOut(3)
.helperAllowFallback(false) .helperAllowFallback(false)
.build()); .build());
@ -988,7 +988,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
switch (i){ switch (i){
case 0: case 0:
b.layer(ConvolutionLayer.builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build()); b.layer(ConvolutionLayer.builder().kernelSize(2,2).nIn(3).nOut(3).convFormat(df).build());
b.inputType(InputType.convolutional(12,12,3,df)); b.inputType(InputType.convolutional(12,12,3,df));
break; break;
case 1: case 1:
@ -996,7 +996,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
b.inputType(InputType.convolutional(12,12,3,df)); b.inputType(InputType.convolutional(12,12,3,df));
break; break;
case 2: case 2:
b.layer(Deconvolution2D.builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build()); b.layer(Deconvolution2D.builder().convFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
b.inputType(InputType.convolutional(12,12,3,df)); b.inputType(InputType.convolutional(12,12,3,df));
break; break;
case 3: case 3:

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -370,7 +371,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest {
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list()
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 .layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14
.layer(new SpaceToBatchLayer.Builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7 .layer(SpaceToBatchLayer.builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7
.layer(OutputLayer.builder().nOut(3).activation(Activation.SOFTMAX).build()) .layer(OutputLayer.builder().nOut(3).activation(Activation.SOFTMAX).build())
.inputType(InputType.convolutional(28, 28, 1)); .inputType(InputType.convolutional(28, 28, 1));
@ -389,11 +390,11 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest {
int blocks = 2; int blocks = 2;
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder()
//(28-2+0)/2+1 = 14 -> 14x14x3 out //(28-2+0)/2+1 = 14 -> 14x14x3 out
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) .layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build())
// Divide space dimensions by blocks, i.e. 14/2 = 7 -> 7x7x12 out (3x2x2 depth) // Divide space dimensions by blocks, i.e. 14/2 = 7 -> 7x7x12 out (3x2x2 depth)
.layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build()) .layer(SpaceToDepthLayer.builder().blockSize(blocks).dataFormat(CNN2DFormat.NCHW).build())
.layer(OutputLayer.builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()) // nIn of the next layer gets multiplied by 2*2. .layer(OutputLayer.builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()) // nIn of the next layer gets multiplied by 2*2.
.inputType(InputType.convolutional(28, 28, 1)); .inputType(InputType.convolutional(28, 28, 1));

View File

@ -71,7 +71,7 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest {
.layer(LocallyConnected2D.builder().kernelSize(8, 8).nIn(3) .layer(LocallyConnected2D.builder().kernelSize(8, 8).nIn(3)
.stride(4, 4).nOut(16).dropOut(0.5) .stride(4, 4).nOut(16).dropOut(0.5)
.convolutionMode(ConvolutionMode.Strict) .convolutionMode(ConvolutionMode.Strict)
.setInputSize(28, 28) .inputSize(28, 28)
.activation(Activation.RELU).weightInit( .activation(Activation.RELU).weightInit(
WeightInit.XAVIER) WeightInit.XAVIER)
.build()) .build())
@ -94,11 +94,10 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest {
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123) NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4)
.updater(new Nesterovs(0.9)).dropOut(0.5) .updater(new Nesterovs(0.9)).dropOut(0.5)
.list()
.layer(LocallyConnected1D.builder().kernelSize(4).nIn(3) .layer(LocallyConnected1D.builder().kernelSize(4).nIn(3)
.stride(1).nOut(16).dropOut(0.5) .stride(1).nOut(16).dropOut(0.5)
.convolutionMode(ConvolutionMode.Strict) .convolutionMode(ConvolutionMode.Strict)
.setInputSize(28) .inputSize(28)
.activation(Activation.RELU).weightInit( .activation(Activation.RELU).weightInit(
WeightInit.XAVIER) WeightInit.XAVIER)
.build()) .build())

View File

@ -61,7 +61,7 @@ public class SpaceToDepthTest extends BaseDL4JTest {
private Layer getSpaceToDepthLayer() { private Layer getSpaceToDepthLayer() {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
.layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build(); .layer(SpaceToDepthLayer.builder().blockSize(blockSize).dataFormat(dataFormat.toFormat()).build()).build();
return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType());
} }

View File

@ -20,6 +20,8 @@
package org.deeplearning4j.nn.layers.custom; package org.deeplearning4j.nn.layers.custom;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
@ -30,31 +32,38 @@ import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestCustomActivation extends BaseDL4JTest { public class TestCustomActivation extends BaseDL4JTest {
@Test @Test
public void testCustomActivationFn() { public void testCustomActivationFn() {
//Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works... // Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config
// actually works...
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)).list() NeuralNetConfiguration conf =
.layer(0, DenseLayer.builder().nIn(10).nOut(10).activation(new CustomActivation()).build()) NeuralNetConfiguration.builder()
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) .updater(new Sgd(0.1))
.layer(
0, DenseLayer.builder().nIn(10).nOut(10).activation(new CustomActivation()).build())
.layer(
1,
OutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(10)
.nOut(10)
.build())
.build(); .build();
String json = conf.toJson(); String json = conf.toJson();
String yaml = conf.toYaml(); String yaml = conf.toYaml();
// System.out.println(json); // System.out.println(json);
NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json); NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json);
assertEquals(conf, confFromJson); assertEquals(conf, confFromJson);
NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml); NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml);
assertEquals(conf, confFromYaml); assertEquals(conf, confFromYaml);
} }
} }

View File

@ -119,7 +119,7 @@ public class TestCustomLayers extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration conf =
NeuralNetConfiguration.builder().seed(12345).list() NeuralNetConfiguration.builder().seed(12345).list()
.layer(0, DenseLayer.builder().nIn(10).nOut(10).build()) .layer(0, DenseLayer.builder().nIn(10).nOut(10).build())
.layer(1, new CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT) .layer(1, CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX) .activation(Activation.SOFTMAX)
.nIn(10).nOut(10).build()) .nIn(10).nOut(10).build())
.build(); .build();
@ -172,7 +172,7 @@ public class TestCustomLayers extends BaseDL4JTest {
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
.graphBuilder().addInputs("in") .graphBuilder().addInputs("in")
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in").addLayer("1", .addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in").addLayer("1",
new CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10) CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10)
.nOut(10).activation(Activation.SOFTMAX).build(), .nOut(10).activation(Activation.SOFTMAX).build(),
"0") "0")
.setOutputs("1").build(); .setOutputs("1").build();

View File

@ -91,8 +91,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
.l2(0.01) .l2(0.01)
.list() .list()
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1,1).build()) .layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1,1).build())
.layer(new Yolo2OutputLayer.Builder() .layer(Yolo2OutputLayer.builder()
.boundingBoxPriors(bbPrior) .boundingBoxes(bbPrior)
.build()) .build())
.build(); .build();
@ -179,8 +179,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.list() .list()
.layer(ConvolutionLayer.builder().nIn(1).nOut(1).kernelSize(1,1).build()) .layer(ConvolutionLayer.builder().nIn(1).nOut(1).kernelSize(1,1).build())
.layer(new Yolo2OutputLayer.Builder() .layer(Yolo2OutputLayer.builder()
.boundingBoxPriors(bbPrior) .boundingBoxes(bbPrior)
.build()) .build())
.build(); .build();
@ -337,8 +337,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.list() .list()
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nIn(3).nOut(3).build()) .layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nIn(3).nOut(3).build())
.layer(new Yolo2OutputLayer.Builder() .layer(Yolo2OutputLayer.builder()
.boundingBoxPriors(bbPriors) .boundingBoxes(bbPriors)
.build()) .build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -506,8 +506,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
.layer(ConvolutionLayer.builder().kernelSize(5,5).stride(2,2).nOut(256).build()) .layer(ConvolutionLayer.builder().kernelSize(5,5).stride(2,2).nOut(256).build())
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2)/*.poolingType(SubsamplingLayer.PoolingType.AVG)*/.build()) .layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2)/*.poolingType(SubsamplingLayer.PoolingType.AVG)*/.build())
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(5,5).stride(1,1).nOut(depthOut).build()) .layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(5,5).stride(1,1).nOut(depthOut).build())
.layer(new Yolo2OutputLayer.Builder() .layer(Yolo2OutputLayer.builder()
.boundingBoxPriors(bbPriors) .boundingBoxes(bbPriors)
.build()) .build())
.inputType(InputType.convolutional(h,w,c)) .inputType(InputType.convolutional(h,w,c))
.build(); .build();

View File

@ -209,7 +209,7 @@ public class RnnDataFormatTests extends BaseDL4JTest {
return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3) return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3)
.dataFormat(format).build(), format, lastTimeStep, maskZeros); .dataFormat(format).build(), format, lastTimeStep, maskZeros);
} else { } else {
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3).build(), format, lastTimeStep, maskZeros);
} }
} }
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) { private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
@ -240,7 +240,7 @@ public class RnnDataFormatTests extends BaseDL4JTest {
} }
private MultiLayerNetwork getNetWithLayer(LayerConfiguration layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) { private MultiLayerNetwork getNetWithLayer(LayerConfiguration layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
if (maskZeros){ if (maskZeros){
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build(); layer = MaskZeroLayer.builder().maskingValue(0.).underlying(layer).build();
} }
if(lastTimeStep){ if(lastTimeStep){
layer = new LastTimeStep(layer); layer = new LastTimeStep(layer);

View File

@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInitDistribution;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,17 +49,17 @@ public class TestRecurrentWeightInit extends BaseDL4JTest {
switch (i) { switch (i) {
case 0: case 0:
b.layer(LSTM.builder().nIn(10).nOut(10) b.layer(LSTM.builder().nIn(10).nOut(10)
.weightInitRecurrent(new UniformDistribution(2, 3)) .weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3)))
.build()); .build());
break; break;
case 1: case 1:
b.layer(GravesLSTM.builder().nIn(10).nOut(10) b.layer(GravesLSTM.builder().nIn(10).nOut(10)
.weightInitRecurrent(new UniformDistribution(2, 3)) .weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3)))
.build()); .build());
break; break;
case 2: case 2:
b.layer(SimpleRnn.builder().nIn(10).nOut(10) b.layer(SimpleRnn.builder().nIn(10).nOut(10)
.weightInitRecurrent(new UniformDistribution(2, 3)).build()); .weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3))).build());
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();

View File

@ -145,8 +145,8 @@ public class TestTimeDistributed extends BaseDL4JTest {
l2 = SimpleRnn.builder().nOut(5).build(); l2 = SimpleRnn.builder().nOut(5).build();
break; break;
case 2: case 2:
l0 = Bidirectional.builder(LSTM.builder().nOut(5).build()); l0 = Bidirectional.builder(LSTM.builder().nOut(5).build()).build();
l2 = Bidirectional.builder(LSTM.builder().nOut(5).build()); l2 = Bidirectional.builder(LSTM.builder().nOut(5).build()).build();
break; break;
default: default:
throw new RuntimeException("Not implemented: " + rnnType); throw new RuntimeException("Not implemented: " + rnnType);

View File

@ -67,7 +67,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.list() .list()
.layer(new SameDiffConv.Builder().nIn(nIn).nOut(nOut).kernelSize(kH, kW).build()) .layer(SameDiffConv.builder().nIn(nIn).nOut(nOut).kernelSize(kH, kW).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -131,7 +131,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
.dataType(DataType.DOUBLE) .dataType(DataType.DOUBLE)
.seed(12345) .seed(12345)
.list() .list()
.layer(new SameDiffConv.Builder() .layer(SameDiffConv.builder()
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.nIn(nIn) .nIn(nIn)
.nOut(nOut) .nOut(nOut)
@ -142,7 +142,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
.activation(a) .activation(a)
.hasBias(hasBias) .hasBias(hasBias)
.build()) .build())
.layer(new SameDiffConv.Builder() .layer(SameDiffConv.builder()
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.nIn(nOut) .nIn(nOut)
.nOut(nOut) .nOut(nOut)
@ -273,7 +273,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.list() .list()
.layer(new SameDiffConv.Builder() .layer(SameDiffConv.builder()
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.nIn(nIn) .nIn(nIn)
.nOut(nOut) .nOut(nOut)
@ -284,7 +284,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
.activation(Activation.TANH) .activation(Activation.TANH)
.hasBias(hasBias) .hasBias(hasBias)
.build()) .build())
.layer(new SameDiffConv.Builder() .layer(SameDiffConv.builder()
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.nIn(nOut) .nIn(nOut)
.nOut(nOut) .nOut(nOut)

View File

@ -65,7 +65,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build()) .layer(SameDiffDense.builder().nIn(nIn).nOut(nOut).build())
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
@ -106,7 +106,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm)
.trainingWorkspaceMode(wsm) .trainingWorkspaceMode(wsm)
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut) .layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
.activation(a) .activation(a)
.build()) .build())
.build(); .build();
@ -178,10 +178,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.seed(12345) .seed(12345)
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut) .layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(a).build()) .activation(a).build())
.layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut) .layer(SameDiffDense.builder().nIn(nOut).nOut(nOut)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(a).build()) .activation(a).build())
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut) .layer(OutputLayer.builder().nIn(nOut).nOut(nOut)
@ -267,7 +267,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut) .layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
.activation(a) .activation(a)
.build()) .build())
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
@ -357,8 +357,8 @@ public class TestSameDiffDense extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm)
.updater(new Adam(0.1)) .updater(new Adam(0.1))
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build()) .layer(SameDiffDense.builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
.layer(new SameDiffDense.Builder().nIn(5).nOut(5).activation(Activation.TANH).build()) .layer(SameDiffDense.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(OutputLayer.builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build(); .build();
@ -428,8 +428,8 @@ public class TestSameDiffDense extends BaseDL4JTest {
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).activation(a).build()) .layer(SameDiffDense.builder().nIn(nIn).nOut(nOut).activation(a).build())
.layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut).activation(a).build()) .layer(SameDiffDense.builder().nIn(nOut).nOut(nOut).activation(a).build())
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
//.inputType(InputType.feedForward(nIn)) //TODO //.inputType(InputType.feedForward(nIn)) //TODO

View File

@ -60,7 +60,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
.updater(new Adam(0.01)) .updater(new Adam(0.01))
.list() .list()
.layer(DenseLayer.builder().nIn(5).nOut(5).activation(Activation.TANH).build()) .layer(DenseLayer.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(LossLayer.builder().lossFunction().activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build()) .layer(LossLayer.builder().activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build())
.build(); .build();
MultiLayerNetwork netSD = new MultiLayerNetwork(confSD); MultiLayerNetwork netSD = new MultiLayerNetwork(confSD);

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.layers.samediff.testlayers; package org.deeplearning4j.nn.layers.samediff.testlayers;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -45,52 +46,62 @@ import java.util.*;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties({"paramShapes"}) @JsonIgnoreProperties({"paramShapes"})
@NoArgsConstructor
@SuperBuilder
public class SameDiffConv extends SameDiffLayer { public class SameDiffConv extends SameDiffLayer {
public static abstract class SameDiffConvBuilder<C extends SameDiffConv, B extends SameDiffConvBuilder<C, B>> extends
SameDiffLayerBuilder<C, B> {
public B kernelSize(int... k) {
this.kernelSize$value = k;
this.kernelSize$set = true;
return self();
}
public B stride(int... s) {
this.stride$value = s;
this.stride$set = true;
return self();
}
public B padding(int... p) {
this.padding$value = p;
this.padding$set = true;
return self();
}
}
private static final List<String> WEIGHT_KEYS = Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY); private static final List<String> WEIGHT_KEYS = Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY);
private static final List<String> BIAS_KEYS = Collections.singletonList(ConvolutionParamInitializer.BIAS_KEY); private static final List<String> BIAS_KEYS = Collections.singletonList(ConvolutionParamInitializer.BIAS_KEY);
//Order to match 'vanilla' conv layer implementation, for easy comparison //Order to match 'vanilla' conv layer implementation, for easy comparison
private static final List<String> PARAM_KEYS = Arrays.asList(ConvolutionParamInitializer.BIAS_KEY, ConvolutionParamInitializer.WEIGHT_KEY); private static final List<String> PARAM_KEYS = Arrays.asList(ConvolutionParamInitializer.BIAS_KEY, ConvolutionParamInitializer.WEIGHT_KEY);
private long nIn;
private long nOut;
private Activation activation;
private int[] kernel;
private int[] stride;
private int[] padding;
private ConvolutionMode cm;
private int[] dilation;
private boolean hasBias;
protected SameDiffConv(Builder b) { private int nIn;
super(b); private int nOut;
this.nIn = b.nIn; @Builder.Default private Activation activation = Activation.TANH;
this.nOut = b.nOut; @Builder.Default private int[] kernelSize = new int[]{2, 2};
this.activation = b.activation;
this.kernel = b.kernel; @Builder.Default private int[] stride = new int[]{1, 1};
this.stride = b.stride; @Builder.Default private int[] padding = new int[]{0, 0};
this.padding = b.padding; @Builder.Default private int[] dilation = new int[]{1, 1};
this.cm = b.cm; @Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Same;
this.dilation = b.dilation; @Builder.Default private boolean hasBias = true;
this.hasBias = b.hasBias;
}
private SameDiffConv(){
//No arg constructor for Jackson/JSON serialization
}
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernel, stride, padding, new int[]{1, 1}, return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, new int[]{1, 1},
cm, nOut, layerIndex, getName(), SameDiffConv.class); convolutionMode, nOut, layerIndex, getName(), SameDiffConv.class);
} }
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
if (nIn <= 0 || override) { if (nIn <= 0 || override) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
this.nIn = c.getChannels(); this.nIn = (int) c.getChannels();
} }
} }
@ -102,7 +113,7 @@ public class SameDiffConv extends SameDiffLayer {
@Override @Override
public void defineParameters(SDLayerParams params) { public void defineParameters(SDLayerParams params) {
params.clear(); params.clear();
val weightsShape = new long[]{kernel[0], kernel[1], nIn, nOut}; //[kH, kW, iC, oC] in libnd4j val weightsShape = new long[]{kernelSize[0], kernelSize[1], nIn, nOut}; //[kH, kW, iC, oC] in libnd4j
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
if(hasBias) { if(hasBias) {
val biasShape = new long[]{1, nOut}; val biasShape = new long[]{1, nOut};
@ -113,8 +124,8 @@ public class SameDiffConv extends SameDiffLayer {
@Override @Override
public void initializeParameters(Map<String, INDArray> params) { public void initializeParameters(Map<String, INDArray> params) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
double fanIn = nIn * kernel[0] * kernel[1]; double fanIn = nIn * kernelSize[0] * kernelSize[1];
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]); double fanOut = nOut * kernelSize[0] * kernelSize[1] / ((double) stride[0] * stride[1]);
for (Map.Entry<String, INDArray> e : params.entrySet()) { for (Map.Entry<String, INDArray> e : params.entrySet()) {
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){ if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue()); paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
@ -135,11 +146,11 @@ public class SameDiffConv extends SameDiffLayer {
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY); SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
Conv2DConfig c = Conv2DConfig.builder() Conv2DConfig c = Conv2DConfig.builder()
.kH(kernel[0]).kW(kernel[1]) .kH(kernelSize[0]).kW(kernelSize[1])
.pH(padding[0]).pW(padding[1]) .pH(padding[0]).pW(padding[1])
.sH(stride[0]).sW(stride[1]) .sH(stride[0]).sW(stride[1])
.dH(dilation[0]).dW(dilation[1]) .dH(dilation[0]).dW(dilation[1])
.isSameMode(this.cm == ConvolutionMode.Same) .isSameMode(this.convolutionMode == ConvolutionMode.Same)
.build(); .build();
SDVariable conv = null; SDVariable conv = null;
@ -159,72 +170,10 @@ public class SameDiffConv extends SameDiffLayer {
if (activation == null) { if (activation == null) {
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation()); activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
} }
if (cm == null) { if (convolutionMode == null) {
cm = clone.getConvolutionMode(); convolutionMode = clone.getConvolutionMode();
} }
} }
public static class Builder extends SameDiffLayer.Builder<Builder> {
private int nIn;
private int nOut;
private Activation activation = Activation.TANH;
private int[] kernel = new int[]{2, 2};
private int[] stride = new int[]{1, 1};
private int[] padding = new int[]{0, 0};
private int[] dilation = new int[]{1, 1};
private ConvolutionMode cm = ConvolutionMode.Same;
private boolean hasBias = true;
public Builder nIn(int nIn) {
this.nIn = nIn;
return this;
}
public Builder nOut(int nOut) {
this.nOut = nOut;
return this;
}
public Builder activation(Activation activation) {
this.activation = activation;
return this;
}
public Builder kernelSize(int... k) {
this.kernel = k;
return this;
}
public Builder stride(int... s) {
this.stride = s;
return this;
}
public Builder padding(int... p) {
this.padding = p;
return this;
}
public Builder convolutionMode(ConvolutionMode cm) {
this.cm = cm;
return this;
}
public Builder dilation(int... d) {
this.dilation = d;
return this;
}
public Builder hasBias(boolean hasBias){
this.hasBias = hasBias;
return this;
}
@Override
public SameDiffConv build() {
return new SameDiffConv(this);
}
}
} }

View File

@ -22,6 +22,8 @@ package org.deeplearning4j.nn.layers.samediff.testlayers;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -40,30 +42,22 @@ import java.util.*;
@Data @Data
@EqualsAndHashCode(callSuper = true, exclude = {"paramShapes"}) @EqualsAndHashCode(callSuper = true, exclude = {"paramShapes"})
@NoArgsConstructor()
@JsonIgnoreProperties("paramShapes") @JsonIgnoreProperties("paramShapes")
@SuperBuilder
public class SameDiffDense extends SameDiffLayer { public class SameDiffDense extends SameDiffLayer {
private static final List<String> W_KEYS = Collections.singletonList(DefaultParamInitializer.WEIGHT_KEY); private static final List<String> W_KEYS = Collections.singletonList(DefaultParamInitializer.WEIGHT_KEY);
private static final List<String> B_KEYS = Collections.singletonList(DefaultParamInitializer.BIAS_KEY); private static final List<String> B_KEYS = Collections.singletonList(DefaultParamInitializer.BIAS_KEY);
private static final List<String> PARAM_KEYS = Arrays.asList(DefaultParamInitializer.WEIGHT_KEY, DefaultParamInitializer.BIAS_KEY); private static final List<String> PARAM_KEYS = Arrays.asList(DefaultParamInitializer.WEIGHT_KEY, DefaultParamInitializer.BIAS_KEY);
private Map<String,long[]> paramShapes; private final Map<String, long[]> paramShapes = new HashMap<>();
private long nIn; private long nIn;
private long nOut; private long nOut;
private Activation activation; private Activation activation;
protected SameDiffDense(Builder builder) {
super(builder);
nIn = builder.nIn;
nOut = builder.nOut;
activation = builder.activation;
}
private SameDiffDense(){
//No op constructor for Jackson
}
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
@ -128,31 +122,5 @@ public class SameDiffDense extends SameDiffLayer {
return 'f'; return 'f';
} }
public static class Builder extends SameDiffLayer.Builder<Builder> {
private int nIn;
private int nOut;
private Activation activation;
public Builder nIn(int nIn){
this.nIn = nIn;
return this;
}
public Builder nOut(int nOut){
this.nOut = nOut;
return this;
}
public Builder activation(Activation activation){
this.activation = activation;
return this;
}
@Override
public SameDiffDense build() {
return new SameDiffDense(this);
}
}
} }

View File

@ -58,7 +58,7 @@ public class TestVAE extends BaseDL4JTest {
NeuralNetConfiguration mlc = NeuralNetConfiguration mlc =
NeuralNetConfiguration.builder() NeuralNetConfiguration.builder()
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.nIn(10).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13) .nIn(10).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)
.build()) .build())
.build(); .build();
@ -95,7 +95,7 @@ public class TestVAE extends BaseDL4JTest {
for (int i = 0; i < encLayerSizes.length; i++) { for (int i = 0; i < encLayerSizes.length; i++) {
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list().layer(0, NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list().layer(0,
new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder().nIn(10) org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder().nIn(10)
.nOut(5).encoderLayerSizes(encLayerSizes[i]).decoderLayerSizes(13).build()) .nOut(5).encoderLayerSizes(encLayerSizes[i]).decoderLayerSizes(13).build())
.build(); .build();
@ -121,7 +121,7 @@ public class TestVAE extends BaseDL4JTest {
int inputSize = 3; int inputSize = 3;
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list() NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.nIn(inputSize).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build()) .nIn(inputSize).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
.build(); .build();
@ -159,7 +159,7 @@ public class TestVAE extends BaseDL4JTest {
public void testParamGradientOrderAndViews() { public void testParamGradientOrderAndViews() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list() NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build()) .nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
.build(); .build();
@ -217,7 +217,7 @@ public class TestVAE extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().seed(12345).list() NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().seed(12345).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build()) .nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(5).nOut(6) .layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(5).nOut(6)
.activation(new ActivationTanH()).build()) .activation(new ActivationTanH()).build())
@ -269,22 +269,22 @@ public class TestVAE extends BaseDL4JTest {
public void testJsonYaml() { public void testJsonYaml() {
NeuralNetConfiguration config = NeuralNetConfiguration.builder().seed(12345).list() NeuralNetConfiguration config = NeuralNetConfiguration.builder().seed(12345).list()
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.IDENTITY)) .reconstructionDistribution(new GaussianReconstructionDistribution(Activation.IDENTITY))
.nIn(3).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build()) .nIn(3).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(1, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.TANH)) .reconstructionDistribution(new GaussianReconstructionDistribution(Activation.TANH))
.nIn(7).nOut(8).encoderLayerSizes(9).decoderLayerSizes(10).build()) .nIn(7).nOut(8).encoderLayerSizes(9).decoderLayerSizes(10).build())
.layer(2, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(2, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.reconstructionDistribution(new BernoulliReconstructionDistribution()).nIn(11) .reconstructionDistribution(new BernoulliReconstructionDistribution()).nIn(11)
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build()) .nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
.layer(3, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(3, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.reconstructionDistribution(new ExponentialReconstructionDistribution(Activation.TANH)) .reconstructionDistribution(new ExponentialReconstructionDistribution(Activation.TANH))
.nIn(11).nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build()) .nIn(11).nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
.layer(4, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(4, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.lossFunction(new ActivationTanH(), LossFunctions.LossFunction.MSE).nIn(11) .lossFunction(new ActivationTanH(), LossFunctions.LossFunction.MSE).nIn(11)
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build()) .nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
.layer(5, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .layer(5, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
.reconstructionDistribution(new CompositeReconstructionDistribution.Builder() .reconstructionDistribution(new CompositeReconstructionDistribution.Builder()
.addDistribution(5, new GaussianReconstructionDistribution()) .addDistribution(5, new GaussianReconstructionDistribution())
.addDistribution(5, .addDistribution(5,

View File

@ -59,7 +59,7 @@ public class TestMemoryReports extends BaseDL4JTest {
l.add(new Pair<>(DropoutLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20))); l.add(new Pair<>(DropoutLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
l.add(new Pair<>(EmbeddingLayer.builder().nIn(1).nOut(20).build(), InputType.feedForward(20))); l.add(new Pair<>(EmbeddingLayer.builder().nIn(1).nOut(20).build(), InputType.feedForward(20)));
l.add(new Pair<>(OutputLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20))); l.add(new Pair<>(OutputLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
l.add(new Pair<>(LossLayer.builder().lossFunction().build(), InputType.feedForward(20))); l.add(new Pair<>(LossLayer.builder().build(), InputType.feedForward(20)));
//RNN layers: //RNN layers:
l.add(new Pair<>(GravesLSTM.builder().nIn(20).nOut(20).build(), InputType.recurrent(20, 30))); l.add(new Pair<>(GravesLSTM.builder().nIn(20).nOut(20).build(), InputType.recurrent(20, 30)));

View File

@ -469,7 +469,7 @@ public class WorkspaceTests extends BaseDL4JTest {
.addLayer("a", GravesLSTM.builder().nOut(300).activation(Activation.HARDTANH).build(), "embeddings") .addLayer("a", GravesLSTM.builder().nOut(300).activation(Activation.HARDTANH).build(), "embeddings")
.addVertex("b", new LastTimeStepVertex("in"), "a") .addVertex("b", new LastTimeStepVertex("in"), "a")
.addLayer("c", DenseLayer.builder().nOut(300).activation(Activation.HARDTANH).build(), "b") .addLayer("c", DenseLayer.builder().nOut(300).activation(Activation.HARDTANH).build(), "b")
.addLayer("output", LossLayer.builder().lossFunction().lossFunction(LossFunctions.LossFunction.COSINE_PROXIMITY).build(), "c") .addLayer("output", LossLayer.builder().lossFunction(LossFunctions.LossFunction.COSINE_PROXIMITY.getILossFunction()).build(), "c")
.setOutputs("output") .setOutputs("output")
.build(); .build();

View File

@ -1455,10 +1455,10 @@ public class MultiLayerTest extends BaseDL4JTest {
NeuralNetConfiguration conf = NeuralNetConfiguration.builder() NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
.l2(0.01) .l2(0.01)
.list()
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1, 1).build()) .layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1, 1).build())
.layer(new Yolo2OutputLayer.Builder() .layer(Yolo2OutputLayer.builder()
.boundingBoxPriors(bbPrior) .boundingBoxes(bbPrior)
.build()) .build())
.build(); .build();

View File

@ -500,10 +500,10 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
.addInputs(inputName) .addInputs(inputName)
.setOutputs(outputName) .setOutputs(outputName)
.setInputTypes(InputType.inferInputTypes(input)) .setInputTypes(InputType.inferInputTypes(input))
.addLayer(firstConv, new Convolution2D.Builder(3, 3) .addLayer(firstConv, Convolution2D.builder(3, 3)
.nOut(10) .nOut(10)
.build(), inputName) .build(), inputName)
.addLayer(secondConv, new Convolution2D.Builder(1, 1) .addLayer(secondConv, Convolution2D.builder(1, 1)
.nOut(3) .nOut(3)
.build(), firstConv) .build(), firstConv)
.addLayer(outputName, OutputLayer.builder() .addLayer(outputName, OutputLayer.builder()
@ -546,11 +546,11 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
.addInputs(inputName) .addInputs(inputName)
.setOutputs(outputName) .setOutputs(outputName)
.setInputTypes(InputType.inferInputTypes(input)) .setInputTypes(InputType.inferInputTypes(input))
.addLayer(changeNoutName, new Convolution2D.Builder(1, 1) .addLayer(changeNoutName, Convolution2D.builder(1, 1)
.nOut(10) .nOut(10)
.build(), inputName) .build(), inputName)
.addLayer(poolName, SubsamplingLayer.builder(1,1).build(), changeNoutName) .addLayer(poolName, SubsamplingLayer.builder(1,1).build(), changeNoutName)
.addLayer(afterPoolName, new Convolution2D.Builder(1, 1) .addLayer(afterPoolName, Convolution2D.builder(1, 1)
.nOut(7) .nOut(7)
.build(), poolName) .build(), poolName)
.addLayer(outputName, OutputLayer.builder() .addLayer(outputName, OutputLayer.builder()
@ -583,7 +583,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.layer("l0", LSTM.builder().nIn(5).nOut(5).build(), "in") .layer("l0", LSTM.builder().nIn(5).nOut(5).build(), "in")
.layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0") .layer("l1", RecurrentAttentionLayer.builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
.layer("out", RnnOutputLayer.builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") .layer("out", RnnOutputLayer.builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
.setOutputs("out") .setOutputs("out")
.build(); .build();

View File

@ -28,7 +28,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.graph.SubsetVertex; import org.deeplearning4j.nn.conf.graph.SubsetVertex;
import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder;
import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
@ -214,9 +213,9 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
(NeuralNetConfiguration) overallConf.clone() (NeuralNetConfiguration) overallConf.clone()
.layer(0, new Builder().nIn(4).nOut(3).build()) .layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
.layer(1, new Builder().nIn(3).nOut(2).build()) .layer(1, DenseLayer.builder().nIn(3).nOut(2).build())
.layer(2, new Builder().nIn(2).nOut(3).build()) .layer(2, DenseLayer.builder().nIn(2).nOut(3).build())
.layer(3, OutputLayer.builder( .layer(3, OutputLayer.builder(
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build()) .build())
@ -233,7 +232,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams()); Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
MultiLayerNetwork notFrozen = new MultiLayerNetwork( MultiLayerNetwork notFrozen = new MultiLayerNetwork(
(NeuralNetConfiguration) overallConf.clone().list() (NeuralNetConfiguration) overallConf.clone().list()
.layer(0, new Builder().nIn(2).nOut(3).build()) .layer(0, DenseLayer.builder().nIn(2).nOut(3).build())
.layer(1, OutputLayer.builder( .layer(1, OutputLayer.builder(
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build()) .build())

View File

@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
@ -74,7 +74,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
(NeuralNetConfiguration) confToChange.list() (NeuralNetConfiguration) confToChange.list()
.layer(0, new Builder().nIn(4).nOut(3).build()) .layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
.layer(1, OutputLayer.builder( .layer(1, OutputLayer.builder(
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build()) .build())
@ -101,7 +101,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.updater(new RmsProp(0.5)).l2(0.4); .updater(new RmsProp(0.5)).l2(0.4);
MultiLayerNetwork expectedModel = new MultiLayerNetwork((NeuralNetConfiguration) confSet.list() MultiLayerNetwork expectedModel = new MultiLayerNetwork((NeuralNetConfiguration) confSet.list()
.layer(0, new Builder().nIn(4).nOut(3).build()) .layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
.layer(1, OutputLayer.builder( .layer(1, OutputLayer.builder(
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
.build()) .build())
@ -651,8 +651,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.weightInit(new ConstantDistribution(666)) .weightInit(new ConstantDistribution(666))
.list() .list()
.inputType(InputType.inferInputTypes(input)[0]) .inputType(InputType.inferInputTypes(input)[0])
.layer(new Convolution2D.Builder(3, 3).nOut(10).build()) .layer(Convolution2D.builder(3, 3).nOut(10).build())
.layer(new Convolution2D.Builder(1, 1).nOut(3).build()) .layer(Convolution2D.builder(1, 1).nOut(3).build())
.layer(OutputLayer.builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE) .layer(OutputLayer.builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE)
.build()).build()); .build()).build());
net.init(); net.init();
@ -682,9 +682,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
MultiLayerNetwork net = new MultiLayerNetwork( NeuralNetConfiguration.builder() MultiLayerNetwork net = new MultiLayerNetwork( NeuralNetConfiguration.builder()
.list() .list()
.inputType(InputType.inferInputTypes(input)[0]) .inputType(InputType.inferInputTypes(input)[0])
.layer(new Convolution2D.Builder(1, 1).nOut(10).build()) .layer(Convolution2D.builder(1, 1).nOut(10).build())
.layer(SubsamplingLayer.builder(1,1).build()) .layer(SubsamplingLayer.builder(1,1).build())
.layer(new Convolution2D.Builder(1, 1).nOut(7).build()) .layer(Convolution2D.builder(1, 1).nOut(7).build())
.layer(OutputLayer.builder().activation(Activation.SOFTMAX).nOut(2).build()) .layer(OutputLayer.builder().activation(Activation.SOFTMAX).nOut(2).build())
.build()); .build());
net.init(); net.init();
@ -712,7 +712,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(LSTM.builder().nOut(8).build()) .layer(LSTM.builder().nOut(8).build())
.layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()) .layer( SelfAttentionLayer.builder().nOut(4).nHeads(2).projectInput(true).build())
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build()) .layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
.layer(OutputLayer.builder().nOut(2).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(2).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())

View File

@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
.graphBuilder() .graphBuilder()
.addInputs(inputName) .addInputs(inputName)
.setOutputs(output) .setOutputs(output)
.layer(conv, new Convolution1DLayer.Builder(7) .layer(conv, Convolution1DLayer.builder(7)
.convolutionMode(ConvolutionMode.Same) .convolutionMode(ConvolutionMode.Same)
.nOut(input.size(1)) .nOut(input.size(1))
.weightInit(new WeightInitIdentity()) .weightInit(new WeightInitIdentity())
@ -115,7 +115,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
.weightInit(new WeightInitIdentity()) .weightInit(new WeightInitIdentity())
.activation(new ActivationIdentity()) .activation(new ActivationIdentity())
.build(), inputName) .build(), inputName)
.layer(output, new Cnn3DLossLayer.Builder(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv) .layer(output, Cnn3DLossLayer.builder().dataFormat(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv)
.build()); .build());
graph.init(); graph.init();

View File

@ -249,7 +249,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
int nOut = 6; int nOut = 6;
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01).l2(0.01) NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01).l2(0.01)
.updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER)
.layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build()) .layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build())
.layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder() .layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder()
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())

View File

@ -56,10 +56,10 @@ public class KerasSpaceToDepth extends KerasLayer {
// TODO: we hard-code block size here to import YOLO9000. This size is not available as property // TODO: we hard-code block size here to import YOLO9000. This size is not available as property
// in the hdf5 file outside of the serialized lambda function (that we can't really well deserialize). // in the hdf5 file outside of the serialized lambda function (that we can't really well deserialize).
SpaceToDepthLayer.Builder builder = new SpaceToDepthLayer.Builder() var builder = SpaceToDepthLayer.builder()
.blocks(2) .blockSize(2)
//the default data format is tensorflow/NWHC for keras import //the default data format is tensorflow/NWHC for keras import
.dataFormat(SpaceToDepthLayer.DataFormat.NHWC) .dataFormat(SpaceToDepthLayer.DataFormat.NHWC.toFormat())
.name(name); .name(name);
this.layer = builder.build(); this.layer = builder.build();

View File

@ -63,7 +63,7 @@ public class KerasUpsampling3D extends KerasLayer {
int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 3, conf); int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 3, conf);
// TODO: make sure to allow different sizes. // TODO: make sure to allow different sizes.
Upsampling3D.Builder builder = new Upsampling3D.Builder() var builder = Upsampling3D.builder()
.name(this.name) .name(this.name)
.dropOut(this.dropout) .dropOut(this.dropout)
.size(size[0]); .size(size[0]);

View File

@ -59,7 +59,7 @@ public class KerasLRN extends KerasLayer {
super(layerConfig, enforceTrainingConfig); super(layerConfig, enforceTrainingConfig);
Map<String, Object> lrnParams = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); Map<String, Object> lrnParams = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
LocalResponseNormalization.Builder builder = LocalResponseNormalization.builder().name(this.name) var builder = LocalResponseNormalization.builder().name(this.name)
.dropOut(this.dropout).alpha((double) lrnParams.get("alpha")) .dropOut(this.dropout).alpha((double) lrnParams.get("alpha"))
.beta((double) lrnParams.get("beta")).k((int) lrnParams.get("k")).n((int) lrnParams.get("n")); .beta((double) lrnParams.get("beta")).k((int) lrnParams.get("k")).n((int) lrnParams.get("n"));
this.layer = builder.build(); this.layer = builder.build();

View File

@ -33,6 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.params.ConvolutionParamInitializer; import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.IWeightInit;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.HashMap; import java.util.HashMap;
@ -98,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution {
LocallyConnected1D.LocallyConnected1DBuilder builder = LocallyConnected1D.builder().name(this.name) LocallyConnected1D.LocallyConnected1DBuilder builder = LocallyConnected1D.builder().name(this.name)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getActivationFromConfig(layerConfig, conf)) .activation(getActivationFromConfig(layerConfig, conf))
.weightInit(conf.getKERAS_PARAM_NAME_W(), init) .weightInit(WeightInit.valueOf(conf.getKERAS_PARAM_NAME_W()))
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0]) .kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])

View File

@ -99,7 +99,7 @@ public class KerasLocallyConnected2D extends KerasConvolution {
LocallyConnected2D.LocallyConnected2DBuilder builder = LocallyConnected2D.builder().name(this.name) LocallyConnected2D.LocallyConnected2DBuilder builder = LocallyConnected2D.builder().name(this.name)
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout) .nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
.activation(getActivationFromConfig(layerConfig, conf)) .activation(getActivationFromConfig(layerConfig, conf))
.weightInit(conf.getKERAS_PARAM_NAME_W(), init) .weightInit(init.enumValue())
.l1(this.weightL1Regularization).l2(this.weightL2Regularization) .l1(this.weightL1Regularization).l2(this.weightL2Regularization)
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf)) .convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion)) .kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))

View File

@ -130,12 +130,14 @@ public class KerasBatchNormalization extends KerasLayer {
BatchNormalization.BatchNormalizationBuilder builder =BatchNormalization.builder() BatchNormalization.BatchNormalizationBuilder builder =BatchNormalization.builder()
.name(this.name) .name(this.name)
.dropOut(this.dropout) .dropOut(this.dropout)
.minibatch(true)
.isMinibatch(true)
.lockGammaBeta(false) .lockGammaBeta(false)
.useLogStd(false) .useLogStd(false)
.decay(getMomentumFromConfig(layerConfig)) .decay(getMomentumFromConfig(layerConfig))
.eps(getEpsFromConfig(layerConfig)); .eps(getEpsFromConfig(layerConfig));
if (betaConstraint != null) if (betaConstraint != null)
builder.constrainBeta(betaConstraint); builder.constrainBeta(betaConstraint);
if (gammaConstraint != null) if (gammaConstraint != null)
builder.constrainGamma(gammaConstraint); builder.constrainGamma(gammaConstraint);

View File

@ -58,11 +58,11 @@ public class KerasModelImportTest extends BaseDL4JTest {
MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5"); MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5");
List<LayerConfiguration> layerConfigs = model.getNetConfiguration().getFlattenedLayerConfigurations(); List<LayerConfiguration> layerConfigs = model.getNetConfiguration().getFlattenedLayerConfigurations();
ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0); ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0);
assertEquals(CNN2DFormat.NCHW,convolutionLayer.getDataFormat()); assertEquals(CNN2DFormat.NCHW,convolutionLayer.getConvFormat());
SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1); SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1);
assertEquals(CNN2DFormat.NHWC,subsamplingLayer.getDataFormat()); assertEquals(CNN2DFormat.NHWC,subsamplingLayer.getDataFormat());
ConvolutionLayer convolutionLayer1 = (ConvolutionLayer) layerConfigs.get(2); ConvolutionLayer convolutionLayer1 = (ConvolutionLayer) layerConfigs.get(2);
assertEquals(CNN2DFormat.NHWC,convolutionLayer1.getDataFormat()); assertEquals(CNN2DFormat.NHWC,convolutionLayer1.getConvFormat());
model.output(Nd4j.zeros(1,1,28,28)); model.output(Nd4j.zeros(1,1,28,28));
assertNotNull(model); assertNotNull(model);

View File

@ -60,8 +60,8 @@ public class KerasYolo9000PredictTest extends BaseDL4JTest {
ComputationGraph model = new TransferLearning.GraphBuilder(graph) ComputationGraph model = new TransferLearning.GraphBuilder(graph)
.addLayer("outputs", .addLayer("outputs",
new org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.Builder() org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.builder()
.boundingBoxPriors(priors) .boundingBoxes(priors)
.build(), .build(),
"conv2d_23") "conv2d_23")
.setOutputs("outputs") .setOutputs("outputs")

View File

@ -126,7 +126,7 @@ public class KerasLocallyConnected1DTest extends BaseDL4JTest {
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
assertEquals(new Dropout(DROPOUT_DL4J), layer.getDropOut()); assertEquals(new Dropout(DROPOUT_DL4J), layer.getDropOut());
assertEquals(KERNEL_SIZE, layer.getKernel()); assertEquals(KERNEL_SIZE, layer.getKernelSize());
assertEquals(STRIDE, layer.getStride()); assertEquals(STRIDE, layer.getStride());
assertEquals(N_OUT, layer.getNOut()); assertEquals(N_OUT, layer.getNOut());
assertEquals(ConvolutionMode.Truncate, layer.getConvolutionMode()); assertEquals(ConvolutionMode.Truncate, layer.getConvolutionMode());

View File

@ -21,6 +21,7 @@
package net.brutex.ai.dnn.api; package net.brutex.ai.dnn.api;
public interface ILayerConfiguration { public interface ILayerConfiguration {

View File

@ -561,13 +561,12 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
public B activation(IActivation activation) { public B activation(Activation activation) {
this.activation = activation; this.activation = activation;
return self(); return self();
} }
public B activation(IActivation activation) {
public B activation(Activation activation) { this.activation = activation;
this.activation = activation.getActivationFunction();
return self(); return self();
} }
/** /**

View File

@ -157,9 +157,9 @@ public class AttentionVertex extends SameDiffVertex {
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION); val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION); val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
attention = sameDiff.nn.multiHeadDotProductAttention(getLayerName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true); attention = sameDiff.nn.multiHeadDotProductAttention(getName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
}else{ }else{
attention = sameDiff.nn.dotProductAttention(getLayerName(), queries, keys, values, mask, true); attention = sameDiff.nn.dotProductAttention(getName(), queries, keys, values, mask, true);
} }
if(maskVars != null){ if(maskVars != null){

View File

@ -53,11 +53,11 @@ public class ActivationLayer extends NoParamLayer {
public static ActivationLayerBuilder<?, ?> builder(Activation activation) { public static ActivationLayerBuilder<?, ?> builder(Activation activation) {
return innerBuilder().activation(activation); return innerBuilder().activation(activation);
} }
public static ActivationLayerBuilder<?, ?> builder(IActivation activation) { public static ActivationLayerBuilder<?, ?> builder(IActivation activation) {
return innerBuilder().activation(activation); return innerBuilder().activation(activation);
} }
public static ActivationLayerBuilder<?, ?> builder() { public static ActivationLayerBuilder<?, ?> builder() {
return innerBuilder(); return innerBuilder();
} }

View File

@ -25,6 +25,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import net.brutex.ai.dnn.api.LayerType; import net.brutex.ai.dnn.api.LayerType;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
@ -80,19 +81,38 @@ public class BatchNormalization extends FeedForwardLayer {
@lombok.Builder.Default protected boolean isMinibatch = true; @lombok.Builder.Default protected boolean isMinibatch = true;
/** /**
* Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
* 1.0 * 1.0
* *
* @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode * @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode
*/ */
@lombok.Builder.Default protected double gamma = 1.0; @lombok.Builder.Default protected double gamma = 1.0;
/** /**
* Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default: * Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
* 0.0 * 0.0
* *
* @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode * @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode
*/ */
@lombok.Builder.Default protected double beta = 0.0; @lombok.Builder.Default protected double beta = 0.0;
/**
* Set constraints to be applied to the beta parameter of this batch normalisation layer. Default: no
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
* been updated.
*
*/
protected List<LayerConstraint> betaConstraints;
/**
* Set constraints to be applied to the gamma parameter of this batch normalisation layer. Default: no
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
* been updated.
*
*/
protected List<LayerConstraint> gammaConstraints;
/** /**
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed? * When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed?
* If set to false, an exception in the helper will be propagated back to the user. If true, the built-in * If set to false, an exception in the helper will be propagated back to the user. If true, the built-in
@ -298,6 +318,15 @@ public class BatchNormalization extends FeedForwardLayer {
this.cudnnAllowFallback$set = true; this.cudnnAllowFallback$set = true;
return self(); return self();
} }
public B constrainBeta(LayerConstraint ... constraints) {
this.betaConstraints = List.of(constraints);
return self();
}
public B constrainGamma(LayerConstraint ... constraints) {
this.gammaConstraints = List.of(constraints);
return self();
}
} }

View File

@ -51,6 +51,16 @@ import org.nd4j.linalg.api.ndarray.INDArray;
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class Convolution1DLayer extends ConvolutionLayer { public class Convolution1DLayer extends ConvolutionLayer {
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW; @Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default
protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons
/** /**
* Size of the convolution * Size of the convolution
* *

View File

@ -26,10 +26,7 @@ import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.layers.convolution.Convolution3DLayer; import org.deeplearning4j.nn.layers.convolution.Convolution3DLayer;
import org.deeplearning4j.nn.params.Convolution3DParamInitializer; import org.deeplearning4j.nn.params.Convolution3DParamInitializer;

View File

@ -65,6 +65,18 @@ public class ConvolutionLayer extends FeedForwardLayer {
* details Default is {@link ConvolutionMode}.Truncate. * details Default is {@link ConvolutionMode}.Truncate.
*/ */
@Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate; @Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default
protected CNN2DFormat convFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons
/** /**
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated * Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated
* convolutions, which are also known as atrous convolutions. * convolutions, which are also known as atrous convolutions.
@ -86,16 +98,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
* false, the built-in (non-CuDNN) implementation for ConvolutionLayer will be used * false, the built-in (non-CuDNN) implementation for ConvolutionLayer will be used
*/ */
@Builder.Default protected boolean cudnnAllowFallback = true; @Builder.Default protected boolean cudnnAllowFallback = true;
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default
protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons
/** Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */ /** Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */
@Builder.Default protected AlgoMode cudnnAlgoMode = AlgoMode.PREFER_FASTEST; @Builder.Default protected AlgoMode cudnnAlgoMode = AlgoMode.PREFER_FASTEST;
@ -179,7 +182,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
nOut, nOut,
layerIndex, layerIndex,
getName(), getName(),
dataFormat, convFormat,
ConvolutionLayer.class); ConvolutionLayer.class);
} }
@ -196,11 +199,11 @@ public class ConvolutionLayer extends FeedForwardLayer {
if (!defaultValueOverriden || nIn <= 0 || override) { if (!defaultValueOverriden || nIn <= 0 || override) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
this.nIn = c.getChannels(); this.nIn = c.getChannels();
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
} }
if (dataFormat == null || override) if (convFormat == null || override)
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat(); this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
} }
@Override @Override

View File

@ -53,6 +53,16 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
*/ */
@Builder.Default @Builder.Default
protected int depthMultiplier = 1; protected int depthMultiplier = 1;
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default
protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons
/** /**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br> * See {@link CNN2DFormat} for more details.<br>

View File

@ -36,6 +36,7 @@ import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding; import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -63,7 +64,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
*/ */
public static EmbeddingLayerBuilder<?, ?> builder() { public static EmbeddingLayerBuilder<?, ?> builder() {
return innerBuilder() return innerBuilder()
.activation(new ActivationIdentity()); .activation(Activation.IDENTITY);
} }
public static abstract class EmbeddingLayerBuilder<C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C,B>> public static abstract class EmbeddingLayerBuilder<C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C,B>>

View File

@ -88,7 +88,7 @@ public abstract class LayerConfiguration
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation. * Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
*/ */
@Builder.Default @Builder.Default
@Getter @Setter private IActivation activation = new ActivationIdentity(); @Getter @Setter private IActivation activation = Activation.IDENTITY;
/** /**
* Get the activation interface (function) from the activation. The activation must have been set * Get the activation interface (function) from the activation. The activation must have been set
@ -335,7 +335,7 @@ public abstract class LayerConfiguration
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> { public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
public B activation(Activation activation) { public B activation(Activation activation) {
this.activation$value = activation.getActivationFunction(); this.activation$value = activation;
this.activation$set = true; this.activation$set = true;
return self(); return self();
} }
@ -344,6 +344,7 @@ public abstract class LayerConfiguration
this.activation$set = true; this.activation$set = true;
return self(); return self();
} }
public B dropOut(double d) { public B dropOut(double d) {
this.dropOut = new Dropout(d); this.dropOut = new Dropout(d);
return self(); return self();
@ -352,6 +353,14 @@ public abstract class LayerConfiguration
this.dropOut = d; this.dropOut = d;
return self(); return self();
} }
public B constrainBias(LayerConstraint constraint) {
return this.biasConstraints(List.of(constraint));
}
public B constrainWeights(LayerConstraint constraint) {
return this.weightConstraints(List.of(constraint));
}
} }
} }

View File

@ -20,7 +20,9 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
@ -32,39 +34,35 @@ import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild")
public class LearnedSelfAttentionLayer extends SameDiffLayer { public class LearnedSelfAttentionLayer extends SameDiffLayer {
private long nIn;
private long nOut;
private int nHeads;
private long headSize;
private boolean projectInput;
private int nQueries;
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq"; private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk"; private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv"; private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo"; private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
private static final String WEIGHT_QUERIES = "Q"; private static final String WEIGHT_QUERIES = "Q";
/** Number of inputs to the layer (input size) */
private int nIn;
/** Number of outputs (output size) */
private int nOut;
/** Number of Attention Heads */
private int nHeads;
/** Size of attention heads */
private int headSize;
/** Project input before applying attention or not. */
private boolean projectInput;
/** Number of queries to learn */
private int nQueries;
private LearnedSelfAttentionLayer(){/*No arg constructor for serialization*/} private LearnedSelfAttentionLayer() {
/*No arg constructor for serialization*/
protected LearnedSelfAttentionLayer(Builder builder){
super(builder);
nIn = builder.nIn;
nOut = builder.nOut;
nHeads = builder.nHeads;
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
projectInput = builder.projectInput;
nQueries = builder.nQueries;
} }
@Override @Override
@ -75,27 +73,34 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) { if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer name = \"" + getName() throw new IllegalStateException(
+ "\"): expect RNN input type with size > 0. Got: " + inputType); "Invalid input for Learned Self Attention layer (layer name = \""
+ getName()
+ "\"): expect RNN input type with size > 0. Got: "
+ inputType);
} }
if (nIn <= 0 || override) { if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize(); this.nIn = (int) r.getSize();
} }
} }
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) { if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer index = " + layerIndex throw new IllegalStateException(
+ ", layer name = \"" + getName() + "\"): expect RNN input type with size > 0. Got: " "Invalid input for Learned Self Attention layer (layer index = "
+ layerIndex
+ ", layer name = \""
+ getName()
+ "\"): expect RNN input type with size > 0. Got: "
+ inputType); + inputType);
} }
if(projectInput){ if (projectInput) {
return InputType.recurrent(nOut, nQueries); return InputType.recurrent(nOut, nQueries);
}else{ } else {
return InputType.recurrent(nIn, nQueries); return InputType.recurrent(nIn, nQueries);
} }
} }
@ -106,7 +111,7 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
params.addWeightParam(WEIGHT_QUERIES, 1, nIn, nQueries); params.addWeightParam(WEIGHT_QUERIES, 1, nIn, nQueries);
if(projectInput){ if (projectInput) {
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn); params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn); params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn); params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
@ -118,137 +123,72 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
public void initializeParameters(Map<String, INDArray> params) { public void initializeParameters(Map<String, INDArray> params) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (Map.Entry<String, INDArray> e : params.entrySet()) { for (Map.Entry<String, INDArray> e : params.entrySet()) {
if(e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)){ if (e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)) {
WeightInitUtil.initWeights(nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue()); WeightInitUtil.initWeights(
}else if(e.getKey().equals(WEIGHT_QUERIES)){ nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
WeightInitUtil.initWeights(nIn, nQueries, e.getValue().shape(), weightInit, null, 'c', e.getValue()); } else if (e.getKey().equals(WEIGHT_QUERIES)) {
}else{ WeightInitUtil.initWeights(
WeightInitUtil.initWeights(nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue()); nIn, nQueries, e.getValue().shape(), weightInit, null, 'c', e.getValue());
} else {
WeightInitUtil.initWeights(
nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
} }
} }
} }
} }
@Override @Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) { public SDVariable defineLayer(
SameDiff sameDiff,
SDVariable layerInput,
Map<String, SDVariable> paramTable,
SDVariable mask) {
val baseQueries = paramTable.get(WEIGHT_QUERIES); val baseQueries = paramTable.get(WEIGHT_QUERIES);
val batchSize = layerInput.shape().get(SDIndex.point(0)); val batchSize = layerInput.shape().get(SDIndex.point(0));
val tileAxis = sameDiff.scatterUpdate(sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize); val tileAxis =
sameDiff.scatterUpdate(
sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize);
val queries = sameDiff.tile(baseQueries, tileAxis); val queries = sameDiff.tile(baseQueries, tileAxis);
if(projectInput){ if (projectInput) {
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION); val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION); val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION); val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION); val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
return sameDiff.nn.multiHeadDotProductAttention(getName(), queries, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true); return sameDiff.nn.multiHeadDotProductAttention(
}else{ getName(), queries, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
return sameDiff.nn.dotProductAttention(getName(), queries, layerInput, layerInput, mask, true); } else {
return sameDiff.nn.dotProductAttention(
getName(), queries, layerInput, layerInput, mask, true);
} }
} }
@Override @Override
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { public Pair<INDArray, MaskState> feedForwardMaskArray(
// No further mask propagation here, as the results have taken any mask into account, like in a global pooling layer INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
// No further mask propagation here, as the results have taken any mask into account, like in a
// global pooling layer
return null; return null;
} }
@Getter public abstract static class LearnedSelfAttentionLayerBuilder<
@Setter C extends LearnedSelfAttentionLayer, B extends LearnedSelfAttentionLayerBuilder<C, B>>
public static class Builder extends SameDiffLayer.Builder<LearnedSelfAttentionLayer.Builder> { extends SameDiffLayerBuilder<C, B> {
public C build() {
/** Preconditions.checkArgument(
* Number of inputs to the layer (input size) this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
*/ Preconditions.checkArgument(
private int nIn; this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(
/** !this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
* Number of outputs (output size) Preconditions.checkArgument(
*/ this.nOut % nHeads == 0 || headSize > 0,
private int nOut; "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
/**
* Number of Attention Heads
*/
private int nHeads;
/**
* Size of attention heads
*/
private int headSize;
/**
* Project input before applying attention or not.
*/
private boolean projectInput;
/**
* Number of queries to learn
*/
private int nQueries;
/**
* @param nIn Number of inputs to the layer (input size)
*/
public Builder nIn(int nIn) {
this.nIn = nIn;
return this;
}
/**
* @param nOut Number of outputs (output size)
*/
public Builder nOut(int nOut) {
this.nOut = nOut;
return this;
}
/**
* Number of Attention Heads
*/
public Builder nHeads(int nHeads){
this.nHeads = nHeads;
return this;
}
/**
* Size of attention heads
*/
public Builder headSize(int headSize){
this.headSize = headSize;
return this;
}
/**
* Project input before applying attention or not.
*/
public Builder projectInput(boolean projectInput){
this.projectInput = projectInput;
return this;
}
/**
* Number of queries to learn
*/
public Builder nQueries(int nQueries){
this.nQueries = nQueries;
return this;
}
@Override
@SuppressWarnings("unchecked")
public LearnedSelfAttentionLayer build() {
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries."); Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries.");
return new LearnedSelfAttentionLayer(this); return initBuild();
} }
} }
} }

View File

@ -41,6 +41,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.enums.PadMode; import org.nd4j.enums.PadMode;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -333,6 +334,11 @@ public class LocallyConnected2D extends SameDiffLayer {
return self(); return self();
} }
public B inputSize(int ... size) {
this.inputSize = size;
return self();
}
public B stride(int ... stride) { public B stride(int ... stride) {
this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride"); this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride");
this.stride$set = true; this.stride$set = true;

View File

@ -20,7 +20,9 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional; import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
@ -37,87 +39,191 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class PrimaryCapsules extends SameDiffLayer { public class PrimaryCapsules extends SameDiffLayer {
private int[] kernelSize;
private int[] stride;
private int[] padding;
private int[] dilation;
private int inputChannels;
private int channels;
private boolean hasBias;
private int capsules;
private int capsuleDimensions;
private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
private boolean useRelu = false;
private double leak = 0;
private static final String WEIGHT_PARAM = "weight"; private static final String WEIGHT_PARAM = "weight";
private static final String BIAS_PARAM = "bias"; private static final String BIAS_PARAM = "bias";
/**
* Sets the kernel size of the 2d convolution
*
* @param kernelSize
* @return
*/
@Builder.Default private int[] kernelSize = new int[] {9, 9};
/**
* Sets the stride of the 2d convolution
*
* @param stride
* @return
*/
@Builder.Default private int[] stride = new int[] {2, 2};
/**
* Sets the padding of the 2d convolution
*
* @param padding
* @return
*/
@Builder.Default private int[] padding = new int[] {0, 0};
/**
* Sets the dilation of the 2d convolution
*
* @param dilation
* @return
*/
@Builder.Default private int[] dilation = new int[] {1, 1};
public PrimaryCapsules(Builder builder){ private int inputChannels;
super(builder); /**
* Sets the number of channels to use in the 2d convolution.
*
* <p>Note that the actual number of channels is channels * capsuleDimensions
*
* <p>Does the same thing as nOut()
*
* @param channels
* @return
*/
@Builder.Default private int channels = 32;
this.kernelSize = builder.kernelSize; @Builder.Default private boolean hasBias = true;
this.stride = builder.stride; /**
this.padding = builder.padding; * Usually inferred automatically.
this.dilation = builder.dilation; *
this.channels = builder.channels; * @param capsules
this.hasBias = builder.hasBias; * @return
this.capsules = builder.capsules; */
this.capsuleDimensions = builder.capsuleDimensions; private int capsules;
this.convolutionMode = builder.convolutionMode; /**
this.useRelu = builder.useRelu; * Sets the number of dimensions to use in the capsules.
this.leak = builder.leak; *
* @param capsuleDimensions
* @return
*/
private int capsuleDimensions;
/**
* The convolution mode to use in the 2d convolution
*
* @param convolutionMode
* @return
*/
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
/**
* Whether to use a ReLU activation on the 2d convolution
*
* @param useRelu
* @return
*/
@Builder.Default private boolean useRelU = false;
/**
* Use a LeakyReLU activation on the 2d convolution
*
* @param leak the alpha value for the LeakyReLU activation.
* @return
*/
@Builder.Default private double useLeakyReLU = 0;
if(capsuleDimensions <= 0 || channels <= 0){ public static PrimaryCapsulesBuilder<?, ?> builder() {
throw new IllegalArgumentException("Invalid configuration for Primary Capsules (layer name = \"" return innerBuilder();
+ name + "\"):"
+ " capsuleDimensions and channels must be > 0. Got: "
+ capsuleDimensions + ", " + channels);
} }
if(capsules < 0){ public static PrimaryCapsulesBuilder<?, ?> builder(
throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \"" int capsuleDimensions,
+ name + "\"):" int channels,
+ " capsules must be >= 0 if set. Got: " int[] kernelSize,
+ capsules); int[] stride,
int[] padding,
int[] dilation,
ConvolutionMode convolutionMode) {
return innerBuilder()
.capsuleDimensions(capsuleDimensions)
.channels(channels)
.kernelSize(kernelSize)
.stride(stride)
.padding(padding)
.dilation(dilation)
.convolutionMode(convolutionMode);
} }
public static PrimaryCapsulesBuilder<?, ?> builder(
int capsuleDimensions,
int channels,
int[] kernelSize,
int[] stride,
int[] padding,
int[] dilation) {
return innerBuilder()
.capsuleDimensions(capsuleDimensions)
.channels(channels)
.kernelSize(kernelSize)
.stride(stride)
.padding(padding)
.dilation(dilation);
}
public static PrimaryCapsulesBuilder<?, ?> builder(
int capsuleDimensions, int channels, int[] kernelSize, int[] stride, int[] padding) {
return innerBuilder()
.capsuleDimensions(capsuleDimensions)
.channels(channels)
.kernelSize(kernelSize)
.stride(stride)
.padding(padding);
}
public static PrimaryCapsulesBuilder<?, ?> builder(
int capsuleDimensions, int channels, int[] kernelSize, int[] stride) {
return innerBuilder()
.capsuleDimensions(capsuleDimensions)
.channels(channels)
.kernelSize(kernelSize)
.stride(stride);
}
public static PrimaryCapsulesBuilder<?, ?> builder(
int capsuleDimensions, int channels, int[] kernelSize) {
return innerBuilder()
.capsuleDimensions(capsuleDimensions)
.channels(channels)
.kernelSize(kernelSize);
}
public static PrimaryCapsulesBuilder<?, ?> builder(int capsuleDimensions, int channels) {
return innerBuilder().capsuleDimensions(capsuleDimensions).channels(channels);
} }
@Override @Override
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) { public SDVariable defineLayer(
Conv2DConfig conf = Conv2DConfig.builder() SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
.kH(kernelSize[0]).kW(kernelSize[1]) Conv2DConfig conf =
.sH(stride[0]).sW(stride[1]) Conv2DConfig.builder()
.pH(padding[0]).pW(padding[1]) .kH(kernelSize[0])
.dH(dilation[0]).dW(dilation[1]) .kW(kernelSize[1])
.sH(stride[0])
.sW(stride[1])
.pH(padding[0])
.pW(padding[1])
.dH(dilation[0])
.dW(dilation[1])
.isSameMode(convolutionMode == ConvolutionMode.Same) .isSameMode(convolutionMode == ConvolutionMode.Same)
.build(); .build();
SDVariable conved; SDVariable conved;
if(hasBias){ if (hasBias) {
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf); conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf);
} else { } else {
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf); conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf);
} }
if(useRelu){ if (useRelU) {
if(leak == 0) { if (useLeakyReLU == 0) {
conved = SD.nn.relu(conved, 0); conved = SD.nn.relu(conved, 0);
} else { } else {
conved = SD.nn.leakyRelu(conved, leak); conved = SD.nn.leakyRelu(conved, useLeakyReLU);
} }
} }
@ -128,10 +234,14 @@ public class PrimaryCapsules extends SameDiffLayer {
@Override @Override
public void defineParameters(SDLayerParams params) { public void defineParameters(SDLayerParams params) {
params.clear(); params.clear();
params.addWeightParam(WEIGHT_PARAM, params.addWeightParam(
kernelSize[0], kernelSize[1], inputChannels, (long) capsuleDimensions * channels); WEIGHT_PARAM,
kernelSize[0],
kernelSize[1],
inputChannels,
(long) capsuleDimensions * channels);
if(hasBias){ if (hasBias) {
params.addBiasParam(BIAS_PARAM, (long) capsuleDimensions * channels); params.addBiasParam(BIAS_PARAM, (long) capsuleDimensions * channels);
} }
} }
@ -142,11 +252,16 @@ public class PrimaryCapsules extends SameDiffLayer {
for (Map.Entry<String, INDArray> e : params.entrySet()) { for (Map.Entry<String, INDArray> e : params.entrySet()) {
if (BIAS_PARAM.equals(e.getKey())) { if (BIAS_PARAM.equals(e.getKey())) {
e.getValue().assign(0); e.getValue().assign(0);
} else if(WEIGHT_PARAM.equals(e.getKey())){ } else if (WEIGHT_PARAM.equals(e.getKey())) {
double fanIn = inputChannels * kernelSize[0] * kernelSize[1]; double fanIn = inputChannels * kernelSize[0] * kernelSize[1];
double fanOut = capsuleDimensions * channels * kernelSize[0] * kernelSize[1] / ((double) stride[0] * stride[1]); double fanOut =
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', capsuleDimensions
e.getValue()); * channels
* kernelSize[0]
* kernelSize[1]
/ ((double) stride[0] * stride[1]);
WeightInitUtil.initWeights(
fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
} }
} }
} }
@ -155,19 +270,33 @@ public class PrimaryCapsules extends SameDiffLayer {
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != Type.CNN) { if (inputType == null || inputType.getType() != Type.CNN) {
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \"" throw new IllegalStateException(
+ name + "\"): expect CNN input. Got: " + inputType); "Invalid input for Primary Capsules layer (layer name = \""
+ name
+ "\"): expect CNN input. Got: "
+ inputType);
} }
if(capsules > 0){ if (capsules > 0) {
return InputType.recurrent(capsules, capsuleDimensions); return InputType.recurrent(capsules, capsuleDimensions);
} else { } else {
InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil InputTypeConvolutional out =
.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, (InputTypeConvolutional)
(long) capsuleDimensions * channels, -1, getName(), PrimaryCapsules.class); InputTypeUtil.getOutputTypeCnnLayers(
inputType,
kernelSize,
stride,
padding,
dilation,
convolutionMode,
(long) capsuleDimensions * channels,
-1,
getName(),
PrimaryCapsules.class);
return InputType.recurrent((int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions), return InputType.recurrent(
(int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions),
capsuleDimensions); capsuleDimensions);
} }
} }
@ -175,250 +304,122 @@ public class PrimaryCapsules extends SameDiffLayer {
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
if (inputType == null || inputType.getType() != Type.CNN) { if (inputType == null || inputType.getType() != Type.CNN) {
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \"" throw new IllegalStateException(
+ name + "\"): expect CNN input. Got: " + inputType); "Invalid input for Primary Capsules layer (layer name = \""
+ name
+ "\"): expect CNN input. Got: "
+ inputType);
} }
InputTypeConvolutional ci = (InputTypeConvolutional) inputType; InputTypeConvolutional ci = (InputTypeConvolutional) inputType;
this.inputChannels = (int) ci.getChannels(); this.inputChannels = (int) ci.getChannels();
if(capsules <= 0 || override) { if (capsules <= 0 || override) {
InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil InputTypeConvolutional out =
.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, (InputTypeConvolutional)
(long) capsuleDimensions * channels, -1, getName(), PrimaryCapsules.class); InputTypeUtil.getOutputTypeCnnLayers(
inputType,
kernelSize,
stride,
padding,
dilation,
convolutionMode,
(long) capsuleDimensions * channels,
-1,
getName(),
PrimaryCapsules.class);
this.capsules = (int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions); this.capsules =
(int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions);
} }
} }
@Getter public abstract static class PrimaryCapsulesBuilder<
@Setter C extends PrimaryCapsules, B extends PrimaryCapsulesBuilder<C, B>>
public static class Builder extends SameDiffLayer.Builder<Builder>{ extends SameDiffLayerBuilder<C, B> {
@Setter(AccessLevel.NONE) public B kernelSize(int... kernelSize) {
private int[] kernelSize = new int[]{9, 9}; this.kernelSize$value = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
this.kernelSize$set = true;
@Setter(AccessLevel.NONE) return self();
private int[] stride = new int[]{2, 2};
@Setter(AccessLevel.NONE)
private int[] padding = new int[]{0, 0};
@Setter(AccessLevel.NONE)
private int[] dilation = new int[]{1, 1};
private int channels = 32;
private boolean hasBias = true;
private int capsules;
private int capsuleDimensions;
private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
private boolean useRelu = false;
private double leak = 0;
public void setKernelSize(int... kernelSize){
this.kernelSize = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
} }
public void setStride(int... stride){ public B stride(int... stride) {
this.stride = ValidationUtils.validate2NonNegative(stride, true, "stride"); this.stride$value = ValidationUtils.validate2NonNegative(stride, true, "stride");
this.stride$set = true;
return self();
} }
public void setPadding(int... padding){ public B padding(int... padding) {
this.padding = ValidationUtils.validate2NonNegative(padding, true, "padding"); this.padding$value = ValidationUtils.validate2NonNegative(padding, true, "padding");
this.padding$set = true;
return self();
} }
public void setDilation(int... dilation){ public B dilation(int... dilation) {
this.dilation = ValidationUtils.validate2NonNegative(dilation, true, "dilation"); this.dilation$value = ValidationUtils.validate2NonNegative(dilation, true, "dilation");
this.dilation$set = true;
return self();
} }
public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride, int[] padding, int[] dilation,
ConvolutionMode convolutionMode){
this.capsuleDimensions = capsuleDimensions;
this.channels = channels;
this.setKernelSize(kernelSize);
this.setStride(stride);
this.setPadding(padding);
this.setDilation(dilation);
this.convolutionMode = convolutionMode;
}
public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride, int[] padding, int[] dilation){
this(capsuleDimensions, channels, kernelSize, stride, padding, dilation, ConvolutionMode.Truncate);
}
public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride, int[] padding){
this(capsuleDimensions, channels, kernelSize, stride, padding, new int[]{1, 1}, ConvolutionMode.Truncate);
}
public Builder(int capsuleDimensions, int channels,
int[] kernelSize, int[] stride){
this(capsuleDimensions, channels, kernelSize, stride, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
}
public Builder(int capsuleDimensions, int channels,
int[] kernelSize){
this(capsuleDimensions, channels, kernelSize, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
}
public Builder(int capsuleDimensions, int channels){
this(capsuleDimensions, channels, new int[]{9, 9}, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
}
/**
* Sets the kernel size of the 2d convolution
*
* @see ConvolutionLayer.Builder#kernelSize(int...)
* @param kernelSize
* @return
*/
public Builder kernelSize(int... kernelSize){
this.setKernelSize(kernelSize);
return this;
}
/**
* Sets the stride of the 2d convolution
*
* @see ConvolutionLayer.Builder#stride(int...)
* @param stride
* @return
*/
public Builder stride(int... stride){
this.setStride(stride);
return this;
}
/**
* Sets the padding of the 2d convolution
*
* @see ConvolutionLayer.Builder#padding(int...)
* @param padding
* @return
*/
public Builder padding(int... padding){
this.setPadding(padding);
return this;
}
/**
* Sets the dilation of the 2d convolution
*
* @see ConvolutionLayer.Builder#dilation(int...)
* @param dilation
* @return
*/
public Builder dilation(int... dilation){
this.setDilation(dilation);
return this;
}
/** /**
* Sets the number of channels to use in the 2d convolution. * Sets the number of channels to use in the 2d convolution.
* *
* Note that the actual number of channels is channels * capsuleDimensions * <p>Note that the actual number of channels is channels * capsuleDimensions
* *
* Does the same thing as nOut() * <p>Does the same thing as channels()
*
* @param channels
* @return
*/
public Builder channels(int channels){
this.channels = channels;
return this;
}
/**
* Sets the number of channels to use in the 2d convolution.
*
* Note that the actual number of channels is channels * capsuleDimensions
*
* Does the same thing as channels()
* *
* @param nOut * @param nOut
* @return * @return
*/ */
public Builder nOut(int nOut){ public B nOut(int nOut) {
return channels(nOut); return channels(nOut);
} }
/**
* Sets the number of dimensions to use in the capsules.
* @param capsuleDimensions
* @return
*/
public Builder capsuleDimensions(int capsuleDimensions){
this.capsuleDimensions = capsuleDimensions;
return this;
}
/**
* Usually inferred automatically.
* @param capsules
* @return
*/
public Builder capsules(int capsules){
this.capsules = capsules;
return this;
}
public Builder hasBias(boolean hasBias){
this.hasBias = hasBias;
return this;
}
/**
* The convolution mode to use in the 2d convolution
* @param convolutionMode
* @return
*/
public Builder convolutionMode(ConvolutionMode convolutionMode){
this.convolutionMode = convolutionMode;
return this;
}
/**
* Whether to use a ReLU activation on the 2d convolution
* @param useRelu
* @return
*/
public Builder useReLU(boolean useRelu){
this.useRelu = useRelu;
return this;
}
/** /**
* Use a ReLU activation on the 2d convolution * Use a ReLU activation on the 2d convolution
*
* @return * @return
*/ */
public Builder useReLU(){ public B useReLU() {
return useReLU(true); return useRelU(true);
} }
/** /**
* Use a LeakyReLU activation on the 2d convolution * Use a LeakyReLU activation on the 2d convolution. Implies {@link #useReLU()} set true.
*
* @param leak the alpha value for the LeakyReLU activation. * @param leak the alpha value for the LeakyReLU activation.
* @return * @return
*/ */
public Builder useLeakyReLU(double leak){ public B useLeakyReLU(double leak) {
this.useRelu = true; this.useRelU(true);
this.leak = leak; this.useLeakyReLU$value = leak;
return this; this.useLeakyReLU$set = true;
return self();
} }
@Override public C build() {
public <E extends LayerConfiguration> E build() { C l = initBuild();
return (E) new PrimaryCapsules(this); if (capsuleDimensions <= 0 || channels$value <= 0) {
throw new IllegalArgumentException(
"Invalid configuration for Primary Capsules (layer name = \""
+ l.getName()
+ "\"):"
+ " capsuleDimensions and channels must be > 0. Got: "
+ capsuleDimensions
+ ", "
+ channels$value);
}
if (capsules < 0) {
throw new IllegalArgumentException(
"Invalid configuration for Capsule ILayer (layer name = \""
+ l.getName()
+ "\"):"
+ " capsules must be >= 0 if set. Got: "
+ capsules);
}
return l;
} }
} }
} }

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
@ -41,15 +42,63 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild")
public class RecurrentAttentionLayer extends SameDiffLayer { public class RecurrentAttentionLayer extends SameDiffLayer {
private long nIn;
private long nOut; public static abstract class RecurrentAttentionLayerBuilder<C extends RecurrentAttentionLayer, B extends RecurrentAttentionLayerBuilder<C,B>>
extends SameDiffLayerBuilder<C,B> {
public C build() {
Preconditions.checkArgument(this.projectInput$value || this.nHeads == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(this.projectInput$value || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(!this.projectInput$value || nOut != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
C l = initBuild();
return l;
}
}
/**
* Number of inputs to the layer (input size)
*/
private int nIn;
/**
* Number of outputs (output size)
*/
private int nOut;
/**
* Number of Attention Heads
*/
private int nHeads; private int nHeads;
private long headSize;
private boolean projectInput; /**
private Activation activation; * Size of attention heads
private boolean hasBias; */
private int headSize;
/**
* Project input before applying attention or not.
*/
@Builder.Default
private boolean projectInput = true;
/**
* If true (default is true) the layer will have a bias
*/
@Builder.Default
private boolean hasBias = true;
/**
* Activation function for the layer
*/
@Builder.Default
private Activation activation = Activation.TANH;
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq"; private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk"; private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
@ -60,18 +109,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
private static final String RECURRENT_WEIGHT_KEY = SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY; private static final String RECURRENT_WEIGHT_KEY = SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY;
private int timeSteps; private int timeSteps;
private RecurrentAttentionLayer(){/*No arg constructor for serialization*/}
protected RecurrentAttentionLayer(Builder builder){
super(builder);
nIn = builder.nIn;
nOut = builder.nOut;
nHeads = builder.nHeads;
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
projectInput = builder.projectInput;
activation = builder.activation;
hasBias = builder.hasBias;
}
@Override @Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
@ -87,7 +125,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
if (nIn <= 0 || override) { if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize(); this.nIn = (int) r.getSize();
} }
} }
@ -206,109 +244,5 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
return sameDiff.concat(2, outputSlices); return sameDiff.concat(2, outputSlices);
} }
@Getter
@Setter
public static class Builder extends SameDiffLayer.Builder<RecurrentAttentionLayer.Builder> {
/**
* Number of inputs to the layer (input size)
*/
private int nIn;
/**
* Number of outputs (output size)
*/
private int nOut;
/**
* Number of Attention Heads
*/
private int nHeads;
/**
* Size of attention heads
*/
private int headSize;
/**
* Project input before applying attention or not.
*/
private boolean projectInput = true;
/**
* If true (default is true) the layer will have a bias
*/
private boolean hasBias = true;
/**
* Activation function for the layer
*/
private Activation activation = Activation.TANH;
/**
* @param nIn Number of inputs to the layer (input size)
*/
public Builder nIn(int nIn) {
this.nIn = nIn;
return this;
}
/**
* @param nOut Number of outputs (output size)
*/
public Builder nOut(int nOut) {
this.nOut = nOut;
return this;
}
/**
* Number of Attention Heads
*/
public Builder nHeads(int nHeads){
this.nHeads = nHeads;
return this;
}
/**
* Size of attention heads
*/
public Builder headSize(int headSize){
this.headSize = headSize;
return this;
}
/**
* Project input before applying attention or not.
*/
public Builder projectInput(boolean projectInput){
this.projectInput = projectInput;
return this;
}
/**
* @param hasBias If true (default is true) the layer will have a bias
*/
public Builder hasBias(boolean hasBias) {
this.hasBias = hasBias;
return this;
}
/**
* @param activation Activation function for the layer
*/
public Builder activation(Activation activation) {
this.activation = activation;
return this;
}
@Override
@SuppressWarnings("unchecked")
public RecurrentAttentionLayer build() {
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
return new RecurrentAttentionLayer(this);
}
}
} }

View File

@ -20,7 +20,9 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
@ -34,32 +36,25 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@NoArgsConstructor()
@SuperBuilder(buildMethodName = "initBuild")
public class SelfAttentionLayer extends SameDiffLayer { public class SelfAttentionLayer extends SameDiffLayer {
private long nIn;
private long nOut;
private int nHeads;
private long headSize;
private boolean projectInput;
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq"; private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk"; private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv"; private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo"; private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
/** Number of inputs to the layer (input size) */
private SelfAttentionLayer(){/*No arg constructor for serialization*/} private int nIn;
/** Number of outputs (output size) */
protected SelfAttentionLayer(Builder builder){ private int nOut;
super(builder); /** Number of Attention Heads */
nIn = builder.nIn; private int nHeads;
nOut = builder.nOut; /** Size of attention heads */
nHeads = builder.nHeads; private int headSize;
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize; /** Project input before applying attention or not. */
projectInput = builder.projectInput; private boolean projectInput;
}
@Override @Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
@ -69,29 +64,36 @@ public class SelfAttentionLayer extends SameDiffLayer {
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) { if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input for Self Attention layer (layer name = \"" + getName() throw new IllegalStateException(
+ "\"): expect RNN input type with size > 0. Got: " + inputType); "Invalid input for Self Attention layer (layer name = \""
+ getName()
+ "\"): expect RNN input type with size > 0. Got: "
+ inputType);
} }
if (nIn <= 0 || override) { if (nIn <= 0 || override) {
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
this.nIn = r.getSize(); this.nIn = (int) r.getSize();
} }
} }
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.RNN) { if (inputType == null || inputType.getType() != InputType.Type.RNN) {
throw new IllegalStateException("Invalid input for Self Attention layer (layer index = " + layerIndex throw new IllegalStateException(
+ ", layer name = \"" + getName() + "\"): expect RNN input type with size > 0. Got: " "Invalid input for Self Attention layer (layer index = "
+ layerIndex
+ ", layer name = \""
+ getName()
+ "\"): expect RNN input type with size > 0. Got: "
+ inputType); + inputType);
} }
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType; InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
if(projectInput){ if (projectInput) {
return InputType.recurrent(nOut, itr.getTimeSeriesLength()); return InputType.recurrent(nOut, itr.getTimeSeriesLength());
}else{ } else {
return InputType.recurrent(nIn, itr.getTimeSeriesLength()); return InputType.recurrent(nIn, itr.getTimeSeriesLength());
} }
} }
@ -100,7 +102,7 @@ public class SelfAttentionLayer extends SameDiffLayer {
public void defineParameters(SDLayerParams params) { public void defineParameters(SDLayerParams params) {
params.clear(); params.clear();
if(projectInput){ if (projectInput) {
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn); params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn); params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn); params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
@ -112,108 +114,52 @@ public class SelfAttentionLayer extends SameDiffLayer {
public void initializeParameters(Map<String, INDArray> params) { public void initializeParameters(Map<String, INDArray> params) {
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
for (Map.Entry<String, INDArray> e : params.entrySet()) { for (Map.Entry<String, INDArray> e : params.entrySet()) {
if(e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)){ if (e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)) {
WeightInitUtil.initWeights(nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue()); WeightInitUtil.initWeights(
}else{ nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
WeightInitUtil.initWeights(nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue()); } else {
WeightInitUtil.initWeights(
nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
} }
} }
} }
} }
@Override @Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) { public SDVariable defineLayer(
if(projectInput){ SameDiff sameDiff,
SDVariable layerInput,
Map<String, SDVariable> paramTable,
SDVariable mask) {
if (projectInput) {
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION); val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION); val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION); val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION); val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
return sameDiff.nn.multiHeadDotProductAttention(getName(), layerInput, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true); return sameDiff.nn.multiHeadDotProductAttention(
}else{ getName(), layerInput, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
return sameDiff.nn.dotProductAttention(getName(), layerInput, layerInput, layerInput, mask, true); } else {
return sameDiff.nn.dotProductAttention(
getName(), layerInput, layerInput, layerInput, mask, true);
} }
} }
public abstract static class SelfAttentionLayerBuilder<
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
extends SameDiffLayerBuilder<C, B> {
public C build() {
Preconditions.checkArgument(
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(
this.nOut % nHeads == 0 || headSize > 0,
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
@Getter return initBuild();
@Setter
public static class Builder extends SameDiffLayer.Builder<SelfAttentionLayer.Builder> {
/**
* Number of inputs to the layer (input size)
*/
private int nIn;
/**
* Number of outputs (output size)
*/
private int nOut;
/**
* Number of Attention Heads
*/
private int nHeads;
/**
* Size of attention heads
*/
private int headSize;
/**
* Project input before applying attention or not.
*/
private boolean projectInput;
/**
* @param nIn Number of inputs to the layer (input size)
*/
public Builder nIn(int nIn) {
this.nIn = nIn;
return this;
}
/**
* @param nOut Number of outputs (output size)
*/
public Builder nOut(int nOut) {
this.nOut = nOut;
return this;
}
/**
* Number of Attention Heads
*/
public Builder nHeads(int nHeads){
this.nHeads = nHeads;
return this;
}
/**
* Size of attention heads
*/
public Builder headSize(int headSize){
this.headSize = headSize;
return this;
}
/**
* Project input before applying attention or not.
*/
public Builder projectInput(boolean projectInput){
this.projectInput = projectInput;
return this;
}
@Override
@SuppressWarnings("unchecked")
public SelfAttentionLayer build() {
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
return new SelfAttentionLayer(this);
} }
} }
} }

View File

@ -63,7 +63,16 @@ public class SeparableConvolution2D extends ConvolutionLayer {
* @return Builder * @return Builder
*/ */
@Builder.Default private int depthMultiplier = 1; @Builder.Default private int depthMultiplier = 1;
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default
protected CNN2DFormat dataFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons
public static SeparableConvolution2DBuilder<?, ?> builder() { public static SeparableConvolution2DBuilder<?, ?> builder() {
return innerBuilder(); return innerBuilder();
} }

View File

@ -20,7 +20,10 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import java.util.Collection;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
@ -35,27 +38,50 @@ import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Map;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder")
public class SpaceToBatchLayer extends NoParamLayer { public class SpaceToBatchLayer extends NoParamLayer {
/**
* Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
* dimensions
*/
protected int[] blockSize;
/** A 2d array, with format [[padTop, padBottom], [padLeft, padRight]] */
@Builder.Default protected int[][] padding = new int[][] {{0, 0}, {0, 0}};
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
*
* @param format Format for activations (in and out)
*/
@Builder.Default protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
public static SpaceToBatchLayerBuilder<?, ?> builder() {
return innerBuilder();
}
// TODO: throw error when block and padding dims don't match // TODO: throw error when block and padding dims don't match
protected int[] blocks; /**
protected int[][] padding; * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and
protected CNN2DFormat format = CNN2DFormat.NCHW; * width dimensions
*/
public static SpaceToBatchLayerBuilder<?, ?> builder(int[] blocks) {
return innerBuilder().blockSize(blocks);
}
/**
protected SpaceToBatchLayer(Builder builder) { * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and
super(builder); * width dimensions
this.blocks = builder.blocks; * @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft,
this.padding = builder.padding; * padRight]]
this.format = builder.format; */
public static SpaceToBatchLayerBuilder<?, ?> builder(int[] blocks, int[][] padding) {
return innerBuilder().blockSize(blocks).padding(padding);
} }
@Override @Override
@ -64,9 +90,13 @@ public class SpaceToBatchLayer extends NoParamLayer {
} }
@Override @Override
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, public org.deeplearning4j.nn.api.Layer instantiate(
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, NeuralNetConfiguration conf,
boolean initializeParams, DataType networkDataType) { Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret = org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret =
@ -83,23 +113,31 @@ public class SpaceToBatchLayer extends NoParamLayer {
@Override @Override
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
InputType.InputTypeConvolutional outputType = (InputType.InputTypeConvolutional) getOutputType(-1, inputType); InputType.InputTypeConvolutional outputType =
(InputType.InputTypeConvolutional) getOutputType(-1, inputType);
return new LayerMemoryReport.Builder(name, SpaceToBatchLayer.class, inputType, outputType) return new LayerMemoryReport.Builder(name, SpaceToBatchLayer.class, inputType, outputType)
.standardMemory(0, 0) //No params .standardMemory(0, 0) // No params
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching .cacheMemory(
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
.build(); .build();
} }
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
if (inputType == null || inputType.getType() != InputType.Type.CNN) { if (inputType == null || inputType.getType() != InputType.Type.CNN) {
throw new IllegalStateException("Invalid input for Subsampling layer (layer name=\"" + getName() throw new IllegalStateException(
+ "\"): Expected CNN input, got " + inputType); "Invalid input for Subsampling layer (layer name=\""
+ getName()
+ "\"): Expected CNN input, got "
+ inputType);
} }
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType; InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
return InputType.convolutional((i.getHeight() + padding[0][0] + padding[0][1]) / blocks[0], return InputType.convolutional(
(i.getWidth() + padding[1][0] + padding[1][1]) / blocks[1], i.getChannels(), i.getFormat()); (i.getHeight() + padding[0][0] + padding[0][1]) / blockSize[0],
(i.getWidth() + padding[1][0] + padding[1][1]) / blockSize[1],
i.getChannels(),
i.getFormat());
} }
@Override @Override
@ -107,17 +145,21 @@ public class SpaceToBatchLayer extends NoParamLayer {
return EmptyParamInitializer.getInstance(); return EmptyParamInitializer.getInstance();
} }
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
Preconditions.checkState(inputType.getType() == InputType.Type.CNN, "Only CNN input types can be used with SpaceToBatchLayer, got %s", inputType); Preconditions.checkState(
this.format = ((InputType.InputTypeConvolutional)inputType).getFormat(); inputType.getType() == InputType.Type.CNN,
"Only CNN input types can be used with SpaceToBatchLayer, got %s",
inputType);
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
} }
@Override @Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
if (inputType == null) { if (inputType == null) {
throw new IllegalStateException("Invalid input for space to batch layer (layer name=\"" + getName() throw new IllegalStateException(
"Invalid input for space to batch layer (layer name=\""
+ getName()
+ "\"): input is null"); + "\"): input is null");
} }
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName()); return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
@ -128,102 +170,28 @@ public class SpaceToBatchLayer extends NoParamLayer {
throw new UnsupportedOperationException("SpaceToBatchLayer does not contain parameters"); throw new UnsupportedOperationException("SpaceToBatchLayer does not contain parameters");
} }
public abstract static class SpaceToBatchLayerBuilder<
@NoArgsConstructor C extends SpaceToBatchLayer, B extends SpaceToBatchLayerBuilder<C, B>>
@Getter extends NoParamLayerBuilder<C, B> {
@Setter
public static class Builder<T extends Builder<T>> extends LayerConfiguration.Builder<T> {
/** /**
* Block size for SpaceToBatch layer. Should be a length 2 array for the height and width * @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height
* dimensions * and width dimensions
* @return
*/ */
@Setter(AccessLevel.NONE) public B blockSize(int... blocks) {
protected int[] blocks; this.blockSize = ValidationUtils.validate2NonNegative(blocks, false, "blocks");
return self();
/**
* A 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
*/
protected int[][] padding;
protected CNN2DFormat format = CNN2DFormat.NCHW;
/**
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
* dimensions
*/
public void setBlocks(int... blocks) {
this.blocks = ValidationUtils.validate2NonNegative(blocks, false, "blocks");
} }
/** /**
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]] * @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft,
* padRight]]
* @return
*/ */
public void setPadding(int[][] padding) { public B padding(int[][] padding) {
this.padding = ValidationUtils.validate2x2NonNegative(padding, "padding"); this.padding$value = ValidationUtils.validate2x2NonNegative(padding, "padding");
} this.padding$set = true;
return self();
/**
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
* dimensions
*/
public Builder(int[] blocks) {
this.setBlocks(blocks);
this.setPadding(new int[][] {{0, 0}, {0, 0}});
}
/**
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
* dimensions
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
*/
public Builder(int[] blocks, int[][] padding) {
this.setBlocks(blocks);
this.setPadding(padding);
}
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
* @param format Format for activations (in and out)
*/
public T dataFormat(CNN2DFormat format){
this.format = format;
return (T)this;
}
/**
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
* dimensions
*/
public T blocks(int... blocks) {
this.setBlocks(blocks);
return (T) this;
}
/**
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
*/
public T padding(int[][] padding) {
this.setPadding(padding);
return (T) this;
}
@Override
public T name(String layerName) {
this.setLayerName(layerName);
return (T) this;
}
@Override
@SuppressWarnings("unchecked")
public SpaceToBatchLayer build() {
if(padding == null)
setPadding(new int[][] {{0, 0}, {0, 0}});
return new SpaceToBatchLayer(this);
} }
} }
} }

View File

@ -21,6 +21,7 @@
package org.deeplearning4j.nn.conf.layers; package org.deeplearning4j.nn.conf.layers;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.InputPreProcessor;
@ -40,6 +41,7 @@ import java.util.Map;
@NoArgsConstructor @NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder
public class SpaceToDepthLayer extends NoParamLayer { public class SpaceToDepthLayer extends NoParamLayer {
/** /**
@ -53,16 +55,20 @@ public class SpaceToDepthLayer extends NoParamLayer {
return this == NCHW ? CNN2DFormat.NCHW : CNN2DFormat.NHWC; return this == NCHW ? CNN2DFormat.NCHW : CNN2DFormat.NHWC;
} }
} }
/**
* @param blockSize Block size
*/
protected int blockSize; protected int blockSize;
protected CNN2DFormat dataFormat; /**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
* @param dataFormat Format for activations (in and out)
*/
@Builder.Default
protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
protected SpaceToDepthLayer(Builder builder) {
super(builder);
this.setBlockSize(builder.blockSize);
this.setDataFormat(builder.dataFormat);
}
@Override @Override
public SpaceToDepthLayer clone() { public SpaceToDepthLayer clone() {
@ -74,7 +80,7 @@ public class SpaceToDepthLayer extends NoParamLayer {
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
boolean initializeParams, DataType networkDataType) { boolean initializeParams, DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
runInheritance();
org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret = org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret =
new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(lconf, networkDataType); new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(lconf, networkDataType);
ret.addTrainingListeners(trainingListeners); ret.addTrainingListeners(trainingListeners);
@ -133,78 +139,5 @@ public class SpaceToDepthLayer extends NoParamLayer {
} }
@NoArgsConstructor
@Getter
@Setter
public static class Builder<T extends Builder<T>> extends LayerConfiguration.Builder<T> {
protected int blockSize;
/**
* Data format for input activations. Note DL4J uses NCHW in most cases
*/
protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
/**
* @param blockSize Block size
*/
public Builder(int blockSize) {
this.setBlockSize(blockSize);
}
/**
* @param blockSize Block size
* @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases
*/
@Deprecated
public Builder(int blockSize, DataFormat dataFormat) {
this(blockSize, dataFormat.toFormat());
}
public Builder(int blockSize, CNN2DFormat dataFormat) {
this.setBlockSize(blockSize);
this.setDataFormat(dataFormat);
}
/**
* @param blockSize Block size
*/
public T blocks(int blockSize) {
this.setBlockSize(blockSize);
return (T) this;
}
/**
* @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases
* @deprecated Use {@link #dataFormat(CNN2DFormat)}
*/
@Deprecated
public T dataFormat(DataFormat dataFormat) {
return dataFormat(dataFormat.toFormat());
}
/**
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
* See {@link CNN2DFormat} for more details.<br>
* Default: NCHW
* @param dataFormat Format for activations (in and out)
*/
public T dataFormat(CNN2DFormat dataFormat) {
this.setDataFormat(dataFormat);
return (T) this;
}
@Override
public T name(String layerName) {
this.setLayerName(layerName);
return (T) this;
}
@Override
@SuppressWarnings("unchecked")
public SpaceToDepthLayer build() {
return new SpaceToDepthLayer(this);
}
}
} }

View File

@ -20,10 +20,14 @@
package org.deeplearning4j.nn.conf.layers.objdetect; package org.deeplearning4j.nn.conf.layers.objdetect;
import lombok.Data; import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import lombok.EqualsAndHashCode; import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import lombok.Getter; import java.util.Arrays;
import lombok.Setter; import java.util.Collection;
import java.util.List;
import java.util.Map;
import lombok.*;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
@ -41,44 +45,57 @@ import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossL2; import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
@Data @Data
@EqualsAndHashCode(callSuper = false) @EqualsAndHashCode(callSuper = false)
@SuperBuilder(buildMethodName = "initBuild")
public class Yolo2OutputLayer extends LayerConfiguration { public class Yolo2OutputLayer extends LayerConfiguration {
private double lambdaCoord; /**
private double lambdaNoObj; * Loss function coefficient for position and size/scale components of the loss function. Default
private ILossFunction lossPositionScale; * (as per paper): 5
private ILossFunction lossClassPredictions; */
@Builder.Default private double lambdaCoord = 5;
/**
* Loss function coefficient for the "no object confidence" components of the loss function.
* Default (as per paper): 0.5
*/
@Builder.Default private double lambdaNoObj = 0.5;
/** Loss function for position/scale component of the loss function */
@Builder.Default private ILossFunction lossPositionScale = new LossL2();
/**
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as
* per the paper), however Loss MCXENT could also be used (which is more common for
* classification).
*/
@Builder.Default private ILossFunction lossClassPredictions = new LossL2();
;
/**
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows,
* columns] = [N, 2] Note that dimensions should be specified as fraction of grid size. For
* example, a network with 13x13 output, a value of 1.0 would correspond to one grid cell; a value
* of 13 would correspond to the entire image.
*/
@JsonSerialize(using = NDArrayTextSerializer.class) @JsonSerialize(using = NDArrayTextSerializer.class)
@JsonDeserialize(using = BoundingBoxesDeserializer.class) @JsonDeserialize(using = BoundingBoxesDeserializer.class)
@Builder.Default
private INDArray boundingBoxes; private INDArray boundingBoxes;
private CNN2DFormat format = CNN2DFormat.NCHW; //Default for serialization of old formats @Builder.Default
private CNN2DFormat format = CNN2DFormat.NCHW; // Default for serialization of old formats
private Yolo2OutputLayer() { private Yolo2OutputLayer() {
//No-arg constructor for Jackson JSON // No-arg constructor for Jackson JSON
}
private Yolo2OutputLayer(Builder builder) {
super(builder);
this.lambdaCoord = builder.lambdaCoord;
this.lambdaNoObj = builder.lambdaNoObj;
this.lossPositionScale = builder.lossPositionScale;
this.lossClassPredictions = builder.lossClassPredictions;
this.boundingBoxes = builder.boundingBoxes;
} }
@Override @Override
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, public Layer instantiate(
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf,
Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType) {
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret = org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret =
@ -99,7 +116,7 @@ public class Yolo2OutputLayer extends LayerConfiguration {
@Override @Override
public InputType getOutputType(int layerIndex, InputType inputType) { public InputType getOutputType(int layerIndex, InputType inputType) {
return inputType; //Same shape output as input return inputType; // Same shape output as input
} }
@Override @Override
@ -126,133 +143,41 @@ public class Yolo2OutputLayer extends LayerConfiguration {
@Override @Override
public List<Regularization> getRegularizationByParam(String paramName) { public List<Regularization> getRegularizationByParam(String paramName) {
//Not applicable // Not applicable
return null; return null;
} }
@Override @Override
public boolean isPretrainParam(String paramName) { public boolean isPretrainParam(String paramName) {
return false; //No params return false; // No params
} }
@Override @Override
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
long numValues = inputType.arrayElementsPerExample(); long numValues = inputType.arrayElementsPerExample();
//This is a VERY rough estimate... // This is a VERY rough estimate...
return new LayerMemoryReport.Builder(name, Yolo2OutputLayer.class, inputType, inputType) return new LayerMemoryReport.Builder(name, Yolo2OutputLayer.class, inputType, inputType)
.standardMemory(0, 0) //No params .standardMemory(0, 0) // No params
.workingMemory(0, numValues, 0, 6 * numValues).cacheMemory(0, 0) //No cache .workingMemory(0, numValues, 0, 6 * numValues)
.cacheMemory(0, 0) // No cache
.build(); .build();
} }
@Getter public static abstract class Yolo2OutputLayerBuilder<
@Setter C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
public static class Builder extends LayerConfiguration.Builder<Builder> { extends LayerConfigurationBuilder<C, B> {
public C build() {
/** if (boundingBoxes$value == null) {
* Loss function coefficient for position and size/scale components of the loss function. Default (as per
* paper): 5
*
*/
private double lambdaCoord = 5;
/**
* Loss function coefficient for the "no object confidence" components of the loss function. Default (as per
* paper): 0.5
*
*/
private double lambdaNoObj = 0.5;
/**
* Loss function for position/scale component of the loss function
*
*/
private ILossFunction lossPositionScale = new LossL2();
/**
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as per the
* paper), however Loss MCXENT could also be used (which is more common for classification).
*
*/
private ILossFunction lossClassPredictions = new LossL2();
/**
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows, columns] = [N,
* 2] Note that dimensions should be specified as fraction of grid size. For example, a network with 13x13
* output, a value of 1.0 would correspond to one grid cell; a value of 13 would correspond to the entire
* image.
*
*/
private INDArray boundingBoxes;
/**
* Loss function coefficient for position and size/scale components of the loss function. Default (as per
* paper): 5
*
* @param lambdaCoord Lambda value for size/scale component of loss function
*/
public Builder lambdaCoord(double lambdaCoord) {
this.setLambdaCoord(lambdaCoord);
return this;
}
/**
* Loss function coefficient for the "no object confidence" components of the loss function. Default (as per
* paper): 0.5
*
* @param lambdaNoObj Lambda value for no-object (confidence) component of the loss function
*/
public Builder lambdaNoObj(double lambdaNoObj) {
this.setLambdaNoObj(lambdaNoObj);
return this;
}
/**
* Loss function for position/scale component of the loss function
*
* @param lossPositionScale Loss function for position/scale
*/
public Builder lossPositionScale(ILossFunction lossPositionScale) {
this.setLossPositionScale(lossPositionScale);
return this;
}
/**
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as per the
* paper), however Loss MCXENT could also be used (which is more common for classification).
*
* @param lossClassPredictions Loss function for the class prediction error component of the YOLO loss function
*/
public Builder lossClassPredictions(ILossFunction lossClassPredictions) {
this.setLossClassPredictions(lossClassPredictions);
return this;
}
/**
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows, columns] = [N,
* 2] Note that dimensions should be specified as fraction of grid size. For example, a network with 13x13
* output, a value of 1.0 would correspond to one grid cell; a value of 13 would correspond to the entire
* image.
*
* @param boundingBoxes Bounding box prior dimensions (width, height)
*/
public Builder boundingBoxPriors(INDArray boundingBoxes) {
this.setBoundingBoxes(boundingBoxes);
return this;
}
@Override
public Yolo2OutputLayer build() {
if (boundingBoxes == null) {
throw new IllegalStateException("Bounding boxes have not been set"); throw new IllegalStateException("Bounding boxes have not been set");
} }
if (boundingBoxes.rank() != 2 || boundingBoxes.size(1) != 2) { if (boundingBoxes$value.rank() != 2 || boundingBoxes$value.size(1) != 2) {
throw new IllegalStateException("Bounding box priors must have shape [nBoxes, 2]. Has shape: " throw new IllegalStateException(
+ Arrays.toString(boundingBoxes.shape())); "Bounding box priors must have shape [nBoxes, 2]. Has shape: "
+ Arrays.toString(boundingBoxes$value.shape()));
} }
return initBuild();
return new Yolo2OutputLayer(this);
} }
} }
} }

View File

@ -48,7 +48,7 @@ public class SimpleRnn extends BaseRecurrentLayer {
* If true (default = false): enable layer normalization on this layer * If true (default = false): enable layer normalization on this layer
* *
*/ */
@lombok.Builder.Default @Accessors @lombok.Builder.Default @Accessors @Getter
private boolean hasLayerNorm = false; private boolean hasLayerNorm = false;

View File

@ -20,6 +20,9 @@
package org.deeplearning4j.nn.conf.layers.samediff; package org.deeplearning4j.nn.conf.layers.samediff;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import lombok.*; import lombok.*;
import lombok.experimental.SuperBuilder; import lombok.experimental.SuperBuilder;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -44,11 +47,6 @@ import org.nd4j.linalg.learning.regularization.L2Regularization;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.learning.regularization.WeightDecay; import org.nd4j.linalg.learning.regularization.WeightDecay;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
@Slf4j @Slf4j
@Data @Data
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true) @EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
@ -59,15 +57,17 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
/** /**
* The regularization for the parameters (excluding biases) - for example {@link WeightDecay} * The regularization for the parameters (excluding biases) - for example {@link WeightDecay}
* *
* -- SETTER -- * <p>-- SETTER -- Set the regularization for the parameters (excluding biases) - for example
* Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay} * {@link WeightDecay}
* @param regularization Regularization to apply for the network parameters/weights (excluding biases) *
* @param regularization Regularization to apply for the network parameters/weights (excluding
* biases)
*/ */
protected List<Regularization> regularization; protected List<Regularization> regularization;
/** /**
* The regularization for the biases only - for example {@link WeightDecay} * The regularization for the biases only - for example {@link WeightDecay} -- SETTER -- Set the
* -- SETTER -- * regularization for the biases only - for example {@link WeightDecay}
* Set the regularization for the biases only - for example {@link WeightDecay} *
* @param regularizationBias Regularization to apply for the network biases only * @param regularizationBias Regularization to apply for the network biases only
*/ */
protected List<Regularization> regularizationBias; protected List<Regularization> regularizationBias;
@ -79,14 +79,13 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
*/ */
protected @Getter @Setter IUpdater updater; protected @Getter @Setter IUpdater updater;
/** /**
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as set by {@link * Gradient updater configuration, for the biases only. If not set, biases will use the updater as
* #setUpdater(IUpdater)} * set by {@link #setUpdater(IUpdater)}
* *
* @param biasUpdater Updater to use for bias parameters * @param biasUpdater Updater to use for bias parameters
*/ */
protected @Getter @Setter IUpdater biasUpdater; protected @Getter @Setter IUpdater biasUpdater;
protected GradientNormalization gradientNormalization; protected GradientNormalization gradientNormalization;
protected double gradientNormalizationThreshold = Double.NaN; protected double gradientNormalizationThreshold = Double.NaN;
@ -94,9 +93,9 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
@Override @Override
public List<Regularization> getRegularizationByParam(String paramName) { public List<Regularization> getRegularizationByParam(String paramName) {
if(layerParams.isWeightParam(paramName)){ if (layerParams.isWeightParam(paramName)) {
return regularization; return regularization;
} else if(layerParams.isBiasParam(paramName)){ } else if (layerParams.isBiasParam(paramName)) {
return regularizationBias; return regularizationBias;
} }
return null; return null;
@ -112,23 +111,23 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {
//Default implementation: no-op // Default implementation: no-op
} }
@Override @Override
public InputPreProcessor getPreProcessorForInputType(InputType inputType) { public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
//Default implementation: no-op // Default implementation: no-op
return null; return null;
} }
public void applyGlobalConfigToLayer(
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
//Default implementation: no op // Default implementation: no op
} }
/** /**
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String, long...)} and {@link * Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String,
* SDLayerParams#addBiasParam(String, long...)} * long...)} and {@link SDLayerParams#addBiasParam(String, long...)}
* *
* @param params Object used to set parameters for this layer * @param params Object used to set parameters for this layer
*/ */
@ -142,11 +141,15 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
public abstract void initializeParameters(Map<String, INDArray> params); public abstract void initializeParameters(Map<String, INDArray> params);
@Override @Override
public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, public abstract org.deeplearning4j.nn.api.Layer instantiate(
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, NeuralNetConfiguration conf,
boolean initializeParams, DataType networkDataType); Collection<TrainingListener> trainingListeners,
int layerIndex,
INDArray layerParamsView,
boolean initializeParams,
DataType networkDataType);
//================================================================================================================== // ==================================================================================================================
@Override @Override
public ParamInitializer initializer() { public ParamInitializer initializer() {
@ -157,7 +160,8 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
public IUpdater getUpdaterByParam(String paramName) { public IUpdater getUpdaterByParam(String paramName) {
if (biasUpdater != null && initializer().isBiasParam(this, paramName)) { if (biasUpdater != null && initializer().isBiasParam(this, paramName)) {
return biasUpdater; return biasUpdater;
} else if (initializer().isBiasParam(this, paramName) || initializer().isWeightParam(this, paramName)) { } else if (initializer().isBiasParam(this, paramName)
|| initializer().isWeightParam(this, paramName)) {
return updater; return updater;
} }
throw new IllegalStateException("Unknown parameter key: " + paramName); throw new IllegalStateException("Unknown parameter key: " + paramName);
@ -170,12 +174,12 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
@Override @Override
public LayerMemoryReport getMemoryReport(InputType inputType) { public LayerMemoryReport getMemoryReport(InputType inputType) {
return new LayerMemoryReport(); //TODO return new LayerMemoryReport(); // TODO
} }
/** /**
* Returns the memory layout ('c' or 'f' order - i.e., row/column major) of the parameters. In most cases, this * Returns the memory layout ('c' or 'f' order - i.e., row/column major) of the parameters. In
* can/should be left * most cases, this can/should be left
* *
* @param param Name of the parameter * @param param Name of the parameter
* @return Memory layout ('c' or 'f') of the parameter * @return Memory layout ('c' or 'f') of the parameter
@ -185,7 +189,8 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
} }
protected void initWeights(int fanIn, int fanOut, WeightInit weightInit, INDArray array) { protected void initWeights(int fanIn, int fanOut, WeightInit weightInit, INDArray array) {
WeightInitUtil.initWeights(fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array); WeightInitUtil.initWeights(
fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
} }
public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) { public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) {
@ -213,63 +218,74 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
} }
/** /**
* This method generates an "all ones" mask array for use in the SameDiff model when none is provided. * This method generates an "all ones" mask array for use in the SameDiff model when none is
* provided.
*
* @param input Input to the layer * @param input Input to the layer
* @return A mask array - should be same datatype as the input (usually) * @return A mask array - should be same datatype as the input (usually)
*/ */
public INDArray onesMaskForInput(INDArray input){ public INDArray onesMaskForInput(INDArray input) {
if(input.rank() == 2){ if (input.rank() == 2) {
return Nd4j.ones(input.dataType(), input.size(0), 1); return Nd4j.ones(input.dataType(), input.size(0), 1);
} else if(input.rank() == 3){ } else if (input.rank() == 3) {
return Nd4j.ones(input.dataType(), input.size(0), input.size(2)); //mask: [mb, length] vs. input [mb, nIn, length] return Nd4j.ones(
} else if(input.rank() == 4){ input.dataType(),
//CNN style - return [mb, 1, 1, 1] for broadcast... input.size(0),
input.size(2)); // mask: [mb, length] vs. input [mb, nIn, length]
} else if (input.rank() == 4) {
// CNN style - return [mb, 1, 1, 1] for broadcast...
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1); return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1);
} else if(input.rank() == 5){ } else if (input.rank() == 5) {
//CNN3D style - return [mb, 1, 1, 1, 1] for broadcast... // CNN3D style - return [mb, 1, 1, 1, 1] for broadcast...
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1); return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1);
} else { } else {
throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, " + throw new IllegalStateException(
"in order to determine the correct mask shape for this layer"); "When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, "
+ "in order to determine the correct mask shape for this layer");
} }
} }
public static abstract class AbstractSameDiffLayerBuilder<C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>> { public abstract static class AbstractSameDiffLayerBuilder<
C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> {
/** /**
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1 regularization * L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1
* coefficient for the bias. * regularization coefficient for the bias.
*/ */
public B l1(double l1) { public B l1(double l1) {
//Check if existing L1 exists; if so, replace it // Check if existing L1 exists; if so, replace it
NetworkUtils.removeInstances(this.regularization, L1Regularization.class); NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
if(l1 > 0.0) { if (l1 > 0.0) {
this.regularization.add(new L1Regularization(l1)); this.regularization.add(new L1Regularization(l1));
} }
return self(); return self();
} }
/** /**
* L2 regularization coefficient (weights only). Use {@link #l2Bias(double)} to configure the l2 regularization * L2 regularization coefficient (weights only). Use {@link #l2Bias(double)} to configure the l2
* coefficient for the bias.<br> * regularization coefficient for the bias.<br>
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)} should be preferred to * <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)}
* L2 regularization. See {@link WeightDecay} javadoc for further details.<br> * should be preferred to L2 regularization. See {@link WeightDecay} javadoc for further
* details.<br>
*/ */
public B l2(double l2) { public B l2(double l2) {
//Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make sense to use both // Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make
// sense to use both
NetworkUtils.removeInstances(this.regularization, L2Regularization.class); NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
if(l2 > 0.0) { if (l2 > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization"); NetworkUtils.removeInstancesWithWarning(
this.regularization,
WeightDecay.class,
"WeightDecay regularization removed: incompatible with added L2 regularization");
this.regularization.add(new L2Regularization(l2)); this.regularization.add(new L2Regularization(l2));
} }
return self(); return self();
} }
/** /** L1 regularization coefficient for the bias. Default: 0. See also {@link #l1(double)} */
* L1 regularization coefficient for the bias. Default: 0. See also {@link #l1(double)}
*/
public B l1Bias(double l1Bias) { public B l1Bias(double l1Bias) {
NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class); NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
if(l1Bias > 0.0) { if (l1Bias > 0.0) {
this.regularizationBias.add(new L1Regularization(l1Bias)); this.regularizationBias.add(new L1Regularization(l1Bias));
} }
return self(); return self();
@ -277,13 +293,17 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
/** /**
* L2 regularization coefficient for the bias. Default: 0. See also {@link #l2(double)}<br> * L2 regularization coefficient for the bias. Default: 0. See also {@link #l2(double)}<br>
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)} should be preferred to * <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)}
* L2 regularization. See {@link WeightDecay} javadoc for further details.<br> * should be preferred to L2 regularization. See {@link WeightDecay} javadoc for further
* details.<br>
*/ */
public B l2Bias(double l2Bias) { public B l2Bias(double l2Bias) {
NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class); NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
if(l2Bias > 0.0) { if (l2Bias > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, WeightDecay.class, "WeightDecay bias regularization removed: incompatible with added L2 regularization"); NetworkUtils.removeInstancesWithWarning(
this.regularizationBias,
WeightDecay.class,
"WeightDecay bias regularization removed: incompatible with added L2 regularization");
this.regularizationBias.add(new L2Regularization(l2Bias)); this.regularizationBias.add(new L2Regularization(l2Bias));
} }
return self(); return self();
@ -291,7 +311,8 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
/** /**
* Add weight decay regularization for the network parameters (excluding biases).<br> * Add weight decay regularization for the network parameters (excluding biases).<br>
* This applies weight decay <i>with</i> multiplying the learning rate - see {@link WeightDecay} for more details.<br> * This applies weight decay <i>with</i> multiplying the learning rate - see {@link WeightDecay}
* for more details.<br>
* *
* @param coefficient Weight decay regularization coefficient * @param coefficient Weight decay regularization coefficient
* @see #weightDecay(double, boolean) * @see #weightDecay(double, boolean)
@ -301,25 +322,31 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
} }
/** /**
* Add weight decay regularization for the network parameters (excluding biases). See {@link WeightDecay} for more details.<br> * Add weight decay regularization for the network parameters (excluding biases). See {@link
* WeightDecay} for more details.<br>
* *
* @param coefficient Weight decay regularization coefficient * @param coefficient Weight decay regularization coefficient
* @param applyLR Whether the learning rate should be multiplied in when performing weight decay updates. See {@link WeightDecay} for more details. * @param applyLR Whether the learning rate should be multiplied in when performing weight decay
* updates. See {@link WeightDecay} for more details.
* @see #weightDecay(double, boolean) * @see #weightDecay(double, boolean)
*/ */
public B weightDecay(double coefficient, boolean applyLR) { public B weightDecay(double coefficient, boolean applyLR) {
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both // Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't
// make sense to use both
NetworkUtils.removeInstances(this.regularization, WeightDecay.class); NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
if(coefficient > 0.0) { if (coefficient > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization"); NetworkUtils.removeInstancesWithWarning(
this.regularization,
L2Regularization.class,
"L2 regularization removed: incompatible with added WeightDecay regularization");
this.regularization.add(new WeightDecay(coefficient, applyLR)); this.regularization.add(new WeightDecay(coefficient, applyLR));
} }
return self(); return self();
} }
/** /**
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details. * Weight decay for the biases only - see {@link #weightDecay(double)} for more details. This
* This applies weight decay <i>with</i> multiplying the learning rate.<br> * applies weight decay <i>with</i> multiplying the learning rate.<br>
* *
* @param coefficient Weight decay regularization coefficient * @param coefficient Weight decay regularization coefficient
* @see #weightDecayBias(double, boolean) * @see #weightDecayBias(double, boolean)
@ -334,10 +361,14 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
* @param coefficient Weight decay regularization coefficient * @param coefficient Weight decay regularization coefficient
*/ */
public B weightDecayBias(double coefficient, boolean applyLR) { public B weightDecayBias(double coefficient, boolean applyLR) {
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both // Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't
// make sense to use both
NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class); NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
if(coefficient > 0.0) { if (coefficient > 0.0) {
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization"); NetworkUtils.removeInstancesWithWarning(
this.regularizationBias,
L2Regularization.class,
"L2 bias regularization removed: incompatible with added WeightDecay regularization");
this.regularizationBias.add(new WeightDecay(coefficient, applyLR)); this.regularizationBias.add(new WeightDecay(coefficient, applyLR));
} }
return self(); return self();

View File

@ -176,7 +176,7 @@ public abstract class SameDiffVertex extends GraphVertex implements ITraininable
} }
@Override @Override
public String getLayerName() { public String getName() {
return name; return name;
} }

View File

@ -285,5 +285,14 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
super.nOut(nOut); super.nOut(nOut);
return self(); return self();
} }
public B pzxActivationFunction(IActivation activation) {
this.pzxActivationFunction$value = activation;
this.pzxActivationFunction$set = true;
return self();
}
public B pzxActivationFunction(Activation activation) {
return this.pzxActivationFunction(activation.getActivationFunction());
}
} }
} }

View File

@ -107,7 +107,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
INDArray origInput = input; INDArray origInput = input;
INDArray origEps = epsilon; INDArray origEps = epsilon;
if(getTypedLayerConfiguration().getDataFormat() != CNN2DFormat.NCHW) { if(getTypedLayerConfiguration().getConvFormat() != CNN2DFormat.NCHW) {
input = input.permute(0,3,1,2); //NHWC to NCHW input = input.permute(0,3,1,2); //NHWC to NCHW
epsilon = epsilon.permute(0,3,1,2); //NHWC to NCHW epsilon = epsilon.permute(0,3,1,2); //NHWC to NCHW
} }
@ -151,7 +151,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr); Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
INDArray z = p.getFirst(); INDArray z = p.getFirst();
CNN2DFormat f = getTypedLayerConfiguration().getDataFormat(); CNN2DFormat f = getTypedLayerConfiguration().getConvFormat();
if(f != CNN2DFormat.NCHW){ if(f != CNN2DFormat.NCHW){
z = z.permute(0,3,1,2); //NHWC to NCHW z = z.permute(0,3,1,2); //NHWC to NCHW
} }
@ -159,7 +159,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())) { if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())) {
INDArray helperDelta = delta; INDArray helperDelta = delta;
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC)
helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC
if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){ if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){
@ -177,7 +177,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
ret = helper.backpropGradient(origInput, weights, bias, helperDelta, kernel, strides, ret = helper.backpropGradient(origInput, weights, bias, helperDelta, kernel, strides,
pad, biasGradView, weightGradView, afn, pad, biasGradView, weightGradView, afn,
getTypedLayerConfiguration().getCudnnAlgoMode(), getTypedLayerConfiguration().getCudnnBwdFilterAlgo(), getTypedLayerConfiguration().getCudnnBwdDataAlgo(), getTypedLayerConfiguration().getCudnnAlgoMode(), getTypedLayerConfiguration().getCudnnBwdFilterAlgo(), getTypedLayerConfiguration().getCudnnBwdDataAlgo(),
convolutionMode, dilation, getTypedLayerConfiguration().getDataFormat(), workspaceMgr); convolutionMode, dilation, getTypedLayerConfiguration().getConvFormat(), workspaceMgr);
} catch (ND4JOpProfilerException e){ } catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging throw e; //NaN panic etc for debugging
} catch (Exception e){ } catch (Exception e){
@ -261,7 +261,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
epsNext = backpropDropOutIfPresent(epsNext); epsNext = backpropDropOutIfPresent(epsNext);
if(getTypedLayerConfiguration().getDataFormat() != CNN2DFormat.NCHW){ if(getTypedLayerConfiguration().getConvFormat()!= CNN2DFormat.NCHW){
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
} }
@ -295,7 +295,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
} }
protected void validateInputDepth(long inDepth) { protected void validateInputDepth(long inDepth) {
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat(); CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
int dim = format == CNN2DFormat.NHWC ? 3 : 1; int dim = format == CNN2DFormat.NHWC ? 3 : 1;
if (input.size(dim) != inDepth) { if (input.size(dim) != inDepth) {
String layerName = layerConfiguration.getName(); String layerName = layerConfiguration.getName();
@ -304,7 +304,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration" + ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + getTypedLayerConfiguration().getConvFormat().dimensionNames()
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId(); + layerId();
@ -337,7 +337,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
INDArray input = this.input.castTo(dataType); INDArray input = this.input.castTo(dataType);
INDArray inputOrig = input; INDArray inputOrig = input;
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) { if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
input = input.permute(0,3,1,2).dup(); //NHWC to NCHW input = input.permute(0,3,1,2).dup(); //NHWC to NCHW
} }
@ -421,7 +421,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
INDArray ret = null; INDArray ret = null;
try { try {
ret = helper.preOutput(inputOrig, weights, bias, kernel, strides, pad, getTypedLayerConfiguration().getCudnnAlgoMode(), ret = helper.preOutput(inputOrig, weights, bias, kernel, strides, pad, getTypedLayerConfiguration().getCudnnAlgoMode(),
getTypedLayerConfiguration().getCudnnFwdAlgo(), convolutionMode, dilation, getTypedLayerConfiguration().getDataFormat(), workspaceMgr); getTypedLayerConfiguration().getCudnnFwdAlgo(), convolutionMode, dilation, getTypedLayerConfiguration().getConvFormat(), workspaceMgr);
} catch (ND4JOpProfilerException e){ } catch (ND4JOpProfilerException e){
throw e; //NaN panic etc for debugging throw e; //NaN panic etc for debugging
} catch (Exception e){ } catch (Exception e){
@ -498,7 +498,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
} }
} }
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) { if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
z = z.permute(0,2,3,1); //NCHW to NHWC z = z.permute(0,2,3,1); //NCHW to NHWC
z = workspaceMgr.dup(ArrayType.ACTIVATIONS, z); z = workspaceMgr.dup(ArrayType.ACTIVATIONS, z);
} }

View File

@ -61,13 +61,13 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
if (input.rank() != 4) { if (input.rank() != 4) {
throw new DL4JInvalidInputException("Got rank " + input.rank() throw new DL4JInvalidInputException("Got rank " + input.rank()
+ " array as input to Deconvolution2DLayer with shape " + Arrays.toString(input.shape()) + " array as input to Deconvolution2DLayer with shape " + Arrays.toString(input.shape())
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". " + ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
+ layerId()); + layerId());
} }
INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr); INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr);
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat(); CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
boolean nchw = format == CNN2DFormat.NCHW; boolean nchw = format == CNN2DFormat.NCHW;
int hDim = nchw ? 2 : 1; int hDim = nchw ? 2 : 1;
int wDim = nchw ? 3 : 2; int wDim = nchw ? 3 : 2;
@ -166,7 +166,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
+ " " + layerId()); + " " + layerId());
} }
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat(); CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
boolean nchw = format == CNN2DFormat.NCHW; boolean nchw = format == CNN2DFormat.NCHW;
int cDim = nchw ? 1 : 3; int cDim = nchw ? 1 : 3;
int hDim = nchw ? 2 : 1; int hDim = nchw ? 2 : 1;

View File

@ -59,12 +59,12 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
@Override @Override
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
assertInputSet(true); assertInputSet(true);
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat(); CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
boolean nchw = format == CNN2DFormat.NCHW; boolean nchw = format == CNN2DFormat.NCHW;
if (input.rank() != 4) { if (input.rank() != 4) {
throw new DL4JInvalidInputException("Got rank " + input.rank() throw new DL4JInvalidInputException("Got rank " + input.rank()
+ " array as input to Convolution layer with shape " + Arrays.toString(input.shape()) + " array as input to Convolution layer with shape " + Arrays.toString(input.shape())
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". " + ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
+ layerId()); + layerId());
} }
INDArray bias; INDArray bias;
@ -158,7 +158,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
throw new DL4JInvalidInputException("Got rank " + input.rank() throw new DL4JInvalidInputException("Got rank " + input.rank()
+ " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = " + " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = "
+ index + ") with shape " + Arrays.toString(input.shape()) + ". " + index + ") with shape " + Arrays.toString(input.shape()) + ". "
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + "." + "Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + "."
+ (input.rank() == 2 + (input.rank() == 2
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
: "") + " " + layerId()); : "") + " " + layerId());
@ -166,7 +166,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
INDArray input = this.input.castTo(dataType); //no-op if correct dtype INDArray input = this.input.castTo(dataType); //no-op if correct dtype
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat(); CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
boolean nchw = format == CNN2DFormat.NCHW; boolean nchw = format == CNN2DFormat.NCHW;
long inDepth = depthWiseWeights.size(2); long inDepth = depthWiseWeights.size(2);

View File

@ -63,7 +63,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
if (input.rank() != 4) { if (input.rank() != 4) {
throw new DL4JInvalidInputException("Got rank " + input.rank() throw new DL4JInvalidInputException("Got rank " + input.rank()
+ " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape()) + " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape())
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". " + ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
+ layerId()); + layerId());
} }
INDArray bias; INDArray bias;
@ -74,7 +74,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
INDArray input = this.input.castTo(dataType); INDArray input = this.input.castTo(dataType);
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat(); CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
boolean nchw = format == CNN2DFormat.NCHW; boolean nchw = format == CNN2DFormat.NCHW;
long miniBatch = input.size(0); long miniBatch = input.size(0);
@ -167,7 +167,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
getParamWithNoise(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY, training, workspaceMgr); getParamWithNoise(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY, training, workspaceMgr);
INDArray input = this.input.castTo(dataType); INDArray input = this.input.castTo(dataType);
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) { if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
input = input.permute(0,3,1,2).dup(); input = input.permute(0,3,1,2).dup();
} }
@ -182,7 +182,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
throw new DL4JInvalidInputException("Got rank " + input.rank() throw new DL4JInvalidInputException("Got rank " + input.rank()
+ " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = " + " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = "
+ index + ") with shape " + Arrays.toString(input.shape()) + ". " + index + ") with shape " + Arrays.toString(input.shape()) + ". "
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + "." + "Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + "."
+ (input.rank() == 2 + (input.rank() == 2
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
: "") : "")
@ -199,7 +199,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration" + ", layer index = " + index + "): input array channels does not match CNN layer configuration"
+ " (data format = " + getTypedLayerConfiguration().getDataFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]=" + " (data format = " + getTypedLayerConfiguration().getConvFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
+ layerId(); + layerId();
@ -287,7 +287,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
.build(); .build();
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) { if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
output = output.permute(0,2,3,1); //NCHW to NHWC output = output.permute(0,2,3,1); //NCHW to NHWC
} }

View File

@ -47,7 +47,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
} }
private int[] getBlocks() { private int[] getBlocks() {
return getTypedLayerConfiguration().getBlocks(); return getTypedLayerConfiguration().getBlockSize();
} }
private int[][] getPadding() { private int[][] getPadding() {
@ -55,7 +55,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
} }
private INDArray getBlocksArray() { private INDArray getBlocksArray() {
int[] intBlocks = getTypedLayerConfiguration().getBlocks(); int[] intBlocks = getTypedLayerConfiguration().getBlockSize();
return Nd4j.createFromArray(intBlocks); return Nd4j.createFromArray(intBlocks);
} }
@ -77,7 +77,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
INDArray input = this.input.castTo(dataType); //Cast to network dtype if required (no-op if already correct type) INDArray input = this.input.castTo(dataType); //Cast to network dtype if required (no-op if already correct type)
boolean nchw = getTypedLayerConfiguration().getFormat() == CNN2DFormat.NCHW; boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW;
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c'); INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
@ -104,7 +104,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
if (input.rank() != 4) { if (input.rank() != 4) {
throw new DL4JInvalidInputException("Got rank " + input.rank() throw new DL4JInvalidInputException("Got rank " + input.rank()
+ " array as input to space to batch with shape " + Arrays.toString(input.shape()) + " array as input to space to batch with shape " + Arrays.toString(input.shape())
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getFormat().dimensionNames() + ". " + ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
+ layerId()); + layerId());
} }
@ -112,7 +112,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
return preOutput; return preOutput;
} }
boolean nchw = getTypedLayerConfiguration().getFormat() == CNN2DFormat.NCHW; boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW;
long inMiniBatch = input.size(0); long inMiniBatch = input.size(0);
long depth = input.size(nchw ? 1 : 3); long depth = input.size(nchw ? 1 : 3);

View File

@ -87,7 +87,7 @@ public class ZeroPadding1DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
@Override @Override
public Layer clone() { public Layer clone() {
return ZeroPadding1DLayer.builder(layerConfiguration.clone(), dataType); return new ZeroPadding1DLayer(layerConfiguration.clone(), dataType);
} }
@Override @Override

View File

@ -312,6 +312,6 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
@Override @Override
public boolean hasLayerNorm(){ public boolean hasLayerNorm(){
return getTypedLayerConfiguration().hasLayerNorm(); return getTypedLayerConfiguration().isHasLayerNorm();
} }
} }

View File

@ -167,7 +167,7 @@ public class SimpleRnnParamInitializer extends AbstractParamInitializer {
protected boolean hasLayerNorm(LayerConfiguration layer){ protected boolean hasLayerNorm(LayerConfiguration layer){
if(layer instanceof SimpleRnn){ if(layer instanceof SimpleRnn){
return ((SimpleRnn) layer).hasLayerNorm(); return ((SimpleRnn) layer).isHasLayerNorm();
} }
return false; return false;
} }

View File

@ -21,39 +21,74 @@
package org.deeplearning4j.nn.weights; package org.deeplearning4j.nn.weights;
import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
public enum WeightInit { public enum WeightInit {
DISTRIBUTION, ZERO, ONES, SIGMOID_UNIFORM, NORMAL, LECUN_NORMAL, UNIFORM, XAVIER, XAVIER_UNIFORM, XAVIER_FAN_IN, XAVIER_LEGACY, RELU, DISTRIBUTION,
RELU_UNIFORM, IDENTITY, LECUN_UNIFORM, VAR_SCALING_NORMAL_FAN_IN, VAR_SCALING_NORMAL_FAN_OUT, VAR_SCALING_NORMAL_FAN_AVG, ZERO,
VAR_SCALING_UNIFORM_FAN_IN, VAR_SCALING_UNIFORM_FAN_OUT, VAR_SCALING_UNIFORM_FAN_AVG; ONES,
SIGMOID_UNIFORM,
NORMAL,
LECUN_NORMAL,
UNIFORM,
XAVIER,
XAVIER_UNIFORM,
XAVIER_FAN_IN,
XAVIER_LEGACY,
RELU,
RELU_UNIFORM,
IDENTITY,
LECUN_UNIFORM,
VAR_SCALING_NORMAL_FAN_IN,
VAR_SCALING_NORMAL_FAN_OUT,
VAR_SCALING_NORMAL_FAN_AVG,
CONSTANT,
EMBEDDING,
VAR_SCALING_UNIFORM_FAN_IN,
VAR_SCALING_UNIFORM_FAN_OUT,
VAR_SCALING_UNIFORM_FAN_AVG;
/**
* Create an instance of the weight initialization function
*
* @return a new {@link IWeightInit} instance
*/
public IWeightInit getWeightInitFunction(Distribution distribution) {
switch (this) {
case DISTRIBUTION:
return new WeightInitDistribution(distribution);
default:
return getWeightInitFunction();
}
}
public IWeightInit getWeightInitFunction(
EmbeddingInitializer initializer) { // EmbeddingInitializer
switch (this) {
case EMBEDDING:
return new WeightInitEmbedding(initializer);
default:
return getWeightInitFunction();
}
}
/** /**
* Create an instance of the weight initialization function * Create an instance of the weight initialization function
* *
* @return a new {@link IWeightInit} instance * @return a new {@link IWeightInit} instance
*/ */
public IWeightInit getWeightInitFunction() { public IWeightInit getWeightInitFunction() {
return getWeightInitFunction(null);
}
/**
* Create an instance of the weight initialization function
*
* @param distribution Distribution of the weights (Only used in case DISTRIBUTION)
* @return a new {@link IWeightInit} instance
*/
public IWeightInit getWeightInitFunction(Distribution distribution) {
switch (this) { switch (this) {
case CONSTANT:
return new WeightInitConstant();
case ZERO: case ZERO:
return new WeightInitConstant(0.0); return new WeightInitConstant(0.0);
case ONES: case ONES:
return new WeightInitConstant(1.0); return new WeightInitConstant(1.0);
case DISTRIBUTION:
return new WeightInitDistribution(distribution);
case SIGMOID_UNIFORM: case SIGMOID_UNIFORM:
return new WeightInitSigmoidUniform(); return new WeightInitSigmoidUniform();
case LECUN_NORMAL: //Fall through: these 3 are equivalent case LECUN_NORMAL: // Fall through: these 3 are equivalent
case XAVIER_FAN_IN: case XAVIER_FAN_IN:
case NORMAL: case NORMAL:
return new WeightInitNormal(); return new WeightInitNormal();
@ -87,7 +122,8 @@ public enum WeightInit {
return new WeightInitVarScalingUniformFanAvg(); return new WeightInitVarScalingUniformFanAvg();
default: default:
throw new UnsupportedOperationException("Unknown or not supported weight initialization function: " + this); throw new UnsupportedOperationException(
"Unknown or not supported weight initialization function: " + this);
} }
} }
} }

View File

@ -42,4 +42,13 @@ public class WeightInitConstant implements IWeightInit {
paramView.assign(value); paramView.assign(value);
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.CONSTANT;
}
} }

View File

@ -51,4 +51,13 @@ public class WeightInitDistribution implements IWeightInit {
} }
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.DISTRIBUTION;
}
} }

View File

@ -38,4 +38,13 @@ public class WeightInitLecunUniform implements IWeightInit {
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-b, b)); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-b, b));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.LECUN_UNIFORM;
}
} }

View File

@ -38,4 +38,13 @@ public class WeightInitNormal implements IWeightInit {
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn)); Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.NORMAL;
}
} }

View File

@ -33,4 +33,13 @@ public class WeightInitRelu implements IWeightInit {
Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn) Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn)
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.RELU;
}
} }

View File

@ -35,4 +35,13 @@ public class WeightInitReluUniform implements IWeightInit {
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn) Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.RELU_UNIFORM;
}
} }

View File

@ -35,4 +35,13 @@ public class WeightInitSigmoidUniform implements IWeightInit {
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-r, r)); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-r, r));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.SIGMOID_UNIFORM;
}
} }

View File

@ -34,4 +34,13 @@ public class WeightInitUniform implements IWeightInit {
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-a, a)); Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-a, a));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.UNIFORM;
}
} }

View File

@ -50,4 +50,13 @@ public class WeightInitVarScalingNormalFanAvg implements IWeightInit {
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std)); Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
}
} }

View File

@ -48,4 +48,13 @@ public class WeightInitVarScalingNormalFanIn implements IWeightInit {
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std)); Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.VAR_SCALING_NORMAL_FAN_IN;
}
} }

View File

@ -50,4 +50,13 @@ public class WeightInitVarScalingNormalFanOut implements IWeightInit {
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std)); Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
return paramView.reshape(order, shape); return paramView.reshape(order, shape);
} }
/**
*
* @return
*/
@Override
public WeightInit enumValue() {
return WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
}
} }

Some files were not shown because too many files have changed in this diff Show More