Using @SuperBuilder for LayerConfigurations
Signed-off-by: brian <brian@brutex.de>master
parent
f6100c362d
commit
ad870c5281
|
@ -221,7 +221,7 @@ public class TestMiscFunctions extends BaseSparkTest {
|
|||
int nIn = 10;
|
||||
|
||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.reconstructionDistribution(
|
||||
new GaussianReconstructionDistribution(Activation.IDENTITY))
|
||||
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build())
|
||||
|
@ -261,7 +261,7 @@ public class TestMiscFunctions extends BaseSparkTest {
|
|||
|
||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder()
|
||||
.list().layer(0,
|
||||
new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.reconstructionDistribution(new LossFunctionWrapper(
|
||||
Activation.IDENTITY, new LossMSE()))
|
||||
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)
|
||||
|
|
|
@ -96,7 +96,7 @@ public class App {
|
|||
private static LayerConfiguration[] genLayers() {
|
||||
return new LayerConfiguration[] {
|
||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
||||
|
|
|
@ -331,10 +331,10 @@ public class CNN2DTestCases {
|
|||
.build(),
|
||||
"leaky_re_lu_8")
|
||||
.addLayer("outputs",
|
||||
new Yolo2OutputLayer.Builder()
|
||||
Yolo2OutputLayer.builder()
|
||||
.lambdaNoObj(lambdaNoObj)
|
||||
.lambdaCoord(lambdaCoord)
|
||||
.boundingBoxPriors(priors)
|
||||
.boundingBoxes(priors)
|
||||
.build(),
|
||||
"convolution2d_9")
|
||||
.setOutputs("outputs")
|
||||
|
|
|
@ -322,7 +322,7 @@ public class RNNTestCases {
|
|||
.updater(new Adam(5e-2))
|
||||
.l1(1e-3).l2(1e-3)
|
||||
|
||||
.layer(0, Bidirectional.builder(LSTM.builder().activation(Activation.TANH).nOut(10).build()))
|
||||
.layer(0, Bidirectional.builder(LSTM.builder().activation(Activation.TANH).nOut(10).build()).build())
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||
.layer(OutputLayer.builder().nOut(6)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
|
|
|
@ -22,9 +22,11 @@ package org.nd4j.linalg.activations;
|
|||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.linalg.activations.impl.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public enum Activation {
|
||||
public enum Activation implements IActivation {
|
||||
CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6,
|
||||
RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH,
|
||||
THRESHOLDEDRELU, GELU, MISH;
|
||||
|
@ -149,4 +151,44 @@ public enum Activation {
|
|||
throw new UnsupportedOperationException("Activation function not yet supported: " + this);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Carry out activation function on the input array (usually known as 'preOut' or 'z')
|
||||
* Implementations must overwrite "in", transform in place and return "in"
|
||||
* Can support separate behaviour during test
|
||||
*
|
||||
* @param in input array.
|
||||
* @param training true when training.
|
||||
* @return transformed activation
|
||||
*/
|
||||
@Override
|
||||
public INDArray getActivation(INDArray in, boolean training) {
|
||||
return getActivationFunction().getActivation(in, training);
|
||||
}
|
||||
|
||||
/**
|
||||
* Backpropagate the errors through the activation function, given input z and epsilon dL/da.<br>
|
||||
* Returns 2 INDArrays:<br>
|
||||
* (a) The gradient dL/dz, calculated from dL/da, and<br>
|
||||
* (b) The parameter gradients dL/dW, where w is the weights in the activation function. For activation functions
|
||||
* with no gradients, this will be null.
|
||||
*
|
||||
* @param in Input, before applying the activation function (z, or 'preOut')
|
||||
* @param epsilon Gradient to be backpropagated: dL/da, where L is the loss function
|
||||
* @return dL/dz and dL/dW, for weights w (null if activation function has no weights)
|
||||
*/
|
||||
@Override
|
||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||
return getActivationFunction().backprop(in, epsilon);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param inputSize
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public int numParams(int inputSize) {
|
||||
return getActivationFunction().numParams(inputSize);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -872,7 +872,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
|||
.nIn(10)
|
||||
.nOut(10)
|
||||
.activation(Activation.TANH)
|
||||
.gateActivationFunction(Activation.SIGMOID)
|
||||
.gateActivationFunction(Activation.SIGMOID.getActivationFunction())
|
||||
.dropOut(0.5)
|
||||
.build())
|
||||
.layer(1, RnnOutputLayer.builder()
|
||||
|
|
|
@ -90,8 +90,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.layer(LSTM.builder().nOut(layerSize).build())
|
||||
.layer( projectInput ?
|
||||
new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()
|
||||
: new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build()
|
||||
SelfAttentionLayer.builder().nOut(4).nHeads(2).projectInput(true).build()
|
||||
: SelfAttentionLayer.builder().nHeads(1).projectInput(false).build()
|
||||
)
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||
|
@ -151,8 +151,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
|||
.list()
|
||||
.layer(LSTM.builder().nOut(layerSize).build())
|
||||
.layer( projectInput ?
|
||||
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
||||
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
||||
LearnedSelfAttentionLayer.builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
||||
: LearnedSelfAttentionLayer.builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
||||
)
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||
|
@ -191,8 +191,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
|||
.list()
|
||||
.layer(LSTM.builder().nOut(layerSize).build())
|
||||
.layer( projectInput ?
|
||||
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
||||
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
||||
LearnedSelfAttentionLayer.builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
||||
: LearnedSelfAttentionLayer.builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
||||
)
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||
|
@ -245,7 +245,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(LSTM.builder().nOut(layerSize).build())
|
||||
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
|
@ -308,7 +308,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(LSTM.builder().nOut(layerSize).build())
|
||||
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
|
|
|
@ -363,7 +363,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.nOut(1)
|
||||
.build()) // output: (5-2+0)/1+1 = 4
|
||||
.layer(
|
||||
new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW)
|
||||
SpaceToDepthLayer.builder().blockSize(blocks).dataFormat(CNN2DFormat.NCHW)
|
||||
.build()) // (mb,1,4,4) -> (mb,4,2,2)
|
||||
.layer(
|
||||
OutputLayer.builder()
|
||||
|
@ -450,10 +450,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
ConvolutionLayer.builder(kernel)
|
||||
.nIn(inputDepth)
|
||||
.nOut(3)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.build())
|
||||
.layer(
|
||||
new SpaceToBatchLayer.Builder(blocks)
|
||||
SpaceToBatchLayer.builder(blocks)
|
||||
.dataFormat(format)
|
||||
.build()) // trivial space to batch
|
||||
.layer(
|
||||
|
@ -546,7 +546,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.layer(
|
||||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.nIn(inputDepth)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nOut(3)
|
||||
.build()) // output: (5-2+0)/1+1 = 4
|
||||
.layer(
|
||||
|
@ -641,7 +641,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
0,
|
||||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.nIn(inputDepth)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nOut(3)
|
||||
.build()) // output: (5-2+0)/1+1 = 4
|
||||
.layer(
|
||||
|
@ -750,7 +750,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
0,
|
||||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.nIn(inputDepth)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nOut(3)
|
||||
.build()) // output: (5-2+0)/1+1 = 4
|
||||
.layer(
|
||||
|
@ -765,7 +765,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.layer(
|
||||
2,
|
||||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(3)
|
||||
.nOut(2)
|
||||
.build()) // Output: (3-2+0)/1+1 = 2
|
||||
|
@ -849,7 +849,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
ConvolutionLayer.builder()
|
||||
.kernelSize(2, 2)
|
||||
.stride(1, 1)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.padding(0, 0)
|
||||
.nIn(inputDepth)
|
||||
.nOut(2)
|
||||
|
@ -861,7 +861,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.nOut(7)
|
||||
.kernelSize(2, 2)
|
||||
.dataFormat(format)
|
||||
.setInputSize(4, 4)
|
||||
.inputSize(new int[]{4, 4})
|
||||
.convolutionMode(ConvolutionMode.Strict)
|
||||
.hasBias(false)
|
||||
.stride(1, 1)
|
||||
|
@ -873,7 +873,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.nIn(7)
|
||||
.nOut(2)
|
||||
.kernelSize(2, 2)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(1, 1)
|
||||
.padding(0, 0)
|
||||
.build()) // (3-2+0)/1+1 = 2
|
||||
|
@ -959,7 +959,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
ConvolutionLayer.builder()
|
||||
.kernelSize(2, 2)
|
||||
.stride(1, 1)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.padding(0, 0)
|
||||
.nIn(inputDepth)
|
||||
.nOut(2)
|
||||
|
@ -970,7 +970,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.nIn(2)
|
||||
.nOut(2)
|
||||
.kernelSize(2, 2)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(1, 1)
|
||||
.padding(0, 0)
|
||||
.build()) // (4-2+0)/1+1 = 3
|
||||
|
@ -980,7 +980,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.nIn(2)
|
||||
.nOut(2)
|
||||
.kernelSize(2, 2)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(1, 1)
|
||||
.padding(0, 0)
|
||||
.build()) // (3-2+0)/1+1 = 2
|
||||
|
@ -1076,7 +1076,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
ConvolutionLayer.builder()
|
||||
.name("layer 0")
|
||||
.kernelSize(k, k)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(1, 1)
|
||||
.padding(0, 0)
|
||||
.nIn(inputDepth)
|
||||
|
@ -1097,7 +1097,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.nIn(2)
|
||||
.nOut(2)
|
||||
.kernelSize(k, k)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(1, 1)
|
||||
.padding(0, 0)
|
||||
.build())
|
||||
|
@ -1181,7 +1181,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
ConvolutionLayer.builder()
|
||||
.name("layer 0")
|
||||
.kernelSize(k, k)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(stride, stride)
|
||||
.padding(0, 0)
|
||||
.nIn(inputDepth)
|
||||
|
@ -1297,7 +1297,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.layer(
|
||||
0,
|
||||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(inputDepth)
|
||||
.nOut(3)
|
||||
.build()) // output: (6-2+0)/1+1 = 5
|
||||
|
@ -1307,7 +1307,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.nIn(3)
|
||||
.nOut(3)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.build()) // output: (6-2+0)/1+1 = 5
|
||||
.layer(
|
||||
3,
|
||||
|
@ -1436,7 +1436,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.name("deconvolution_2D_layer")
|
||||
.kernelSize(k, k)
|
||||
.stride(s, s)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.dilation(d, d)
|
||||
.convolutionMode(cm)
|
||||
.nIn(inputDepth)
|
||||
|
@ -1530,7 +1530,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.stride(s, s)
|
||||
.dilation(d, d)
|
||||
.depthMultiplier(3)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(inputDepth)
|
||||
.nOut(2)
|
||||
.build())
|
||||
|
@ -1621,7 +1621,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.kernelSize(k, k)
|
||||
.stride(s, s)
|
||||
.dilation(d, d)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(inputDepth)
|
||||
.nOut(2)
|
||||
.build());
|
||||
|
@ -1642,7 +1642,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.kernelSize(k, k)
|
||||
.stride(s, s)
|
||||
.dilation(d, d)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.build());
|
||||
}
|
||||
|
||||
|
@ -1732,14 +1732,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.list()
|
||||
.layer(
|
||||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(inputDepth)
|
||||
.nOut(2)
|
||||
.build()) // output: (6-2+0)/1+1 = 5
|
||||
.layer(Cropping2D.builder(crop).dataFormat(format).build())
|
||||
.layer(
|
||||
ConvolutionLayer.builder(kernel, stride, padding)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(2)
|
||||
.nOut(2)
|
||||
.build())
|
||||
|
@ -1857,7 +1857,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
|||
.stride(1, 1)
|
||||
.nIn(nIn)
|
||||
.nOut(nIn)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.build())
|
||||
.layer(
|
||||
DepthwiseConvolution2D.builder()
|
||||
|
|
|
@ -82,7 +82,7 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
|||
.seed(123)
|
||||
.updater(new NoOp())
|
||||
.dist(new UniformDistribution(-6, 6))
|
||||
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel)
|
||||
.layer(PrimaryCapsules.builder(primaryCapsDim, primarpCapsChannel)
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.build())
|
||||
|
|
|
@ -131,7 +131,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
|||
.updater(new NoOp())
|
||||
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
|
||||
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
||||
.dataFormat(nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)
|
||||
.convFormat(nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)
|
||||
.nOut(layerDepth)
|
||||
.build())
|
||||
.layer(1, GlobalPoolingLayer.builder().poolingType(pt).build())
|
||||
|
|
|
@ -345,10 +345,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
|||
.dist(new NormalDistribution(0, 0.1))
|
||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||
.addLayer("l1", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||
.addLayer("l2", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
||||
.padding(0, 0).dataFormat(format)
|
||||
.padding(0, 0).convFormat(format)
|
||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||
.addVertex("merge", new MergeVertex(), "l1", "l2")
|
||||
.addLayer("outputLayer",
|
||||
|
|
|
@ -116,7 +116,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
|
|||
.layer(Bidirectional.builder(m,
|
||||
(simple ?
|
||||
SimpleRnn.builder().nIn(3).nOut(3).hasLayerNorm(hasLayerNorm).build() :
|
||||
LSTM.builder().nIn(3).nOut(3).build())))
|
||||
LSTM.builder().nIn(3).nOut(3).build())).build())
|
||||
.layer(RnnOutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX).build())
|
||||
.build();
|
||||
|
||||
|
|
|
@ -115,12 +115,11 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
|||
.activation(a)
|
||||
.l1(l1[i]).l2(l2[i])
|
||||
.convolutionMode(ConvolutionMode.Same)
|
||||
.list()
|
||||
.layer(ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nIn(depthIn).nOut(yoloDepth).build())//output: (5-2+0)/1+1 = 4
|
||||
.layer(new Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(bbPrior)
|
||||
.layer(Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(bbPrior)
|
||||
.build())
|
||||
.inputType(InputType.convolutional(h, w, depthIn, format))
|
||||
.build();
|
||||
|
@ -237,8 +236,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
|||
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nOut(4).build())
|
||||
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build())
|
||||
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(3,3).stride(1,1).nOut(depthOut).build())
|
||||
.layer(new Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(bbPriors)
|
||||
.layer(Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(bbPriors)
|
||||
.build())
|
||||
.inputType(InputType.convolutional(h,w,c))
|
||||
.build();
|
||||
|
|
|
@ -437,7 +437,7 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
|||
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
||||
.layer(!lossLayer ? OutputLayer.builder().nIn(10).nOut(nOut[i])
|
||||
.activation(activations[i]).lossFunction(lf[i]).build()
|
||||
: LossLayer.builder().lossFunction().activation(activations[i]).lossFunction(lf[i])
|
||||
: LossLayer.builder().activation(activations[i]).lossFunction(lf[i].getILossFunction())
|
||||
.build())
|
||||
.validateOutputLayerConfig(validate)
|
||||
.build();
|
||||
|
|
|
@ -48,6 +48,7 @@ import org.nd4j.linalg.learning.config.RmsProp;
|
|||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
@ -230,7 +231,8 @@ public class TestConstraints extends BaseDL4JTest {
|
|||
.biasInit(0.2)
|
||||
|
||||
.layer(DenseLayer.builder().nIn(12).nOut(10)
|
||||
.constrainAllParameters(lc).build())
|
||||
.allParamConstraints(List.of(lc))
|
||||
.build())
|
||||
.layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build())
|
||||
.build();
|
||||
|
||||
|
|
|
@ -201,21 +201,21 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
|||
.addInputs("input1", "input2", "input3")
|
||||
.addLayer("dense1",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input1")
|
||||
.addLayer("dense2",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input2")
|
||||
.addLayer("dense3",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input3")
|
||||
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
|
||||
"dense2", "dense3")
|
||||
.addLayer("output",
|
||||
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
||||
.activation(new ActivationSigmoid())
|
||||
.activation(Activation.SIGMOID)
|
||||
.lossFunction(LossFunction.MSE).build(),
|
||||
"elementwiseAdd")
|
||||
.setOutputs("output").build();
|
||||
|
@ -377,21 +377,21 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
|||
.addInputs("input1", "input2", "input3")
|
||||
.addLayer("dense1",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input1")
|
||||
.addLayer("dense2",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input2")
|
||||
.addLayer("dense3",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input3")
|
||||
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1",
|
||||
"dense2", "dense3")
|
||||
.addLayer("output",
|
||||
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
||||
.activation(new ActivationSigmoid())
|
||||
.activation(Activation.SIGMOID)
|
||||
.lossFunction(LossFunction.MSE).build(),
|
||||
"elementwiseProduct")
|
||||
.setOutputs("output").build();
|
||||
|
@ -552,17 +552,17 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
|||
.addInputs("input1", "input2")
|
||||
.addLayer("dense1",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input1")
|
||||
.addLayer("dense2",
|
||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||
.activation(new ActivationTanH()).build(),
|
||||
.activation(Activation.TANH).build(),
|
||||
"input2")
|
||||
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
|
||||
"dense1", "dense2")
|
||||
.addLayer("output",
|
||||
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
||||
.activation(new ActivationSigmoid())
|
||||
.activation(Activation.SIGMOID)
|
||||
.lossFunction(LossFunction.MSE).build(),
|
||||
"elementwiseSubtract")
|
||||
.setOutputs("output").build();
|
||||
|
|
|
@ -493,7 +493,7 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(3).activation(Activation.TANH).build();
|
||||
break;
|
||||
case 4:
|
||||
ol = new Yolo2OutputLayer.Builder().boundingBoxPriors(Nd4j.create(new double[][]{{1.0, 1.0}, {2.0, 2.0}}).castTo(networkDtype)).build();
|
||||
ol = Yolo2OutputLayer.builder().boundingBoxes(Nd4j.create(new double[][]{{1.0, 1.0}, {2.0, 2.0}}).castTo(networkDtype)).build();
|
||||
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(14).activation(Activation.TANH).build();
|
||||
break;
|
||||
default:
|
||||
|
@ -817,8 +817,8 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
.convolutionMode(ConvolutionMode.Same)
|
||||
.updater(new Adam(1e-2))
|
||||
.list()
|
||||
.layer(new SpaceToBatchLayer.Builder().blocks(1, 1).build())
|
||||
.layer(new SpaceToDepthLayer.Builder().blocks(2).build())
|
||||
.layer(SpaceToBatchLayer.builder().blockSize(1, 1).build())
|
||||
.layer(SpaceToDepthLayer.builder().blockSize(2).build())
|
||||
.layer(OutputLayer.builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.inputType(InputType.convolutional(28, 28, 5))
|
||||
.build();
|
||||
|
@ -907,7 +907,7 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
.layer(Bidirectional.builder(LSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()).build())
|
||||
.layer(new TimeDistributed(DenseLayer.builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
|
||||
.layer(SimpleRnn.builder().nIn(5).nOut(5).build())
|
||||
.layer(new MaskZeroLayer.Builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
|
||||
.layer(MaskZeroLayer.builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskingValue(0.0).build())
|
||||
.layer(secondLast)
|
||||
.layer(ol)
|
||||
.build();
|
||||
|
@ -986,7 +986,7 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
.updater(new NoOp())
|
||||
.dist(new UniformDistribution(-6, 6))
|
||||
|
||||
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel)
|
||||
.layer(PrimaryCapsules.builder(primaryCapsDim, primarpCapsChannel)
|
||||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.build())
|
||||
|
@ -1400,9 +1400,9 @@ public class DTypeTests extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(LSTM.builder().nOut(layerSize).build())
|
||||
.layer(new SelfAttentionLayer.Builder().nOut(8).nHeads(2).projectInput(true).build())
|
||||
.layer(new LearnedSelfAttentionLayer.Builder().nOut(8).nHeads(2).nQueries(numQueries).projectInput(true).build())
|
||||
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||
.layer(SelfAttentionLayer.builder().nOut(8).nHeads(2).projectInput(true).build())
|
||||
.layer(LearnedSelfAttentionLayer.builder().nOut(8).nHeads(2).nQueries(numQueries).projectInput(true).build())
|
||||
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
|
|
|
@ -161,7 +161,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
|
|||
imageHeight))
|
||||
.addLayer("conv1", ConvolutionLayer.builder()
|
||||
.kernelSize(kernelHeight, kernelWidth).stride(1, 1)
|
||||
.dataFormat(CNN2DFormat.NCHW)
|
||||
.convFormat(CNN2DFormat.NCHW)
|
||||
.nIn(nChannels).nOut(2).weightInit(WeightInit.XAVIER)
|
||||
.activation(Activation.RELU).build(), "input")
|
||||
.addLayer("pool1",
|
||||
|
|
|
@ -1163,7 +1163,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
"act")
|
||||
.addLayer("drop", DropoutLayer.builder(0.5).build(), "pool")
|
||||
.addLayer("dense", DenseLayer.builder().nIn(1).nOut(1).build(), "drop")
|
||||
.addLayer("loss", LossLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.addLayer("loss", LossLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT.getILossFunction())
|
||||
.build(), "dense")
|
||||
.allowDisconnected(true)
|
||||
.setOutputs("loss").build();
|
||||
|
@ -1457,7 +1457,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("0", SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build(), "in")
|
||||
.layer("1", LossLayer.builder().lossFunction().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build(), "0")
|
||||
.layer("1", LossLayer.builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build(), "0")
|
||||
.setOutputs("1")
|
||||
.setInputTypes(InputType.convolutionalFlat(28,28,1))
|
||||
.build();
|
||||
|
@ -1791,7 +1791,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
.nIn(10).nOut(5)
|
||||
.activation(Activation.TANH)
|
||||
.dropOut(new GaussianNoise(0.05))
|
||||
.build())
|
||||
.build()).build()
|
||||
,"merge")
|
||||
.addLayer("out1",
|
||||
RnnOutputLayer.builder().activation(Activation.SOFTMAX)
|
||||
|
@ -1986,10 +1986,10 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
.updater(new Adam())
|
||||
.graphBuilder()
|
||||
.addInputs("x_emb")
|
||||
.addLayer("agg_lstm", Bidirectional.builder(CONCAT, LSTM.builder().nOut(hiddenSize/2).build()), "x_emb")
|
||||
.addLayer("agg_lstm", Bidirectional.builder(CONCAT, LSTM.builder().nOut(hiddenSize/2).build()).build(), "x_emb")
|
||||
.addLayer("agg_att", DenseLayer.builder().nIn(100).nOut(1).activation(Activation.SOFTMAX).build(), "agg_lstm")
|
||||
.addVertex("att", new PreprocessorVertex(new ComposableInputPreProcessor(new FeedForwardToRnnPreProcessor(), new PermutePreprocessor(0,2,1), new RnnToFeedForwardPreProcessor())), "agg_att")
|
||||
.addLayer("att_repeat", new RepeatVector.Builder(hiddenSize).build(),"att")
|
||||
.addLayer("att_repeat", RepeatVector.builder().repetitionFactor(hiddenSize).build(),"att")
|
||||
.addVertex("att_trans", new PreprocessorVertex(new PermutePreprocessor(0, 2, 1)), "att_repeat")
|
||||
.addVertex("mult", new ElementWiseVertex(ElementWiseVertex.Op.Product), "agg_lstm", "att_trans")
|
||||
.addLayer("sum", GlobalPoolingLayer.builder().build(), "mult")
|
||||
|
@ -2197,16 +2197,16 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
.addInputs("in")
|
||||
.layer("l0", ConvolutionLayer.builder()
|
||||
.nOut(16)
|
||||
.dataFormat(CNN2DFormat.NHWC)
|
||||
.convFormat(CNN2DFormat.NHWC)
|
||||
.kernelSize(2,2).stride(1,1)
|
||||
.build(), "in")
|
||||
.layer("l1", ConvolutionLayer.builder()
|
||||
.nOut(8)
|
||||
.dataFormat(CNN2DFormat.NHWC)
|
||||
.convFormat(CNN2DFormat.NHWC)
|
||||
.kernelSize(2,2).stride(1,1)
|
||||
.build(), "in")
|
||||
.addVertex("merge", new MergeVertex(), "l0", "l1")
|
||||
.layer("out", CnnLossLayer.builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
|
||||
.layer("out", CnnLossLayer.builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build(), "merge")
|
||||
.setOutputs("out")
|
||||
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
|
||||
.build();
|
||||
|
|
|
@ -357,7 +357,7 @@ public class ActivationLayerTest extends BaseDL4JTest {
|
|||
.activation(Activation.RATIONALTANH)
|
||||
|
||||
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
||||
.layer(ActivationLayer.builder())
|
||||
.layer(ActivationLayer.builder().build())
|
||||
.layer(ActivationLayer.builder().build())
|
||||
.layer(ActivationLayer.builder().activation(Activation.ELU).build())
|
||||
.layer(
|
||||
|
@ -404,7 +404,7 @@ public class ActivationLayerTest extends BaseDL4JTest {
|
|||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in")
|
||||
.addLayer("1", ActivationLayer.builder(), "0")
|
||||
.addLayer("1", ActivationLayer.builder().build(), "0")
|
||||
.addLayer("2", ActivationLayer.builder().build(), "1")
|
||||
.addLayer("3", ActivationLayer.builder().activation(Activation.ELU).build(), "2")
|
||||
.addLayer(
|
||||
|
|
|
@ -63,7 +63,7 @@ public class CapsNetMNISTTest extends BaseDL4JTest {
|
|||
.kernelSize(9, 9)
|
||||
.stride(3, 3)
|
||||
.build())
|
||||
.layer(new PrimaryCapsules.Builder(8, 8)
|
||||
.layer(PrimaryCapsules.builder(8, 8)
|
||||
.kernelSize(7, 7)
|
||||
.stride(2, 2)
|
||||
.build())
|
||||
|
|
|
@ -44,7 +44,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testOutputType(){
|
||||
PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8)
|
||||
PrimaryCapsules layer = PrimaryCapsules.builder(8, 8)
|
||||
.kernelSize(7, 7)
|
||||
.stride(2, 2)
|
||||
.build();
|
||||
|
@ -57,7 +57,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testInputType(){
|
||||
PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8)
|
||||
PrimaryCapsules layer = PrimaryCapsules.builder(8, 8)
|
||||
.kernelSize(7, 7)
|
||||
.stride(2, 2)
|
||||
.build();
|
||||
|
@ -72,7 +72,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
|||
|
||||
@Test
|
||||
public void testConfig(){
|
||||
PrimaryCapsules layer1 = new PrimaryCapsules.Builder(8, 10)
|
||||
PrimaryCapsules layer1 = PrimaryCapsules.builder(8, 10)
|
||||
.kernelSize(5, 5)
|
||||
.stride(4, 4)
|
||||
.useLeakyReLU(0.5)
|
||||
|
@ -84,22 +84,22 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
|||
assertArrayEquals(new int[]{4, 4}, layer1.getStride());
|
||||
assertArrayEquals(new int[]{0, 0}, layer1.getPadding());
|
||||
assertArrayEquals(new int[]{1, 1}, layer1.getDilation());
|
||||
assertTrue(layer1.isUseRelu());
|
||||
assertEquals(0.5, layer1.getLeak(), 0.001);
|
||||
assertTrue(layer1.isUseRelU());
|
||||
assertEquals(0.5, layer1.getUseLeakyReLU(), 0.001);
|
||||
|
||||
PrimaryCapsules layer2 = new PrimaryCapsules.Builder(8, 10)
|
||||
PrimaryCapsules layer2 = PrimaryCapsules.builder(8, 10)
|
||||
.kernelSize(5, 5)
|
||||
.stride(4, 4)
|
||||
.build();
|
||||
assertFalse(layer2.isUseRelu());
|
||||
assertFalse(layer2.isUseRelU());
|
||||
|
||||
PrimaryCapsules layer3 = new PrimaryCapsules.Builder(8, 10)
|
||||
PrimaryCapsules layer3 = PrimaryCapsules.builder(8, 10)
|
||||
.kernelSize(5, 5)
|
||||
.stride(4, 4)
|
||||
.useReLU()
|
||||
.build();
|
||||
assertTrue(layer3.isUseRelu());
|
||||
assertEquals(0, layer3.getLeak(), 0.001);
|
||||
assertTrue(layer3.isUseRelU());
|
||||
assertEquals(0, layer3.getUseLeakyReLU(), 0.001);
|
||||
|
||||
}
|
||||
|
||||
|
@ -108,7 +108,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
|||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.seed(123)
|
||||
.list()
|
||||
.layer(new PrimaryCapsules.Builder(8, 10)
|
||||
.layer(PrimaryCapsules.builder(8, 10)
|
||||
.kernelSize(5, 5)
|
||||
.stride(4, 4)
|
||||
.useLeakyReLU(0.5)
|
||||
|
|
|
@ -561,7 +561,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nOut(3)
|
||||
.helperAllowFallback(false)
|
||||
.build(), format, cm, null);
|
||||
|
@ -685,14 +685,14 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
return getNetWithLayer(Deconvolution2D.builder().nOut(2)
|
||||
.activation(Activation.TANH)
|
||||
.kernelSize(2,2)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(2,2)
|
||||
.build(), format, cm, null);
|
||||
} else {
|
||||
return getNetWithLayer(Deconvolution2D.builder().nOut(2)
|
||||
.activation(Activation.TANH)
|
||||
.kernelSize(2,2)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.stride(2,2)
|
||||
.build(), format, cm, null);
|
||||
}
|
||||
|
@ -715,26 +715,26 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
|
||||
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
||||
.blocks(2)
|
||||
return getNetWithLayer(SpaceToDepthLayer.builder()
|
||||
.blockSize(2)
|
||||
.dataFormat(format)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
} else {
|
||||
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
||||
.blocks(2)
|
||||
return getNetWithLayer(SpaceToDepthLayer.builder()
|
||||
.blockSize(2)
|
||||
.build(), format, ConvolutionMode.Same, null);
|
||||
}
|
||||
}
|
||||
|
||||
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||
if (setOnLayerAlso) {
|
||||
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
||||
.blocks(2, 2)
|
||||
return getNetWithLayer(SpaceToBatchLayer.builder()
|
||||
.blockSize(2, 2)
|
||||
.dataFormat(format)
|
||||
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||
} else {
|
||||
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
||||
.blocks(2, 2)
|
||||
return getNetWithLayer(SpaceToBatchLayer.builder()
|
||||
.blockSize(2, 2)
|
||||
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||
}
|
||||
}
|
||||
|
@ -807,7 +807,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
.kernelSize(3, 3)
|
||||
.stride(2, 2)
|
||||
.activation(Activation.TANH)
|
||||
.dataFormat(format)
|
||||
.convFormat(format)
|
||||
.nOut(3)
|
||||
.helperAllowFallback(false)
|
||||
.build());
|
||||
|
@ -988,7 +988,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
|
||||
switch (i){
|
||||
case 0:
|
||||
b.layer(ConvolutionLayer.builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
|
||||
b.layer(ConvolutionLayer.builder().kernelSize(2,2).nIn(3).nOut(3).convFormat(df).build());
|
||||
b.inputType(InputType.convolutional(12,12,3,df));
|
||||
break;
|
||||
case 1:
|
||||
|
@ -996,7 +996,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
|||
b.inputType(InputType.convolutional(12,12,3,df));
|
||||
break;
|
||||
case 2:
|
||||
b.layer(Deconvolution2D.builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
||||
b.layer(Deconvolution2D.builder().convFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
||||
b.inputType(InputType.convolutional(12,12,3,df));
|
||||
break;
|
||||
case 3:
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
@ -370,7 +371,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest {
|
|||
|
||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list()
|
||||
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14
|
||||
.layer(new SpaceToBatchLayer.Builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7
|
||||
.layer(SpaceToBatchLayer.builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7
|
||||
.layer(OutputLayer.builder().nOut(3).activation(Activation.SOFTMAX).build())
|
||||
.inputType(InputType.convolutional(28, 28, 1));
|
||||
|
||||
|
@ -389,11 +390,11 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest {
|
|||
|
||||
int blocks = 2;
|
||||
|
||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list()
|
||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder()
|
||||
//(28-2+0)/2+1 = 14 -> 14x14x3 out
|
||||
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build())
|
||||
// Divide space dimensions by blocks, i.e. 14/2 = 7 -> 7x7x12 out (3x2x2 depth)
|
||||
.layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build())
|
||||
.layer(SpaceToDepthLayer.builder().blockSize(blocks).dataFormat(CNN2DFormat.NCHW).build())
|
||||
.layer(OutputLayer.builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()) // nIn of the next layer gets multiplied by 2*2.
|
||||
.inputType(InputType.convolutional(28, 28, 1));
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -71,7 +71,7 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest {
|
|||
.layer(LocallyConnected2D.builder().kernelSize(8, 8).nIn(3)
|
||||
.stride(4, 4).nOut(16).dropOut(0.5)
|
||||
.convolutionMode(ConvolutionMode.Strict)
|
||||
.setInputSize(28, 28)
|
||||
.inputSize(28, 28)
|
||||
.activation(Activation.RELU).weightInit(
|
||||
WeightInit.XAVIER)
|
||||
.build())
|
||||
|
@ -94,11 +94,10 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest {
|
|||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4)
|
||||
.updater(new Nesterovs(0.9)).dropOut(0.5)
|
||||
.list()
|
||||
.layer(LocallyConnected1D.builder().kernelSize(4).nIn(3)
|
||||
.stride(1).nOut(16).dropOut(0.5)
|
||||
.convolutionMode(ConvolutionMode.Strict)
|
||||
.setInputSize(28)
|
||||
.inputSize(28)
|
||||
.activation(Activation.RELU).weightInit(
|
||||
WeightInit.XAVIER)
|
||||
.build())
|
||||
|
|
|
@ -61,7 +61,7 @@ public class SpaceToDepthTest extends BaseDL4JTest {
|
|||
private Layer getSpaceToDepthLayer() {
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
|
||||
.layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build();
|
||||
.layer(SpaceToDepthLayer.builder().blockSize(blockSize).dataFormat(dataFormat.toFormat()).build()).build();
|
||||
return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType());
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
|
||||
package org.deeplearning4j.nn.layers.custom;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
|
@ -30,31 +32,38 @@ import org.nd4j.linalg.activations.Activation;
|
|||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
public class TestCustomActivation extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
public void testCustomActivationFn() {
|
||||
//Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works...
|
||||
@Test
|
||||
public void testCustomActivationFn() {
|
||||
// Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config
|
||||
// actually works...
|
||||
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)).list()
|
||||
.layer(0, DenseLayer.builder().nIn(10).nOut(10).activation(new CustomActivation()).build())
|
||||
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build())
|
||||
.build();
|
||||
NeuralNetConfiguration conf =
|
||||
NeuralNetConfiguration.builder()
|
||||
.updater(new Sgd(0.1))
|
||||
|
||||
String json = conf.toJson();
|
||||
String yaml = conf.toYaml();
|
||||
.layer(
|
||||
0, DenseLayer.builder().nIn(10).nOut(10).activation(new CustomActivation()).build())
|
||||
.layer(
|
||||
1,
|
||||
OutputLayer.builder()
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nIn(10)
|
||||
.nOut(10)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
// System.out.println(json);
|
||||
String json = conf.toJson();
|
||||
String yaml = conf.toYaml();
|
||||
|
||||
NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json);
|
||||
assertEquals(conf, confFromJson);
|
||||
// System.out.println(json);
|
||||
|
||||
NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml);
|
||||
assertEquals(conf, confFromYaml);
|
||||
|
||||
}
|
||||
NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json);
|
||||
assertEquals(conf, confFromJson);
|
||||
|
||||
NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml);
|
||||
assertEquals(conf, confFromYaml);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -119,7 +119,7 @@ public class TestCustomLayers extends BaseDL4JTest {
|
|||
NeuralNetConfiguration conf =
|
||||
NeuralNetConfiguration.builder().seed(12345).list()
|
||||
.layer(0, DenseLayer.builder().nIn(10).nOut(10).build())
|
||||
.layer(1, new CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.layer(1, CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nIn(10).nOut(10).build())
|
||||
.build();
|
||||
|
@ -172,7 +172,7 @@ public class TestCustomLayers extends BaseDL4JTest {
|
|||
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
||||
.graphBuilder().addInputs("in")
|
||||
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in").addLayer("1",
|
||||
new CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10)
|
||||
CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10)
|
||||
.nOut(10).activation(Activation.SOFTMAX).build(),
|
||||
"0")
|
||||
.setOutputs("1").build();
|
||||
|
|
|
@ -91,8 +91,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
|||
.l2(0.01)
|
||||
.list()
|
||||
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1,1).build())
|
||||
.layer(new Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(bbPrior)
|
||||
.layer(Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(bbPrior)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
|
@ -179,8 +179,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
|||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.list()
|
||||
.layer(ConvolutionLayer.builder().nIn(1).nOut(1).kernelSize(1,1).build())
|
||||
.layer(new Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(bbPrior)
|
||||
.layer(Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(bbPrior)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
|
@ -337,8 +337,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
|||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.list()
|
||||
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nIn(3).nOut(3).build())
|
||||
.layer(new Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(bbPriors)
|
||||
.layer(Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(bbPriors)
|
||||
.build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
|
@ -506,8 +506,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
|||
.layer(ConvolutionLayer.builder().kernelSize(5,5).stride(2,2).nOut(256).build())
|
||||
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2)/*.poolingType(SubsamplingLayer.PoolingType.AVG)*/.build())
|
||||
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(5,5).stride(1,1).nOut(depthOut).build())
|
||||
.layer(new Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(bbPriors)
|
||||
.layer(Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(bbPriors)
|
||||
.build())
|
||||
.inputType(InputType.convolutional(h,w,c))
|
||||
.build();
|
||||
|
|
|
@ -209,7 +209,7 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
|||
return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3)
|
||||
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||
} else {
|
||||
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||
return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||
}
|
||||
}
|
||||
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||
|
@ -240,7 +240,7 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
|||
}
|
||||
private MultiLayerNetwork getNetWithLayer(LayerConfiguration layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
|
||||
if (maskZeros){
|
||||
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
|
||||
layer = MaskZeroLayer.builder().maskingValue(0.).underlying(layer).build();
|
||||
}
|
||||
if(lastTimeStep){
|
||||
layer = new LastTimeStep(layer);
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
|||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
|
@ -48,17 +49,17 @@ public class TestRecurrentWeightInit extends BaseDL4JTest {
|
|||
switch (i) {
|
||||
case 0:
|
||||
b.layer(LSTM.builder().nIn(10).nOut(10)
|
||||
.weightInitRecurrent(new UniformDistribution(2, 3))
|
||||
.weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3)))
|
||||
.build());
|
||||
break;
|
||||
case 1:
|
||||
b.layer(GravesLSTM.builder().nIn(10).nOut(10)
|
||||
.weightInitRecurrent(new UniformDistribution(2, 3))
|
||||
.weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3)))
|
||||
.build());
|
||||
break;
|
||||
case 2:
|
||||
b.layer(SimpleRnn.builder().nIn(10).nOut(10)
|
||||
.weightInitRecurrent(new UniformDistribution(2, 3)).build());
|
||||
.weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3))).build());
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
|
|
|
@ -145,8 +145,8 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
|||
l2 = SimpleRnn.builder().nOut(5).build();
|
||||
break;
|
||||
case 2:
|
||||
l0 = Bidirectional.builder(LSTM.builder().nOut(5).build());
|
||||
l2 = Bidirectional.builder(LSTM.builder().nOut(5).build());
|
||||
l0 = Bidirectional.builder(LSTM.builder().nOut(5).build()).build();
|
||||
l2 = Bidirectional.builder(LSTM.builder().nOut(5).build()).build();
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Not implemented: " + rnnType);
|
||||
|
|
|
@ -67,7 +67,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
|||
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.list()
|
||||
.layer(new SameDiffConv.Builder().nIn(nIn).nOut(nOut).kernelSize(kH, kW).build())
|
||||
.layer(SameDiffConv.builder().nIn(nIn).nOut(nOut).kernelSize(kH, kW).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
|
@ -131,7 +131,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
|||
.dataType(DataType.DOUBLE)
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(new SameDiffConv.Builder()
|
||||
.layer(SameDiffConv.builder()
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.nIn(nIn)
|
||||
.nOut(nOut)
|
||||
|
@ -142,7 +142,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
|||
.activation(a)
|
||||
.hasBias(hasBias)
|
||||
.build())
|
||||
.layer(new SameDiffConv.Builder()
|
||||
.layer(SameDiffConv.builder()
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.nIn(nOut)
|
||||
.nOut(nOut)
|
||||
|
@ -273,7 +273,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
|||
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||
.list()
|
||||
.layer(new SameDiffConv.Builder()
|
||||
.layer(SameDiffConv.builder()
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.nIn(nIn)
|
||||
.nOut(nOut)
|
||||
|
@ -284,7 +284,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
|||
.activation(Activation.TANH)
|
||||
.hasBias(hasBias)
|
||||
.build())
|
||||
.layer(new SameDiffConv.Builder()
|
||||
.layer(SameDiffConv.builder()
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.nIn(nOut)
|
||||
.nOut(nOut)
|
||||
|
|
|
@ -65,7 +65,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.list()
|
||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build())
|
||||
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
|
@ -106,7 +106,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
.inferenceWorkspaceMode(wsm)
|
||||
.trainingWorkspaceMode(wsm)
|
||||
.list()
|
||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
||||
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
|
||||
.activation(a)
|
||||
.build())
|
||||
.build();
|
||||
|
@ -178,10 +178,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.seed(12345)
|
||||
.list()
|
||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
||||
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(a).build())
|
||||
.layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut)
|
||||
.layer(SameDiffDense.builder().nIn(nOut).nOut(nOut)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.activation(a).build())
|
||||
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut)
|
||||
|
@ -267,7 +267,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||
.list()
|
||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
||||
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
|
||||
.activation(a)
|
||||
.build())
|
||||
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
|
||||
|
@ -357,8 +357,8 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
.inferenceWorkspaceMode(wsm)
|
||||
.updater(new Adam(0.1))
|
||||
.list()
|
||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
|
||||
.layer(new SameDiffDense.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||
.layer(SameDiffDense.builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
|
||||
.layer(SameDiffDense.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||
.layer(OutputLayer.builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.build();
|
||||
|
@ -428,8 +428,8 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
|||
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||
.list()
|
||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).activation(a).build())
|
||||
.layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut).activation(a).build())
|
||||
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut).activation(a).build())
|
||||
.layer(SameDiffDense.builder().nIn(nOut).nOut(nOut).activation(a).build())
|
||||
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
//.inputType(InputType.feedForward(nIn)) //TODO
|
||||
|
|
|
@ -60,7 +60,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
|||
.updater(new Adam(0.01))
|
||||
.list()
|
||||
.layer(DenseLayer.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||
.layer(LossLayer.builder().lossFunction().activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build())
|
||||
.layer(LossLayer.builder().activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork netSD = new MultiLayerNetwork(confSD);
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -45,52 +46,62 @@ import java.util.*;
|
|||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@JsonIgnoreProperties({"paramShapes"})
|
||||
@NoArgsConstructor
|
||||
@SuperBuilder
|
||||
public class SameDiffConv extends SameDiffLayer {
|
||||
public static abstract class SameDiffConvBuilder<C extends SameDiffConv, B extends SameDiffConvBuilder<C, B>> extends
|
||||
SameDiffLayerBuilder<C, B> {
|
||||
public B kernelSize(int... k) {
|
||||
this.kernelSize$value = k;
|
||||
this.kernelSize$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B stride(int... s) {
|
||||
this.stride$value = s;
|
||||
this.stride$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B padding(int... p) {
|
||||
this.padding$value = p;
|
||||
this.padding$set = true;
|
||||
return self();
|
||||
}
|
||||
}
|
||||
|
||||
private static final List<String> WEIGHT_KEYS = Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
private static final List<String> BIAS_KEYS = Collections.singletonList(ConvolutionParamInitializer.BIAS_KEY);
|
||||
//Order to match 'vanilla' conv layer implementation, for easy comparison
|
||||
private static final List<String> PARAM_KEYS = Arrays.asList(ConvolutionParamInitializer.BIAS_KEY, ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
|
||||
private long nIn;
|
||||
private long nOut;
|
||||
private Activation activation;
|
||||
private int[] kernel;
|
||||
private int[] stride;
|
||||
private int[] padding;
|
||||
private ConvolutionMode cm;
|
||||
private int[] dilation;
|
||||
private boolean hasBias;
|
||||
|
||||
protected SameDiffConv(Builder b) {
|
||||
super(b);
|
||||
this.nIn = b.nIn;
|
||||
this.nOut = b.nOut;
|
||||
this.activation = b.activation;
|
||||
this.kernel = b.kernel;
|
||||
this.stride = b.stride;
|
||||
this.padding = b.padding;
|
||||
this.cm = b.cm;
|
||||
this.dilation = b.dilation;
|
||||
this.hasBias = b.hasBias;
|
||||
}
|
||||
private int nIn;
|
||||
private int nOut;
|
||||
@Builder.Default private Activation activation = Activation.TANH;
|
||||
@Builder.Default private int[] kernelSize = new int[]{2, 2};
|
||||
|
||||
@Builder.Default private int[] stride = new int[]{1, 1};
|
||||
@Builder.Default private int[] padding = new int[]{0, 0};
|
||||
@Builder.Default private int[] dilation = new int[]{1, 1};
|
||||
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Same;
|
||||
@Builder.Default private boolean hasBias = true;
|
||||
|
||||
|
||||
|
||||
private SameDiffConv(){
|
||||
//No arg constructor for Jackson/JSON serialization
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernel, stride, padding, new int[]{1, 1},
|
||||
cm, nOut, layerIndex, getName(), SameDiffConv.class);
|
||||
return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, new int[]{1, 1},
|
||||
convolutionMode, nOut, layerIndex, getName(), SameDiffConv.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
this.nIn = c.getChannels();
|
||||
this.nIn = (int) c.getChannels();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,7 +113,7 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
@Override
|
||||
public void defineParameters(SDLayerParams params) {
|
||||
params.clear();
|
||||
val weightsShape = new long[]{kernel[0], kernel[1], nIn, nOut}; //[kH, kW, iC, oC] in libnd4j
|
||||
val weightsShape = new long[]{kernelSize[0], kernelSize[1], nIn, nOut}; //[kH, kW, iC, oC] in libnd4j
|
||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||
if(hasBias) {
|
||||
val biasShape = new long[]{1, nOut};
|
||||
|
@ -113,8 +124,8 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
double fanIn = nIn * kernel[0] * kernel[1];
|
||||
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
|
||||
double fanIn = nIn * kernelSize[0] * kernelSize[1];
|
||||
double fanOut = nOut * kernelSize[0] * kernelSize[1] / ((double) stride[0] * stride[1]);
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
|
||||
paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
|
||||
|
@ -135,11 +146,11 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||
|
||||
Conv2DConfig c = Conv2DConfig.builder()
|
||||
.kH(kernel[0]).kW(kernel[1])
|
||||
.kH(kernelSize[0]).kW(kernelSize[1])
|
||||
.pH(padding[0]).pW(padding[1])
|
||||
.sH(stride[0]).sW(stride[1])
|
||||
.dH(dilation[0]).dW(dilation[1])
|
||||
.isSameMode(this.cm == ConvolutionMode.Same)
|
||||
.isSameMode(this.convolutionMode == ConvolutionMode.Same)
|
||||
.build();
|
||||
|
||||
SDVariable conv = null;
|
||||
|
@ -159,72 +170,10 @@ public class SameDiffConv extends SameDiffLayer {
|
|||
if (activation == null) {
|
||||
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
|
||||
}
|
||||
if (cm == null) {
|
||||
cm = clone.getConvolutionMode();
|
||||
if (convolutionMode == null) {
|
||||
convolutionMode = clone.getConvolutionMode();
|
||||
}
|
||||
}
|
||||
|
||||
public static class Builder extends SameDiffLayer.Builder<Builder> {
|
||||
|
||||
private int nIn;
|
||||
private int nOut;
|
||||
private Activation activation = Activation.TANH;
|
||||
private int[] kernel = new int[]{2, 2};
|
||||
|
||||
private int[] stride = new int[]{1, 1};
|
||||
private int[] padding = new int[]{0, 0};
|
||||
private int[] dilation = new int[]{1, 1};
|
||||
private ConvolutionMode cm = ConvolutionMode.Same;
|
||||
private boolean hasBias = true;
|
||||
|
||||
public Builder nIn(int nIn) {
|
||||
this.nIn = nIn;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder nOut(int nOut) {
|
||||
this.nOut = nOut;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder activation(Activation activation) {
|
||||
this.activation = activation;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder kernelSize(int... k) {
|
||||
this.kernel = k;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder stride(int... s) {
|
||||
this.stride = s;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder padding(int... p) {
|
||||
this.padding = p;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder convolutionMode(ConvolutionMode cm) {
|
||||
this.cm = cm;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder dilation(int... d) {
|
||||
this.dilation = d;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder hasBias(boolean hasBias){
|
||||
this.hasBias = hasBias;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiffConv build() {
|
||||
return new SameDiffConv(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@ package org.deeplearning4j.nn.layers.samediff.testlayers;
|
|||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
@ -40,30 +42,22 @@ import java.util.*;
|
|||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true, exclude = {"paramShapes"})
|
||||
@NoArgsConstructor()
|
||||
@JsonIgnoreProperties("paramShapes")
|
||||
@SuperBuilder
|
||||
public class SameDiffDense extends SameDiffLayer {
|
||||
|
||||
private static final List<String> W_KEYS = Collections.singletonList(DefaultParamInitializer.WEIGHT_KEY);
|
||||
private static final List<String> B_KEYS = Collections.singletonList(DefaultParamInitializer.BIAS_KEY);
|
||||
private static final List<String> PARAM_KEYS = Arrays.asList(DefaultParamInitializer.WEIGHT_KEY, DefaultParamInitializer.BIAS_KEY);
|
||||
|
||||
private Map<String,long[]> paramShapes;
|
||||
private final Map<String, long[]> paramShapes = new HashMap<>();
|
||||
|
||||
private long nIn;
|
||||
private long nOut;
|
||||
private Activation activation;
|
||||
|
||||
protected SameDiffDense(Builder builder) {
|
||||
super(builder);
|
||||
|
||||
nIn = builder.nIn;
|
||||
nOut = builder.nOut;
|
||||
activation = builder.activation;
|
||||
}
|
||||
|
||||
private SameDiffDense(){
|
||||
//No op constructor for Jackson
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
|
@ -128,31 +122,5 @@ public class SameDiffDense extends SameDiffLayer {
|
|||
return 'f';
|
||||
}
|
||||
|
||||
public static class Builder extends SameDiffLayer.Builder<Builder> {
|
||||
|
||||
private int nIn;
|
||||
private int nOut;
|
||||
|
||||
private Activation activation;
|
||||
|
||||
public Builder nIn(int nIn){
|
||||
this.nIn = nIn;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder nOut(int nOut){
|
||||
this.nOut = nOut;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder activation(Activation activation){
|
||||
this.activation = activation;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SameDiffDense build() {
|
||||
return new SameDiffDense(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ public class TestVAE extends BaseDL4JTest {
|
|||
|
||||
NeuralNetConfiguration mlc =
|
||||
NeuralNetConfiguration.builder()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.nIn(10).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)
|
||||
.build())
|
||||
.build();
|
||||
|
@ -95,7 +95,7 @@ public class TestVAE extends BaseDL4JTest {
|
|||
for (int i = 0; i < encLayerSizes.length; i++) {
|
||||
|
||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list().layer(0,
|
||||
new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder().nIn(10)
|
||||
org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder().nIn(10)
|
||||
.nOut(5).encoderLayerSizes(encLayerSizes[i]).decoderLayerSizes(13).build())
|
||||
.build();
|
||||
|
||||
|
@ -121,7 +121,7 @@ public class TestVAE extends BaseDL4JTest {
|
|||
int inputSize = 3;
|
||||
|
||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.nIn(inputSize).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
|
||||
.build();
|
||||
|
||||
|
@ -159,7 +159,7 @@ public class TestVAE extends BaseDL4JTest {
|
|||
public void testParamGradientOrderAndViews() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
|
||||
.build();
|
||||
|
||||
|
@ -217,7 +217,7 @@ public class TestVAE extends BaseDL4JTest {
|
|||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().seed(12345).list()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
|
||||
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(5).nOut(6)
|
||||
.activation(new ActivationTanH()).build())
|
||||
|
@ -269,22 +269,22 @@ public class TestVAE extends BaseDL4JTest {
|
|||
public void testJsonYaml() {
|
||||
|
||||
NeuralNetConfiguration config = NeuralNetConfiguration.builder().seed(12345).list()
|
||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.IDENTITY))
|
||||
.nIn(3).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
|
||||
.layer(1, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(1, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.TANH))
|
||||
.nIn(7).nOut(8).encoderLayerSizes(9).decoderLayerSizes(10).build())
|
||||
.layer(2, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(2, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.reconstructionDistribution(new BernoulliReconstructionDistribution()).nIn(11)
|
||||
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
||||
.layer(3, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(3, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.reconstructionDistribution(new ExponentialReconstructionDistribution(Activation.TANH))
|
||||
.nIn(11).nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
||||
.layer(4, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(4, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.lossFunction(new ActivationTanH(), LossFunctions.LossFunction.MSE).nIn(11)
|
||||
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
||||
.layer(5, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
||||
.layer(5, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||
.reconstructionDistribution(new CompositeReconstructionDistribution.Builder()
|
||||
.addDistribution(5, new GaussianReconstructionDistribution())
|
||||
.addDistribution(5,
|
||||
|
|
|
@ -59,7 +59,7 @@ public class TestMemoryReports extends BaseDL4JTest {
|
|||
l.add(new Pair<>(DropoutLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
|
||||
l.add(new Pair<>(EmbeddingLayer.builder().nIn(1).nOut(20).build(), InputType.feedForward(20)));
|
||||
l.add(new Pair<>(OutputLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
|
||||
l.add(new Pair<>(LossLayer.builder().lossFunction().build(), InputType.feedForward(20)));
|
||||
l.add(new Pair<>(LossLayer.builder().build(), InputType.feedForward(20)));
|
||||
|
||||
//RNN layers:
|
||||
l.add(new Pair<>(GravesLSTM.builder().nIn(20).nOut(20).build(), InputType.recurrent(20, 30)));
|
||||
|
|
|
@ -469,7 +469,7 @@ public class WorkspaceTests extends BaseDL4JTest {
|
|||
.addLayer("a", GravesLSTM.builder().nOut(300).activation(Activation.HARDTANH).build(), "embeddings")
|
||||
.addVertex("b", new LastTimeStepVertex("in"), "a")
|
||||
.addLayer("c", DenseLayer.builder().nOut(300).activation(Activation.HARDTANH).build(), "b")
|
||||
.addLayer("output", LossLayer.builder().lossFunction().lossFunction(LossFunctions.LossFunction.COSINE_PROXIMITY).build(), "c")
|
||||
.addLayer("output", LossLayer.builder().lossFunction(LossFunctions.LossFunction.COSINE_PROXIMITY.getILossFunction()).build(), "c")
|
||||
.setOutputs("output")
|
||||
.build();
|
||||
|
||||
|
|
|
@ -1455,10 +1455,10 @@ public class MultiLayerTest extends BaseDL4JTest {
|
|||
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||
.l2(0.01)
|
||||
.list()
|
||||
|
||||
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1, 1).build())
|
||||
.layer(new Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(bbPrior)
|
||||
.layer(Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(bbPrior)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
|
|
|
@ -500,10 +500,10 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
|||
.addInputs(inputName)
|
||||
.setOutputs(outputName)
|
||||
.setInputTypes(InputType.inferInputTypes(input))
|
||||
.addLayer(firstConv, new Convolution2D.Builder(3, 3)
|
||||
.addLayer(firstConv, Convolution2D.builder(3, 3)
|
||||
.nOut(10)
|
||||
.build(), inputName)
|
||||
.addLayer(secondConv, new Convolution2D.Builder(1, 1)
|
||||
.addLayer(secondConv, Convolution2D.builder(1, 1)
|
||||
.nOut(3)
|
||||
.build(), firstConv)
|
||||
.addLayer(outputName, OutputLayer.builder()
|
||||
|
@ -546,11 +546,11 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
|||
.addInputs(inputName)
|
||||
.setOutputs(outputName)
|
||||
.setInputTypes(InputType.inferInputTypes(input))
|
||||
.addLayer(changeNoutName, new Convolution2D.Builder(1, 1)
|
||||
.addLayer(changeNoutName, Convolution2D.builder(1, 1)
|
||||
.nOut(10)
|
||||
.build(), inputName)
|
||||
.addLayer(poolName, SubsamplingLayer.builder(1,1).build(), changeNoutName)
|
||||
.addLayer(afterPoolName, new Convolution2D.Builder(1, 1)
|
||||
.addLayer(afterPoolName, Convolution2D.builder(1, 1)
|
||||
.nOut(7)
|
||||
.build(), poolName)
|
||||
.addLayer(outputName, OutputLayer.builder()
|
||||
|
@ -583,7 +583,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
|||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("l0", LSTM.builder().nIn(5).nOut(5).build(), "in")
|
||||
.layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
|
||||
.layer("l1", RecurrentAttentionLayer.builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
|
||||
.layer("out", RnnOutputLayer.builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||
.setOutputs("out")
|
||||
.build();
|
||||
|
|
|
@ -28,7 +28,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
|
@ -214,9 +213,9 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
|||
|
||||
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
|
||||
(NeuralNetConfiguration) overallConf.clone()
|
||||
.layer(0, new Builder().nIn(4).nOut(3).build())
|
||||
.layer(1, new Builder().nIn(3).nOut(2).build())
|
||||
.layer(2, new Builder().nIn(2).nOut(3).build())
|
||||
.layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
|
||||
.layer(1, DenseLayer.builder().nIn(3).nOut(2).build())
|
||||
.layer(2, DenseLayer.builder().nIn(2).nOut(3).build())
|
||||
.layer(3, OutputLayer.builder(
|
||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||
.build())
|
||||
|
@ -233,7 +232,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
|||
Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
|
||||
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
||||
(NeuralNetConfiguration) overallConf.clone().list()
|
||||
.layer(0, new Builder().nIn(2).nOut(3).build())
|
||||
.layer(0, DenseLayer.builder().nIn(2).nOut(3).build())
|
||||
.layer(1, OutputLayer.builder(
|
||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||
.build())
|
||||
|
|
|
@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
|||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder;
|
||||
|
||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
|
||||
|
@ -74,7 +74,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
|||
|
||||
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
|
||||
(NeuralNetConfiguration) confToChange.list()
|
||||
.layer(0, new Builder().nIn(4).nOut(3).build())
|
||||
.layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
|
||||
.layer(1, OutputLayer.builder(
|
||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||
.build())
|
||||
|
@ -101,7 +101,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
|||
.updater(new RmsProp(0.5)).l2(0.4);
|
||||
|
||||
MultiLayerNetwork expectedModel = new MultiLayerNetwork((NeuralNetConfiguration) confSet.list()
|
||||
.layer(0, new Builder().nIn(4).nOut(3).build())
|
||||
.layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
|
||||
.layer(1, OutputLayer.builder(
|
||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||
.build())
|
||||
|
@ -651,8 +651,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
|||
.weightInit(new ConstantDistribution(666))
|
||||
.list()
|
||||
.inputType(InputType.inferInputTypes(input)[0])
|
||||
.layer(new Convolution2D.Builder(3, 3).nOut(10).build())
|
||||
.layer(new Convolution2D.Builder(1, 1).nOut(3).build())
|
||||
.layer(Convolution2D.builder(3, 3).nOut(10).build())
|
||||
.layer(Convolution2D.builder(1, 1).nOut(3).build())
|
||||
.layer(OutputLayer.builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE)
|
||||
.build()).build());
|
||||
net.init();
|
||||
|
@ -682,9 +682,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork net = new MultiLayerNetwork( NeuralNetConfiguration.builder()
|
||||
.list()
|
||||
.inputType(InputType.inferInputTypes(input)[0])
|
||||
.layer(new Convolution2D.Builder(1, 1).nOut(10).build())
|
||||
.layer(Convolution2D.builder(1, 1).nOut(10).build())
|
||||
.layer(SubsamplingLayer.builder(1,1).build())
|
||||
.layer(new Convolution2D.Builder(1, 1).nOut(7).build())
|
||||
.layer(Convolution2D.builder(1, 1).nOut(7).build())
|
||||
.layer(OutputLayer.builder().activation(Activation.SOFTMAX).nOut(2).build())
|
||||
.build());
|
||||
net.init();
|
||||
|
@ -712,7 +712,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
|||
.weightInit(WeightInit.XAVIER)
|
||||
.list()
|
||||
.layer(LSTM.builder().nOut(8).build())
|
||||
.layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build())
|
||||
.layer( SelfAttentionLayer.builder().nOut(4).nHeads(2).projectInput(true).build())
|
||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||
.layer(OutputLayer.builder().nOut(2).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
|
|
|
@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
|
|||
.graphBuilder()
|
||||
.addInputs(inputName)
|
||||
.setOutputs(output)
|
||||
.layer(conv, new Convolution1DLayer.Builder(7)
|
||||
.layer(conv, Convolution1DLayer.builder(7)
|
||||
.convolutionMode(ConvolutionMode.Same)
|
||||
.nOut(input.size(1))
|
||||
.weightInit(new WeightInitIdentity())
|
||||
|
@ -115,7 +115,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
|
|||
.weightInit(new WeightInitIdentity())
|
||||
.activation(new ActivationIdentity())
|
||||
.build(), inputName)
|
||||
.layer(output, new Cnn3DLossLayer.Builder(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv)
|
||||
.layer(output, Cnn3DLossLayer.builder().dataFormat(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv)
|
||||
.build());
|
||||
graph.init();
|
||||
|
||||
|
|
|
@ -249,7 +249,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
|
|||
int nOut = 6;
|
||||
|
||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01).l2(0.01)
|
||||
.updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
|
||||
.updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER)
|
||||
.layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build())
|
||||
.layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder()
|
||||
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
|
||||
|
|
|
@ -56,10 +56,10 @@ public class KerasSpaceToDepth extends KerasLayer {
|
|||
|
||||
// TODO: we hard-code block size here to import YOLO9000. This size is not available as property
|
||||
// in the hdf5 file outside of the serialized lambda function (that we can't really well deserialize).
|
||||
SpaceToDepthLayer.Builder builder = new SpaceToDepthLayer.Builder()
|
||||
.blocks(2)
|
||||
var builder = SpaceToDepthLayer.builder()
|
||||
.blockSize(2)
|
||||
//the default data format is tensorflow/NWHC for keras import
|
||||
.dataFormat(SpaceToDepthLayer.DataFormat.NHWC)
|
||||
.dataFormat(SpaceToDepthLayer.DataFormat.NHWC.toFormat())
|
||||
.name(name);
|
||||
|
||||
this.layer = builder.build();
|
||||
|
|
|
@ -63,7 +63,7 @@ public class KerasUpsampling3D extends KerasLayer {
|
|||
int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 3, conf);
|
||||
// TODO: make sure to allow different sizes.
|
||||
|
||||
Upsampling3D.Builder builder = new Upsampling3D.Builder()
|
||||
var builder = Upsampling3D.builder()
|
||||
.name(this.name)
|
||||
.dropOut(this.dropout)
|
||||
.size(size[0]);
|
||||
|
|
|
@ -59,7 +59,7 @@ public class KerasLRN extends KerasLayer {
|
|||
super(layerConfig, enforceTrainingConfig);
|
||||
Map<String, Object> lrnParams = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||
|
||||
LocalResponseNormalization.Builder builder = LocalResponseNormalization.builder().name(this.name)
|
||||
var builder = LocalResponseNormalization.builder().name(this.name)
|
||||
.dropOut(this.dropout).alpha((double) lrnParams.get("alpha"))
|
||||
.beta((double) lrnParams.get("beta")).k((int) lrnParams.get("k")).n((int) lrnParams.get("n"));
|
||||
this.layer = builder.build();
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu
|
|||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
@ -98,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution {
|
|||
LocallyConnected1D.LocallyConnected1DBuilder builder = LocallyConnected1D.builder().name(this.name)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
||||
.weightInit(WeightInit.valueOf(conf.getKERAS_PARAM_NAME_W()))
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||
|
|
|
@ -99,7 +99,7 @@ public class KerasLocallyConnected2D extends KerasConvolution {
|
|||
LocallyConnected2D.LocallyConnected2DBuilder builder = LocallyConnected2D.builder().name(this.name)
|
||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||
.activation(getActivationFromConfig(layerConfig, conf))
|
||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
||||
.weightInit(init.enumValue())
|
||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||
|
|
|
@ -130,12 +130,14 @@ public class KerasBatchNormalization extends KerasLayer {
|
|||
BatchNormalization.BatchNormalizationBuilder builder =BatchNormalization.builder()
|
||||
.name(this.name)
|
||||
.dropOut(this.dropout)
|
||||
.minibatch(true)
|
||||
|
||||
.isMinibatch(true)
|
||||
.lockGammaBeta(false)
|
||||
.useLogStd(false)
|
||||
.decay(getMomentumFromConfig(layerConfig))
|
||||
.eps(getEpsFromConfig(layerConfig));
|
||||
if (betaConstraint != null)
|
||||
|
||||
builder.constrainBeta(betaConstraint);
|
||||
if (gammaConstraint != null)
|
||||
builder.constrainGamma(gammaConstraint);
|
||||
|
|
|
@ -58,11 +58,11 @@ public class KerasModelImportTest extends BaseDL4JTest {
|
|||
MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5");
|
||||
List<LayerConfiguration> layerConfigs = model.getNetConfiguration().getFlattenedLayerConfigurations();
|
||||
ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0);
|
||||
assertEquals(CNN2DFormat.NCHW,convolutionLayer.getDataFormat());
|
||||
assertEquals(CNN2DFormat.NCHW,convolutionLayer.getConvFormat());
|
||||
SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1);
|
||||
assertEquals(CNN2DFormat.NHWC,subsamplingLayer.getDataFormat());
|
||||
ConvolutionLayer convolutionLayer1 = (ConvolutionLayer) layerConfigs.get(2);
|
||||
assertEquals(CNN2DFormat.NHWC,convolutionLayer1.getDataFormat());
|
||||
assertEquals(CNN2DFormat.NHWC,convolutionLayer1.getConvFormat());
|
||||
|
||||
model.output(Nd4j.zeros(1,1,28,28));
|
||||
assertNotNull(model);
|
||||
|
|
|
@ -60,8 +60,8 @@ public class KerasYolo9000PredictTest extends BaseDL4JTest {
|
|||
|
||||
ComputationGraph model = new TransferLearning.GraphBuilder(graph)
|
||||
.addLayer("outputs",
|
||||
new org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.Builder()
|
||||
.boundingBoxPriors(priors)
|
||||
org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.builder()
|
||||
.boundingBoxes(priors)
|
||||
.build(),
|
||||
"conv2d_23")
|
||||
.setOutputs("outputs")
|
||||
|
|
|
@ -126,7 +126,7 @@ public class KerasLocallyConnected1DTest extends BaseDL4JTest {
|
|||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getDropOut());
|
||||
assertEquals(KERNEL_SIZE, layer.getKernel());
|
||||
assertEquals(KERNEL_SIZE, layer.getKernelSize());
|
||||
assertEquals(STRIDE, layer.getStride());
|
||||
assertEquals(N_OUT, layer.getNOut());
|
||||
assertEquals(ConvolutionMode.Truncate, layer.getConvolutionMode());
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
package net.brutex.ai.dnn.api;
|
||||
|
||||
|
||||
public interface ILayerConfiguration {
|
||||
|
||||
|
||||
|
|
|
@ -561,13 +561,12 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
|||
|
||||
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
|
||||
|
||||
public B activation(IActivation activation) {
|
||||
public B activation(Activation activation) {
|
||||
this.activation = activation;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B activation(Activation activation) {
|
||||
this.activation = activation.getActivationFunction();
|
||||
public B activation(IActivation activation) {
|
||||
this.activation = activation;
|
||||
return self();
|
||||
}
|
||||
/**
|
||||
|
|
|
@ -157,9 +157,9 @@ public class AttentionVertex extends SameDiffVertex {
|
|||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||
|
||||
attention = sameDiff.nn.multiHeadDotProductAttention(getLayerName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
|
||||
attention = sameDiff.nn.multiHeadDotProductAttention(getName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
|
||||
}else{
|
||||
attention = sameDiff.nn.dotProductAttention(getLayerName(), queries, keys, values, mask, true);
|
||||
attention = sameDiff.nn.dotProductAttention(getName(), queries, keys, values, mask, true);
|
||||
}
|
||||
|
||||
if(maskVars != null){
|
||||
|
|
|
@ -53,11 +53,11 @@ public class ActivationLayer extends NoParamLayer {
|
|||
public static ActivationLayerBuilder<?, ?> builder(Activation activation) {
|
||||
return innerBuilder().activation(activation);
|
||||
}
|
||||
|
||||
public static ActivationLayerBuilder<?, ?> builder(IActivation activation) {
|
||||
return innerBuilder().activation(activation);
|
||||
}
|
||||
|
||||
|
||||
public static ActivationLayerBuilder<?, ?> builder() {
|
||||
return innerBuilder();
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.List;
|
|||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||
import net.brutex.ai.dnn.api.LayerType;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
|
@ -80,19 +81,38 @@ public class BatchNormalization extends FeedForwardLayer {
|
|||
@lombok.Builder.Default protected boolean isMinibatch = true;
|
||||
|
||||
/**
|
||||
* Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
||||
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
||||
* 1.0
|
||||
*
|
||||
* @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode
|
||||
*/
|
||||
@lombok.Builder.Default protected double gamma = 1.0;
|
||||
/**
|
||||
* Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
||||
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
||||
* 0.0
|
||||
*
|
||||
* @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode
|
||||
*/
|
||||
@lombok.Builder.Default protected double beta = 0.0;
|
||||
/**
|
||||
* Set constraints to be applied to the beta parameter of this batch normalisation layer. Default: no
|
||||
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
|
||||
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
||||
* been updated.
|
||||
*
|
||||
*/
|
||||
protected List<LayerConstraint> betaConstraints;
|
||||
|
||||
/**
|
||||
* Set constraints to be applied to the gamma parameter of this batch normalisation layer. Default: no
|
||||
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
|
||||
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
||||
* been updated.
|
||||
*
|
||||
*/
|
||||
protected List<LayerConstraint> gammaConstraints;
|
||||
|
||||
|
||||
/**
|
||||
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed?
|
||||
* If set to false, an exception in the helper will be propagated back to the user. If true, the built-in
|
||||
|
@ -298,6 +318,15 @@ public class BatchNormalization extends FeedForwardLayer {
|
|||
this.cudnnAllowFallback$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B constrainBeta(LayerConstraint ... constraints) {
|
||||
this.betaConstraints = List.of(constraints);
|
||||
return self();
|
||||
}
|
||||
public B constrainGamma(LayerConstraint ... constraints) {
|
||||
this.gammaConstraints = List.of(constraints);
|
||||
return self();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -51,6 +51,16 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
||||
public class Convolution1DLayer extends ConvolutionLayer {
|
||||
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
*
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default
|
||||
protected CNN2DFormat dataFormat =
|
||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||
/**
|
||||
* Size of the convolution
|
||||
*
|
||||
|
|
|
@ -26,10 +26,7 @@ import lombok.*;
|
|||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.*;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.layers.convolution.Convolution3DLayer;
|
||||
import org.deeplearning4j.nn.params.Convolution3DParamInitializer;
|
||||
|
|
|
@ -65,6 +65,18 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
* details Default is {@link ConvolutionMode}.Truncate.
|
||||
*/
|
||||
@Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
*
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default
|
||||
protected CNN2DFormat convFormat =
|
||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||
|
||||
/**
|
||||
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated
|
||||
* convolutions, which are also known as atrous convolutions.
|
||||
|
@ -86,16 +98,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
* false, the built-in (non-CuDNN) implementation for ConvolutionLayer will be used
|
||||
*/
|
||||
@Builder.Default protected boolean cudnnAllowFallback = true;
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
*
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default
|
||||
protected CNN2DFormat dataFormat =
|
||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||
|
||||
/** Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */
|
||||
@Builder.Default protected AlgoMode cudnnAlgoMode = AlgoMode.PREFER_FASTEST;
|
||||
|
||||
|
@ -179,7 +182,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
nOut,
|
||||
layerIndex,
|
||||
getName(),
|
||||
dataFormat,
|
||||
convFormat,
|
||||
ConvolutionLayer.class);
|
||||
}
|
||||
|
||||
|
@ -196,11 +199,11 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
|||
if (!defaultValueOverriden || nIn <= 0 || override) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
this.nIn = c.getChannels();
|
||||
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||
this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||
}
|
||||
|
||||
if (dataFormat == null || override)
|
||||
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||
if (convFormat == null || override)
|
||||
this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -53,6 +53,16 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
|
|||
*/
|
||||
@Builder.Default
|
||||
protected int depthMultiplier = 1;
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
*
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default
|
||||
protected CNN2DFormat dataFormat =
|
||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
|||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
|
||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -63,7 +64,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
|
|||
*/
|
||||
public static EmbeddingLayerBuilder<?, ?> builder() {
|
||||
return innerBuilder()
|
||||
.activation(new ActivationIdentity());
|
||||
.activation(Activation.IDENTITY);
|
||||
}
|
||||
|
||||
public static abstract class EmbeddingLayerBuilder<C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C,B>>
|
||||
|
|
|
@ -88,7 +88,7 @@ public abstract class LayerConfiguration
|
|||
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
|
||||
*/
|
||||
@Builder.Default
|
||||
@Getter @Setter private IActivation activation = new ActivationIdentity();
|
||||
@Getter @Setter private IActivation activation = Activation.IDENTITY;
|
||||
|
||||
/**
|
||||
* Get the activation interface (function) from the activation. The activation must have been set
|
||||
|
@ -333,9 +333,9 @@ public abstract class LayerConfiguration
|
|||
runInheritance(getNetConfiguration());
|
||||
}
|
||||
|
||||
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
||||
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
||||
public B activation(Activation activation) {
|
||||
this.activation$value = activation.getActivationFunction();
|
||||
this.activation$value = activation;
|
||||
this.activation$set = true;
|
||||
return self();
|
||||
}
|
||||
|
@ -344,6 +344,7 @@ public abstract class LayerConfiguration
|
|||
this.activation$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B dropOut(double d) {
|
||||
this.dropOut = new Dropout(d);
|
||||
return self();
|
||||
|
@ -352,6 +353,14 @@ public abstract class LayerConfiguration
|
|||
this.dropOut = d;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B constrainBias(LayerConstraint constraint) {
|
||||
return this.biasConstraints(List.of(constraint));
|
||||
}
|
||||
|
||||
public B constrainWeights(LayerConstraint constraint) {
|
||||
return this.weightConstraints(List.of(constraint));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.api.MaskState;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
|
@ -32,223 +34,161 @@ import org.nd4j.autodiff.samediff.SDIndex;
|
|||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuperBuilder(buildMethodName = "initBuild")
|
||||
public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
||||
private long nIn;
|
||||
private long nOut;
|
||||
private int nHeads;
|
||||
private long headSize;
|
||||
private boolean projectInput;
|
||||
private int nQueries;
|
||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
||||
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
||||
private static final String WEIGHT_QUERIES = "Q";
|
||||
/** Number of inputs to the layer (input size) */
|
||||
private int nIn;
|
||||
/** Number of outputs (output size) */
|
||||
private int nOut;
|
||||
/** Number of Attention Heads */
|
||||
private int nHeads;
|
||||
/** Size of attention heads */
|
||||
private int headSize;
|
||||
/** Project input before applying attention or not. */
|
||||
private boolean projectInput;
|
||||
/** Number of queries to learn */
|
||||
private int nQueries;
|
||||
|
||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
||||
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
||||
private static final String WEIGHT_QUERIES = "Q";
|
||||
private LearnedSelfAttentionLayer() {
|
||||
/*No arg constructor for serialization*/
|
||||
}
|
||||
|
||||
private LearnedSelfAttentionLayer(){/*No arg constructor for serialization*/}
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getName());
|
||||
}
|
||||
|
||||
protected LearnedSelfAttentionLayer(Builder builder){
|
||||
super(builder);
|
||||
nIn = builder.nIn;
|
||||
nOut = builder.nOut;
|
||||
nHeads = builder.nHeads;
|
||||
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
|
||||
projectInput = builder.projectInput;
|
||||
nQueries = builder.nQueries;
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for Learned Self Attention layer (layer name = \""
|
||||
+ getName()
|
||||
+ "\"): expect RNN input type with size > 0. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getName());
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = (int) r.getSize();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for Learned Self Attention layer (layer index = "
|
||||
+ layerIndex
|
||||
+ ", layer name = \""
|
||||
+ getName()
|
||||
+ "\"): expect RNN input type with size > 0. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer name = \"" + getName()
|
||||
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
|
||||
}
|
||||
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = r.getSize();
|
||||
}
|
||||
if (projectInput) {
|
||||
return InputType.recurrent(nOut, nQueries);
|
||||
} else {
|
||||
return InputType.recurrent(nIn, nQueries);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer index = " + layerIndex
|
||||
+ ", layer name = \"" + getName() + "\"): expect RNN input type with size > 0. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
@Override
|
||||
public void defineParameters(SDLayerParams params) {
|
||||
params.clear();
|
||||
|
||||
if(projectInput){
|
||||
return InputType.recurrent(nOut, nQueries);
|
||||
}else{
|
||||
return InputType.recurrent(nIn, nQueries);
|
||||
}
|
||||
params.addWeightParam(WEIGHT_QUERIES, 1, nIn, nQueries);
|
||||
|
||||
if (projectInput) {
|
||||
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, nHeads * headSize, nOut);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void defineParameters(SDLayerParams params) {
|
||||
params.clear();
|
||||
|
||||
params.addWeightParam(WEIGHT_QUERIES, 1, nIn, nQueries);
|
||||
|
||||
if(projectInput){
|
||||
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, nHeads * headSize, nOut);
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if (e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)) {
|
||||
WeightInitUtil.initWeights(
|
||||
nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
} else if (e.getKey().equals(WEIGHT_QUERIES)) {
|
||||
WeightInitUtil.initWeights(
|
||||
nIn, nQueries, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
} else {
|
||||
WeightInitUtil.initWeights(
|
||||
nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if(e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)){
|
||||
WeightInitUtil.initWeights(nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}else if(e.getKey().equals(WEIGHT_QUERIES)){
|
||||
WeightInitUtil.initWeights(nIn, nQueries, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}else{
|
||||
WeightInitUtil.initWeights(nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
@Override
|
||||
public SDVariable defineLayer(
|
||||
SameDiff sameDiff,
|
||||
SDVariable layerInput,
|
||||
Map<String, SDVariable> paramTable,
|
||||
SDVariable mask) {
|
||||
val baseQueries = paramTable.get(WEIGHT_QUERIES);
|
||||
val batchSize = layerInput.shape().get(SDIndex.point(0));
|
||||
val tileAxis =
|
||||
sameDiff.scatterUpdate(
|
||||
sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize);
|
||||
|
||||
val queries = sameDiff.tile(baseQueries, tileAxis);
|
||||
|
||||
if (projectInput) {
|
||||
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
||||
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
||||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||
|
||||
return sameDiff.nn.multiHeadDotProductAttention(
|
||||
getName(), queries, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
||||
} else {
|
||||
return sameDiff.nn.dotProductAttention(
|
||||
getName(), queries, layerInput, layerInput, mask, true);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(
|
||||
INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||
// No further mask propagation here, as the results have taken any mask into account, like in a
|
||||
// global pooling layer
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
val baseQueries = paramTable.get(WEIGHT_QUERIES);
|
||||
val batchSize = layerInput.shape().get(SDIndex.point(0));
|
||||
val tileAxis = sameDiff.scatterUpdate(sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize);
|
||||
public abstract static class LearnedSelfAttentionLayerBuilder<
|
||||
C extends LearnedSelfAttentionLayer, B extends LearnedSelfAttentionLayerBuilder<C, B>>
|
||||
extends SameDiffLayerBuilder<C, B> {
|
||||
public C build() {
|
||||
Preconditions.checkArgument(
|
||||
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||
Preconditions.checkArgument(
|
||||
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||
Preconditions.checkArgument(
|
||||
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||
Preconditions.checkArgument(
|
||||
this.nOut % nHeads == 0 || headSize > 0,
|
||||
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||
Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries.");
|
||||
|
||||
val queries = sameDiff.tile(baseQueries, tileAxis);
|
||||
|
||||
if(projectInput){
|
||||
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
||||
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
||||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||
|
||||
return sameDiff.nn.multiHeadDotProductAttention(getName(), queries, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
||||
}else{
|
||||
return sameDiff.nn.dotProductAttention(getName(), queries, layerInput, layerInput, mask, true);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||
// No further mask propagation here, as the results have taken any mask into account, like in a global pooling layer
|
||||
return null;
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder extends SameDiffLayer.Builder<LearnedSelfAttentionLayer.Builder> {
|
||||
|
||||
/**
|
||||
* Number of inputs to the layer (input size)
|
||||
*/
|
||||
private int nIn;
|
||||
|
||||
/**
|
||||
* Number of outputs (output size)
|
||||
*/
|
||||
private int nOut;
|
||||
|
||||
/**
|
||||
* Number of Attention Heads
|
||||
*/
|
||||
private int nHeads;
|
||||
|
||||
/**
|
||||
* Size of attention heads
|
||||
*/
|
||||
private int headSize;
|
||||
|
||||
/**
|
||||
* Project input before applying attention or not.
|
||||
*/
|
||||
private boolean projectInput;
|
||||
|
||||
|
||||
/**
|
||||
* Number of queries to learn
|
||||
*/
|
||||
private int nQueries;
|
||||
|
||||
/**
|
||||
* @param nIn Number of inputs to the layer (input size)
|
||||
*/
|
||||
public Builder nIn(int nIn) {
|
||||
this.nIn = nIn;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param nOut Number of outputs (output size)
|
||||
*/
|
||||
public Builder nOut(int nOut) {
|
||||
this.nOut = nOut;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of Attention Heads
|
||||
*/
|
||||
public Builder nHeads(int nHeads){
|
||||
this.nHeads = nHeads;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Size of attention heads
|
||||
*/
|
||||
public Builder headSize(int headSize){
|
||||
this.headSize = headSize;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Project input before applying attention or not.
|
||||
*/
|
||||
public Builder projectInput(boolean projectInput){
|
||||
this.projectInput = projectInput;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of queries to learn
|
||||
*/
|
||||
public Builder nQueries(int nQueries){
|
||||
this.nQueries = nQueries;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public LearnedSelfAttentionLayer build() {
|
||||
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||
Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries.");
|
||||
|
||||
return new LearnedSelfAttentionLayer(this);
|
||||
}
|
||||
return initBuild();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
|||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.enums.PadMode;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.IActivation;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -333,6 +334,11 @@ public class LocallyConnected2D extends SameDiffLayer {
|
|||
return self();
|
||||
}
|
||||
|
||||
public B inputSize(int ... size) {
|
||||
this.inputSize = size;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B stride(int ... stride) {
|
||||
this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride");
|
||||
this.stride$set = true;
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
||||
|
@ -37,388 +39,387 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
||||
public class PrimaryCapsules extends SameDiffLayer {
|
||||
|
||||
private int[] kernelSize;
|
||||
private int[] stride;
|
||||
private int[] padding;
|
||||
private int[] dilation;
|
||||
private int inputChannels;
|
||||
private int channels;
|
||||
private static final String WEIGHT_PARAM = "weight";
|
||||
private static final String BIAS_PARAM = "bias";
|
||||
/**
|
||||
* Sets the kernel size of the 2d convolution
|
||||
*
|
||||
* @param kernelSize
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private int[] kernelSize = new int[] {9, 9};
|
||||
/**
|
||||
* Sets the stride of the 2d convolution
|
||||
*
|
||||
* @param stride
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private int[] stride = new int[] {2, 2};
|
||||
/**
|
||||
* Sets the padding of the 2d convolution
|
||||
*
|
||||
* @param padding
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private int[] padding = new int[] {0, 0};
|
||||
/**
|
||||
* Sets the dilation of the 2d convolution
|
||||
*
|
||||
* @param dilation
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private int[] dilation = new int[] {1, 1};
|
||||
|
||||
private boolean hasBias;
|
||||
private int inputChannels;
|
||||
/**
|
||||
* Sets the number of channels to use in the 2d convolution.
|
||||
*
|
||||
* <p>Note that the actual number of channels is channels * capsuleDimensions
|
||||
*
|
||||
* <p>Does the same thing as nOut()
|
||||
*
|
||||
* @param channels
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private int channels = 32;
|
||||
|
||||
private int capsules;
|
||||
private int capsuleDimensions;
|
||||
@Builder.Default private boolean hasBias = true;
|
||||
/**
|
||||
* Usually inferred automatically.
|
||||
*
|
||||
* @param capsules
|
||||
* @return
|
||||
*/
|
||||
private int capsules;
|
||||
/**
|
||||
* Sets the number of dimensions to use in the capsules.
|
||||
*
|
||||
* @param capsuleDimensions
|
||||
* @return
|
||||
*/
|
||||
private int capsuleDimensions;
|
||||
/**
|
||||
* The convolution mode to use in the 2d convolution
|
||||
*
|
||||
* @param convolutionMode
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||
/**
|
||||
* Whether to use a ReLU activation on the 2d convolution
|
||||
*
|
||||
* @param useRelu
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private boolean useRelU = false;
|
||||
/**
|
||||
* Use a LeakyReLU activation on the 2d convolution
|
||||
*
|
||||
* @param leak the alpha value for the LeakyReLU activation.
|
||||
* @return
|
||||
*/
|
||||
@Builder.Default private double useLeakyReLU = 0;
|
||||
|
||||
private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||
public static PrimaryCapsulesBuilder<?, ?> builder() {
|
||||
return innerBuilder();
|
||||
}
|
||||
|
||||
private boolean useRelu = false;
|
||||
private double leak = 0;
|
||||
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||
int capsuleDimensions,
|
||||
int channels,
|
||||
int[] kernelSize,
|
||||
int[] stride,
|
||||
int[] padding,
|
||||
int[] dilation,
|
||||
ConvolutionMode convolutionMode) {
|
||||
return innerBuilder()
|
||||
.capsuleDimensions(capsuleDimensions)
|
||||
.channels(channels)
|
||||
.kernelSize(kernelSize)
|
||||
.stride(stride)
|
||||
.padding(padding)
|
||||
.dilation(dilation)
|
||||
.convolutionMode(convolutionMode);
|
||||
}
|
||||
|
||||
private static final String WEIGHT_PARAM = "weight";
|
||||
private static final String BIAS_PARAM = "bias";
|
||||
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||
int capsuleDimensions,
|
||||
int channels,
|
||||
int[] kernelSize,
|
||||
int[] stride,
|
||||
int[] padding,
|
||||
int[] dilation) {
|
||||
return innerBuilder()
|
||||
.capsuleDimensions(capsuleDimensions)
|
||||
.channels(channels)
|
||||
.kernelSize(kernelSize)
|
||||
.stride(stride)
|
||||
.padding(padding)
|
||||
.dilation(dilation);
|
||||
}
|
||||
|
||||
public PrimaryCapsules(Builder builder){
|
||||
super(builder);
|
||||
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||
int capsuleDimensions, int channels, int[] kernelSize, int[] stride, int[] padding) {
|
||||
return innerBuilder()
|
||||
.capsuleDimensions(capsuleDimensions)
|
||||
.channels(channels)
|
||||
.kernelSize(kernelSize)
|
||||
.stride(stride)
|
||||
.padding(padding);
|
||||
}
|
||||
|
||||
this.kernelSize = builder.kernelSize;
|
||||
this.stride = builder.stride;
|
||||
this.padding = builder.padding;
|
||||
this.dilation = builder.dilation;
|
||||
this.channels = builder.channels;
|
||||
this.hasBias = builder.hasBias;
|
||||
this.capsules = builder.capsules;
|
||||
this.capsuleDimensions = builder.capsuleDimensions;
|
||||
this.convolutionMode = builder.convolutionMode;
|
||||
this.useRelu = builder.useRelu;
|
||||
this.leak = builder.leak;
|
||||
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||
int capsuleDimensions, int channels, int[] kernelSize, int[] stride) {
|
||||
return innerBuilder()
|
||||
.capsuleDimensions(capsuleDimensions)
|
||||
.channels(channels)
|
||||
.kernelSize(kernelSize)
|
||||
.stride(stride);
|
||||
}
|
||||
|
||||
if(capsuleDimensions <= 0 || channels <= 0){
|
||||
throw new IllegalArgumentException("Invalid configuration for Primary Capsules (layer name = \""
|
||||
+ name + "\"):"
|
||||
+ " capsuleDimensions and channels must be > 0. Got: "
|
||||
+ capsuleDimensions + ", " + channels);
|
||||
}
|
||||
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||
int capsuleDimensions, int channels, int[] kernelSize) {
|
||||
return innerBuilder()
|
||||
.capsuleDimensions(capsuleDimensions)
|
||||
.channels(channels)
|
||||
.kernelSize(kernelSize);
|
||||
}
|
||||
|
||||
if(capsules < 0){
|
||||
throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \""
|
||||
+ name + "\"):"
|
||||
+ " capsules must be >= 0 if set. Got: "
|
||||
+ capsules);
|
||||
}
|
||||
public static PrimaryCapsulesBuilder<?, ?> builder(int capsuleDimensions, int channels) {
|
||||
return innerBuilder().capsuleDimensions(capsuleDimensions).channels(channels);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(
|
||||
SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
Conv2DConfig conf =
|
||||
Conv2DConfig.builder()
|
||||
.kH(kernelSize[0])
|
||||
.kW(kernelSize[1])
|
||||
.sH(stride[0])
|
||||
.sW(stride[1])
|
||||
.pH(padding[0])
|
||||
.pW(padding[1])
|
||||
.dH(dilation[0])
|
||||
.dW(dilation[1])
|
||||
.isSameMode(convolutionMode == ConvolutionMode.Same)
|
||||
.build();
|
||||
|
||||
SDVariable conved;
|
||||
|
||||
if (hasBias) {
|
||||
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf);
|
||||
} else {
|
||||
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
Conv2DConfig conf = Conv2DConfig.builder()
|
||||
.kH(kernelSize[0]).kW(kernelSize[1])
|
||||
.sH(stride[0]).sW(stride[1])
|
||||
.pH(padding[0]).pW(padding[1])
|
||||
.dH(dilation[0]).dW(dilation[1])
|
||||
.isSameMode(convolutionMode == ConvolutionMode.Same)
|
||||
.build();
|
||||
|
||||
SDVariable conved;
|
||||
|
||||
if(hasBias){
|
||||
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf);
|
||||
} else {
|
||||
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf);
|
||||
}
|
||||
|
||||
if(useRelu){
|
||||
if(leak == 0) {
|
||||
conved = SD.nn.relu(conved, 0);
|
||||
} else {
|
||||
conved = SD.nn.leakyRelu(conved, leak);
|
||||
}
|
||||
}
|
||||
|
||||
SDVariable reshaped = conved.reshape(-1, capsules, capsuleDimensions);
|
||||
return CapsuleUtils.squash(SD, reshaped, 2);
|
||||
if (useRelU) {
|
||||
if (useLeakyReLU == 0) {
|
||||
conved = SD.nn.relu(conved, 0);
|
||||
} else {
|
||||
conved = SD.nn.leakyRelu(conved, useLeakyReLU);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void defineParameters(SDLayerParams params) {
|
||||
params.clear();
|
||||
params.addWeightParam(WEIGHT_PARAM,
|
||||
kernelSize[0], kernelSize[1], inputChannels, (long) capsuleDimensions * channels);
|
||||
SDVariable reshaped = conved.reshape(-1, capsules, capsuleDimensions);
|
||||
return CapsuleUtils.squash(SD, reshaped, 2);
|
||||
}
|
||||
|
||||
if(hasBias){
|
||||
params.addBiasParam(BIAS_PARAM, (long) capsuleDimensions * channels);
|
||||
@Override
|
||||
public void defineParameters(SDLayerParams params) {
|
||||
params.clear();
|
||||
params.addWeightParam(
|
||||
WEIGHT_PARAM,
|
||||
kernelSize[0],
|
||||
kernelSize[1],
|
||||
inputChannels,
|
||||
(long) capsuleDimensions * channels);
|
||||
|
||||
if (hasBias) {
|
||||
params.addBiasParam(BIAS_PARAM, (long) capsuleDimensions * channels);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if (BIAS_PARAM.equals(e.getKey())) {
|
||||
e.getValue().assign(0);
|
||||
} else if (WEIGHT_PARAM.equals(e.getKey())) {
|
||||
double fanIn = inputChannels * kernelSize[0] * kernelSize[1];
|
||||
double fanOut =
|
||||
capsuleDimensions
|
||||
* channels
|
||||
* kernelSize[0]
|
||||
* kernelSize[1]
|
||||
/ ((double) stride[0] * stride[1]);
|
||||
WeightInitUtil.initWeights(
|
||||
fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != Type.CNN) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for Primary Capsules layer (layer name = \""
|
||||
+ name
|
||||
+ "\"): expect CNN input. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if (BIAS_PARAM.equals(e.getKey())) {
|
||||
e.getValue().assign(0);
|
||||
} else if(WEIGHT_PARAM.equals(e.getKey())){
|
||||
double fanIn = inputChannels * kernelSize[0] * kernelSize[1];
|
||||
double fanOut = capsuleDimensions * channels * kernelSize[0] * kernelSize[1] / ((double) stride[0] * stride[1]);
|
||||
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c',
|
||||
e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (capsules > 0) {
|
||||
return InputType.recurrent(capsules, capsuleDimensions);
|
||||
} else {
|
||||
|
||||
InputTypeConvolutional out =
|
||||
(InputTypeConvolutional)
|
||||
InputTypeUtil.getOutputTypeCnnLayers(
|
||||
inputType,
|
||||
kernelSize,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
convolutionMode,
|
||||
(long) capsuleDimensions * channels,
|
||||
-1,
|
||||
getName(),
|
||||
PrimaryCapsules.class);
|
||||
|
||||
return InputType.recurrent(
|
||||
(int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions),
|
||||
capsuleDimensions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
if (inputType == null || inputType.getType() != Type.CNN) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for Primary Capsules layer (layer name = \""
|
||||
+ name
|
||||
+ "\"): expect CNN input. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != Type.CNN) {
|
||||
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \""
|
||||
+ name + "\"): expect CNN input. Got: " + inputType);
|
||||
}
|
||||
InputTypeConvolutional ci = (InputTypeConvolutional) inputType;
|
||||
|
||||
if(capsules > 0){
|
||||
return InputType.recurrent(capsules, capsuleDimensions);
|
||||
} else {
|
||||
this.inputChannels = (int) ci.getChannels();
|
||||
|
||||
InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil
|
||||
.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode,
|
||||
(long) capsuleDimensions * channels, -1, getName(), PrimaryCapsules.class);
|
||||
if (capsules <= 0 || override) {
|
||||
|
||||
return InputType.recurrent((int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions),
|
||||
capsuleDimensions);
|
||||
}
|
||||
InputTypeConvolutional out =
|
||||
(InputTypeConvolutional)
|
||||
InputTypeUtil.getOutputTypeCnnLayers(
|
||||
inputType,
|
||||
kernelSize,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
convolutionMode,
|
||||
(long) capsuleDimensions * channels,
|
||||
-1,
|
||||
getName(),
|
||||
PrimaryCapsules.class);
|
||||
|
||||
this.capsules =
|
||||
(int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions);
|
||||
}
|
||||
}
|
||||
|
||||
public abstract static class PrimaryCapsulesBuilder<
|
||||
C extends PrimaryCapsules, B extends PrimaryCapsulesBuilder<C, B>>
|
||||
extends SameDiffLayerBuilder<C, B> {
|
||||
|
||||
public B kernelSize(int... kernelSize) {
|
||||
this.kernelSize$value = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
|
||||
this.kernelSize$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
if (inputType == null || inputType.getType() != Type.CNN) {
|
||||
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \""
|
||||
+ name + "\"): expect CNN input. Got: " + inputType);
|
||||
}
|
||||
|
||||
InputTypeConvolutional ci = (InputTypeConvolutional) inputType;
|
||||
|
||||
this.inputChannels = (int) ci.getChannels();
|
||||
|
||||
if(capsules <= 0 || override) {
|
||||
|
||||
InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil
|
||||
.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode,
|
||||
(long) capsuleDimensions * channels, -1, getName(), PrimaryCapsules.class);
|
||||
|
||||
this.capsules = (int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions);
|
||||
}
|
||||
public B stride(int... stride) {
|
||||
this.stride$value = ValidationUtils.validate2NonNegative(stride, true, "stride");
|
||||
this.stride$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder extends SameDiffLayer.Builder<Builder>{
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private int[] kernelSize = new int[]{9, 9};
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private int[] stride = new int[]{2, 2};
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private int[] padding = new int[]{0, 0};
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private int[] dilation = new int[]{1, 1};
|
||||
|
||||
private int channels = 32;
|
||||
|
||||
private boolean hasBias = true;
|
||||
|
||||
private int capsules;
|
||||
private int capsuleDimensions;
|
||||
|
||||
private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||
|
||||
private boolean useRelu = false;
|
||||
private double leak = 0;
|
||||
|
||||
|
||||
public void setKernelSize(int... kernelSize){
|
||||
this.kernelSize = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
|
||||
}
|
||||
|
||||
public void setStride(int... stride){
|
||||
this.stride = ValidationUtils.validate2NonNegative(stride, true, "stride");
|
||||
}
|
||||
|
||||
public void setPadding(int... padding){
|
||||
this.padding = ValidationUtils.validate2NonNegative(padding, true, "padding");
|
||||
}
|
||||
|
||||
public void setDilation(int... dilation){
|
||||
this.dilation = ValidationUtils.validate2NonNegative(dilation, true, "dilation");
|
||||
}
|
||||
|
||||
|
||||
public Builder(int capsuleDimensions, int channels,
|
||||
int[] kernelSize, int[] stride, int[] padding, int[] dilation,
|
||||
ConvolutionMode convolutionMode){
|
||||
this.capsuleDimensions = capsuleDimensions;
|
||||
this.channels = channels;
|
||||
this.setKernelSize(kernelSize);
|
||||
this.setStride(stride);
|
||||
this.setPadding(padding);
|
||||
this.setDilation(dilation);
|
||||
this.convolutionMode = convolutionMode;
|
||||
}
|
||||
|
||||
public Builder(int capsuleDimensions, int channels,
|
||||
int[] kernelSize, int[] stride, int[] padding, int[] dilation){
|
||||
this(capsuleDimensions, channels, kernelSize, stride, padding, dilation, ConvolutionMode.Truncate);
|
||||
}
|
||||
|
||||
public Builder(int capsuleDimensions, int channels,
|
||||
int[] kernelSize, int[] stride, int[] padding){
|
||||
this(capsuleDimensions, channels, kernelSize, stride, padding, new int[]{1, 1}, ConvolutionMode.Truncate);
|
||||
}
|
||||
|
||||
public Builder(int capsuleDimensions, int channels,
|
||||
int[] kernelSize, int[] stride){
|
||||
this(capsuleDimensions, channels, kernelSize, stride, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
|
||||
}
|
||||
|
||||
public Builder(int capsuleDimensions, int channels,
|
||||
int[] kernelSize){
|
||||
this(capsuleDimensions, channels, kernelSize, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
|
||||
}
|
||||
|
||||
public Builder(int capsuleDimensions, int channels){
|
||||
this(capsuleDimensions, channels, new int[]{9, 9}, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the kernel size of the 2d convolution
|
||||
*
|
||||
* @see ConvolutionLayer.Builder#kernelSize(int...)
|
||||
* @param kernelSize
|
||||
* @return
|
||||
*/
|
||||
public Builder kernelSize(int... kernelSize){
|
||||
this.setKernelSize(kernelSize);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the stride of the 2d convolution
|
||||
*
|
||||
* @see ConvolutionLayer.Builder#stride(int...)
|
||||
* @param stride
|
||||
* @return
|
||||
*/
|
||||
public Builder stride(int... stride){
|
||||
this.setStride(stride);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the padding of the 2d convolution
|
||||
*
|
||||
* @see ConvolutionLayer.Builder#padding(int...)
|
||||
* @param padding
|
||||
* @return
|
||||
*/
|
||||
public Builder padding(int... padding){
|
||||
this.setPadding(padding);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the dilation of the 2d convolution
|
||||
*
|
||||
* @see ConvolutionLayer.Builder#dilation(int...)
|
||||
* @param dilation
|
||||
* @return
|
||||
*/
|
||||
public Builder dilation(int... dilation){
|
||||
this.setDilation(dilation);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of channels to use in the 2d convolution.
|
||||
*
|
||||
* Note that the actual number of channels is channels * capsuleDimensions
|
||||
*
|
||||
* Does the same thing as nOut()
|
||||
*
|
||||
* @param channels
|
||||
* @return
|
||||
*/
|
||||
public Builder channels(int channels){
|
||||
this.channels = channels;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of channels to use in the 2d convolution.
|
||||
*
|
||||
* Note that the actual number of channels is channels * capsuleDimensions
|
||||
*
|
||||
* Does the same thing as channels()
|
||||
*
|
||||
* @param nOut
|
||||
* @return
|
||||
*/
|
||||
public Builder nOut(int nOut){
|
||||
return channels(nOut);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the number of dimensions to use in the capsules.
|
||||
* @param capsuleDimensions
|
||||
* @return
|
||||
*/
|
||||
public Builder capsuleDimensions(int capsuleDimensions){
|
||||
this.capsuleDimensions = capsuleDimensions;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Usually inferred automatically.
|
||||
* @param capsules
|
||||
* @return
|
||||
*/
|
||||
public Builder capsules(int capsules){
|
||||
this.capsules = capsules;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder hasBias(boolean hasBias){
|
||||
this.hasBias = hasBias;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* The convolution mode to use in the 2d convolution
|
||||
* @param convolutionMode
|
||||
* @return
|
||||
*/
|
||||
public Builder convolutionMode(ConvolutionMode convolutionMode){
|
||||
this.convolutionMode = convolutionMode;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether to use a ReLU activation on the 2d convolution
|
||||
* @param useRelu
|
||||
* @return
|
||||
*/
|
||||
public Builder useReLU(boolean useRelu){
|
||||
this.useRelu = useRelu;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a ReLU activation on the 2d convolution
|
||||
* @return
|
||||
*/
|
||||
public Builder useReLU(){
|
||||
return useReLU(true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a LeakyReLU activation on the 2d convolution
|
||||
* @param leak the alpha value for the LeakyReLU activation.
|
||||
* @return
|
||||
*/
|
||||
public Builder useLeakyReLU(double leak){
|
||||
this.useRelu = true;
|
||||
this.leak = leak;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <E extends LayerConfiguration> E build() {
|
||||
return (E) new PrimaryCapsules(this);
|
||||
}
|
||||
public B padding(int... padding) {
|
||||
this.padding$value = ValidationUtils.validate2NonNegative(padding, true, "padding");
|
||||
this.padding$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
public B dilation(int... dilation) {
|
||||
this.dilation$value = ValidationUtils.validate2NonNegative(dilation, true, "dilation");
|
||||
this.dilation$set = true;
|
||||
return self();
|
||||
}
|
||||
/**
|
||||
* Sets the number of channels to use in the 2d convolution.
|
||||
*
|
||||
* <p>Note that the actual number of channels is channels * capsuleDimensions
|
||||
*
|
||||
* <p>Does the same thing as channels()
|
||||
*
|
||||
* @param nOut
|
||||
* @return
|
||||
*/
|
||||
public B nOut(int nOut) {
|
||||
return channels(nOut);
|
||||
}
|
||||
/**
|
||||
* Use a ReLU activation on the 2d convolution
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public B useReLU() {
|
||||
return useRelU(true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Use a LeakyReLU activation on the 2d convolution. Implies {@link #useReLU()} set true.
|
||||
*
|
||||
* @param leak the alpha value for the LeakyReLU activation.
|
||||
* @return
|
||||
*/
|
||||
public B useLeakyReLU(double leak) {
|
||||
this.useRelU(true);
|
||||
this.useLeakyReLU$value = leak;
|
||||
this.useLeakyReLU$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
public C build() {
|
||||
C l = initBuild();
|
||||
if (capsuleDimensions <= 0 || channels$value <= 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid configuration for Primary Capsules (layer name = \""
|
||||
+ l.getName()
|
||||
+ "\"):"
|
||||
+ " capsuleDimensions and channels must be > 0. Got: "
|
||||
+ capsuleDimensions
|
||||
+ ", "
|
||||
+ channels$value);
|
||||
}
|
||||
|
||||
if (capsules < 0) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid configuration for Capsule ILayer (layer name = \""
|
||||
+ l.getName()
|
||||
+ "\"):"
|
||||
+ " capsules must be >= 0 if set. Got: "
|
||||
+ capsules);
|
||||
}
|
||||
return l;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
|
@ -41,15 +42,63 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuperBuilder(buildMethodName = "initBuild")
|
||||
public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||
private long nIn;
|
||||
private long nOut;
|
||||
|
||||
public static abstract class RecurrentAttentionLayerBuilder<C extends RecurrentAttentionLayer, B extends RecurrentAttentionLayerBuilder<C,B>>
|
||||
extends SameDiffLayerBuilder<C,B> {
|
||||
|
||||
public C build() {
|
||||
Preconditions.checkArgument(this.projectInput$value || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||
Preconditions.checkArgument(this.projectInput$value || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||
Preconditions.checkArgument(!this.projectInput$value || nOut != 0, "nOut must be specified when projectInput is true");
|
||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||
|
||||
C l = initBuild();
|
||||
return l;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of inputs to the layer (input size)
|
||||
*/
|
||||
private int nIn;
|
||||
|
||||
/**
|
||||
* Number of outputs (output size)
|
||||
*/
|
||||
private int nOut;
|
||||
|
||||
/**
|
||||
* Number of Attention Heads
|
||||
*/
|
||||
private int nHeads;
|
||||
private long headSize;
|
||||
private boolean projectInput;
|
||||
private Activation activation;
|
||||
private boolean hasBias;
|
||||
|
||||
/**
|
||||
* Size of attention heads
|
||||
*/
|
||||
private int headSize;
|
||||
|
||||
/**
|
||||
* Project input before applying attention or not.
|
||||
*/
|
||||
@Builder.Default
|
||||
private boolean projectInput = true;
|
||||
|
||||
/**
|
||||
* If true (default is true) the layer will have a bias
|
||||
*/
|
||||
@Builder.Default
|
||||
private boolean hasBias = true;
|
||||
|
||||
/**
|
||||
* Activation function for the layer
|
||||
*/
|
||||
@Builder.Default
|
||||
private Activation activation = Activation.TANH;
|
||||
|
||||
|
||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||
|
@ -60,18 +109,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
|||
private static final String RECURRENT_WEIGHT_KEY = SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY;
|
||||
private int timeSteps;
|
||||
|
||||
private RecurrentAttentionLayer(){/*No arg constructor for serialization*/}
|
||||
|
||||
protected RecurrentAttentionLayer(Builder builder){
|
||||
super(builder);
|
||||
nIn = builder.nIn;
|
||||
nOut = builder.nOut;
|
||||
nHeads = builder.nHeads;
|
||||
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
|
||||
projectInput = builder.projectInput;
|
||||
activation = builder.activation;
|
||||
hasBias = builder.hasBias;
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
|
@ -87,7 +125,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
|||
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = r.getSize();
|
||||
this.nIn = (int) r.getSize();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -206,109 +244,5 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
|||
return sameDiff.concat(2, outputSlices);
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder extends SameDiffLayer.Builder<RecurrentAttentionLayer.Builder> {
|
||||
|
||||
/**
|
||||
* Number of inputs to the layer (input size)
|
||||
*/
|
||||
private int nIn;
|
||||
|
||||
/**
|
||||
* Number of outputs (output size)
|
||||
*/
|
||||
private int nOut;
|
||||
|
||||
/**
|
||||
* Number of Attention Heads
|
||||
*/
|
||||
private int nHeads;
|
||||
|
||||
/**
|
||||
* Size of attention heads
|
||||
*/
|
||||
private int headSize;
|
||||
|
||||
/**
|
||||
* Project input before applying attention or not.
|
||||
*/
|
||||
private boolean projectInput = true;
|
||||
|
||||
/**
|
||||
* If true (default is true) the layer will have a bias
|
||||
*/
|
||||
private boolean hasBias = true;
|
||||
|
||||
/**
|
||||
* Activation function for the layer
|
||||
*/
|
||||
private Activation activation = Activation.TANH;
|
||||
|
||||
/**
|
||||
* @param nIn Number of inputs to the layer (input size)
|
||||
*/
|
||||
public Builder nIn(int nIn) {
|
||||
this.nIn = nIn;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param nOut Number of outputs (output size)
|
||||
*/
|
||||
public Builder nOut(int nOut) {
|
||||
this.nOut = nOut;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of Attention Heads
|
||||
*/
|
||||
public Builder nHeads(int nHeads){
|
||||
this.nHeads = nHeads;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Size of attention heads
|
||||
*/
|
||||
public Builder headSize(int headSize){
|
||||
this.headSize = headSize;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Project input before applying attention or not.
|
||||
*/
|
||||
public Builder projectInput(boolean projectInput){
|
||||
this.projectInput = projectInput;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param hasBias If true (default is true) the layer will have a bias
|
||||
*/
|
||||
public Builder hasBias(boolean hasBias) {
|
||||
this.hasBias = hasBias;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param activation Activation function for the layer
|
||||
*/
|
||||
public Builder activation(Activation activation) {
|
||||
this.activation = activation;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public RecurrentAttentionLayer build() {
|
||||
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||
return new RecurrentAttentionLayer(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||
|
@ -34,186 +36,130 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@NoArgsConstructor()
|
||||
@SuperBuilder(buildMethodName = "initBuild")
|
||||
public class SelfAttentionLayer extends SameDiffLayer {
|
||||
private long nIn;
|
||||
private long nOut;
|
||||
private int nHeads;
|
||||
private long headSize;
|
||||
private boolean projectInput;
|
||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
||||
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
||||
/** Number of inputs to the layer (input size) */
|
||||
private int nIn;
|
||||
/** Number of outputs (output size) */
|
||||
private int nOut;
|
||||
/** Number of Attention Heads */
|
||||
private int nHeads;
|
||||
/** Size of attention heads */
|
||||
private int headSize;
|
||||
/** Project input before applying attention or not. */
|
||||
private boolean projectInput;
|
||||
|
||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
||||
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getName());
|
||||
}
|
||||
|
||||
private SelfAttentionLayer(){/*No arg constructor for serialization*/}
|
||||
|
||||
protected SelfAttentionLayer(Builder builder){
|
||||
super(builder);
|
||||
nIn = builder.nIn;
|
||||
nOut = builder.nOut;
|
||||
nHeads = builder.nHeads;
|
||||
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
|
||||
projectInput = builder.projectInput;
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for Self Attention layer (layer name = \""
|
||||
+ getName()
|
||||
+ "\"): expect RNN input type with size > 0. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
return InputTypeUtil.getPreprocessorForInputTypeRnnLayers(inputType, RNNFormat.NCW, getName());
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = (int) r.getSize();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for Self Attention layer (layer index = "
|
||||
+ layerIndex
|
||||
+ ", layer name = \""
|
||||
+ getName()
|
||||
+ "\"): expect RNN input type with size > 0. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException("Invalid input for Self Attention layer (layer name = \"" + getName()
|
||||
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
|
||||
}
|
||||
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
||||
|
||||
if (nIn <= 0 || override) {
|
||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||
this.nIn = r.getSize();
|
||||
}
|
||||
if (projectInput) {
|
||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
|
||||
} else {
|
||||
return InputType.recurrent(nIn, itr.getTimeSeriesLength());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||
throw new IllegalStateException("Invalid input for Self Attention layer (layer index = " + layerIndex
|
||||
+ ", layer name = \"" + getName() + "\"): expect RNN input type with size > 0. Got: "
|
||||
+ inputType);
|
||||
}
|
||||
@Override
|
||||
public void defineParameters(SDLayerParams params) {
|
||||
params.clear();
|
||||
|
||||
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
||||
|
||||
if(projectInput){
|
||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
|
||||
}else{
|
||||
return InputType.recurrent(nIn, itr.getTimeSeriesLength());
|
||||
}
|
||||
if (projectInput) {
|
||||
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, nHeads * headSize, nOut);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void defineParameters(SDLayerParams params) {
|
||||
params.clear();
|
||||
|
||||
if(projectInput){
|
||||
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
||||
params.addWeightParam(WEIGHT_KEY_OUT_PROJECTION, nHeads * headSize, nOut);
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if (e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)) {
|
||||
WeightInitUtil.initWeights(
|
||||
nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
} else {
|
||||
WeightInitUtil.initWeights(
|
||||
nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initializeParameters(Map<String, INDArray> params) {
|
||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||
if(e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)){
|
||||
WeightInitUtil.initWeights(nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}else{
|
||||
WeightInitUtil.initWeights(nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
@Override
|
||||
public SDVariable defineLayer(
|
||||
SameDiff sameDiff,
|
||||
SDVariable layerInput,
|
||||
Map<String, SDVariable> paramTable,
|
||||
SDVariable mask) {
|
||||
if (projectInput) {
|
||||
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
||||
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
||||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||
|
||||
return sameDiff.nn.multiHeadDotProductAttention(
|
||||
getName(), layerInput, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
||||
} else {
|
||||
return sameDiff.nn.dotProductAttention(
|
||||
getName(), layerInput, layerInput, layerInput, mask, true);
|
||||
}
|
||||
}
|
||||
|
||||
public abstract static class SelfAttentionLayerBuilder<
|
||||
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
|
||||
extends SameDiffLayerBuilder<C, B> {
|
||||
public C build() {
|
||||
Preconditions.checkArgument(
|
||||
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||
Preconditions.checkArgument(
|
||||
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||
Preconditions.checkArgument(
|
||||
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||
Preconditions.checkArgument(
|
||||
this.nOut % nHeads == 0 || headSize > 0,
|
||||
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||
if(projectInput){
|
||||
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
||||
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
||||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||
|
||||
return sameDiff.nn.multiHeadDotProductAttention(getName(), layerInput, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
||||
}else{
|
||||
return sameDiff.nn.dotProductAttention(getName(), layerInput, layerInput, layerInput, mask, true);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder extends SameDiffLayer.Builder<SelfAttentionLayer.Builder> {
|
||||
|
||||
/**
|
||||
* Number of inputs to the layer (input size)
|
||||
*/
|
||||
private int nIn;
|
||||
|
||||
/**
|
||||
* Number of outputs (output size)
|
||||
*/
|
||||
private int nOut;
|
||||
|
||||
/**
|
||||
* Number of Attention Heads
|
||||
*/
|
||||
private int nHeads;
|
||||
|
||||
/**
|
||||
* Size of attention heads
|
||||
*/
|
||||
private int headSize;
|
||||
|
||||
/**
|
||||
* Project input before applying attention or not.
|
||||
*/
|
||||
private boolean projectInput;
|
||||
|
||||
/**
|
||||
* @param nIn Number of inputs to the layer (input size)
|
||||
*/
|
||||
public Builder nIn(int nIn) {
|
||||
this.nIn = nIn;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param nOut Number of outputs (output size)
|
||||
*/
|
||||
public Builder nOut(int nOut) {
|
||||
this.nOut = nOut;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of Attention Heads
|
||||
*/
|
||||
public Builder nHeads(int nHeads){
|
||||
this.nHeads = nHeads;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Size of attention heads
|
||||
*/
|
||||
public Builder headSize(int headSize){
|
||||
this.headSize = headSize;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Project input before applying attention or not.
|
||||
*/
|
||||
public Builder projectInput(boolean projectInput){
|
||||
this.projectInput = projectInput;
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public SelfAttentionLayer build() {
|
||||
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||
return new SelfAttentionLayer(this);
|
||||
}
|
||||
return initBuild();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,7 +63,16 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
|||
* @return Builder
|
||||
*/
|
||||
@Builder.Default private int depthMultiplier = 1;
|
||||
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
*
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default
|
||||
protected CNN2DFormat dataFormat =
|
||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||
public static SeparableConvolution2DBuilder<?, ?> builder() {
|
||||
return innerBuilder();
|
||||
}
|
||||
|
|
|
@ -20,7 +20,10 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
|
@ -35,195 +38,160 @@ import org.nd4j.common.base.Preconditions;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@ToString(callSuper = true)
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||
public class SpaceToBatchLayer extends NoParamLayer {
|
||||
|
||||
// TODO: throw error when block and padding dims don't match
|
||||
/**
|
||||
* Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
||||
* dimensions
|
||||
*/
|
||||
protected int[] blockSize;
|
||||
/** A 2d array, with format [[padTop, padBottom], [padLeft, padRight]] */
|
||||
@Builder.Default protected int[][] padding = new int[][] {{0, 0}, {0, 0}};
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
*
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||
|
||||
protected int[] blocks;
|
||||
protected int[][] padding;
|
||||
protected CNN2DFormat format = CNN2DFormat.NCHW;
|
||||
public static SpaceToBatchLayerBuilder<?, ?> builder() {
|
||||
return innerBuilder();
|
||||
}
|
||||
// TODO: throw error when block and padding dims don't match
|
||||
|
||||
/**
|
||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and
|
||||
* width dimensions
|
||||
*/
|
||||
public static SpaceToBatchLayerBuilder<?, ?> builder(int[] blocks) {
|
||||
return innerBuilder().blockSize(blocks);
|
||||
}
|
||||
|
||||
protected SpaceToBatchLayer(Builder builder) {
|
||||
super(builder);
|
||||
this.blocks = builder.blocks;
|
||||
this.padding = builder.padding;
|
||||
this.format = builder.format;
|
||||
/**
|
||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and
|
||||
* width dimensions
|
||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft,
|
||||
* padRight]]
|
||||
*/
|
||||
public static SpaceToBatchLayerBuilder<?, ?> builder(int[] blocks, int[][] padding) {
|
||||
return innerBuilder().blockSize(blocks).padding(padding);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SpaceToBatchLayer clone() {
|
||||
return (SpaceToBatchLayer) super.clone();
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.deeplearning4j.nn.api.Layer instantiate(
|
||||
NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners,
|
||||
int layerIndex,
|
||||
INDArray layerParamsView,
|
||||
boolean initializeParams,
|
||||
DataType networkDataType) {
|
||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||
|
||||
org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret =
|
||||
new org.deeplearning4j.nn.layers.convolution.SpaceToBatch(lconf, networkDataType);
|
||||
ret.addTrainingListeners(trainingListeners);
|
||||
ret.setIndex(layerIndex);
|
||||
ret.setParamsViewArray(layerParamsView);
|
||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||
ret.setParamTable(paramTable);
|
||||
ret.setLayerConfiguration(lconf);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
InputType.InputTypeConvolutional outputType =
|
||||
(InputType.InputTypeConvolutional) getOutputType(-1, inputType);
|
||||
|
||||
return new LayerMemoryReport.Builder(name, SpaceToBatchLayer.class, inputType, outputType)
|
||||
.standardMemory(0, 0) // No params
|
||||
.cacheMemory(
|
||||
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for Subsampling layer (layer name=\""
|
||||
+ getName()
|
||||
+ "\"): Expected CNN input, got "
|
||||
+ inputType);
|
||||
}
|
||||
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
||||
return InputType.convolutional(
|
||||
(i.getHeight() + padding[0][0] + padding[0][1]) / blockSize[0],
|
||||
(i.getWidth() + padding[1][0] + padding[1][1]) / blockSize[1],
|
||||
i.getChannels(),
|
||||
i.getFormat());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return EmptyParamInitializer.getInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
Preconditions.checkState(
|
||||
inputType.getType() == InputType.Type.CNN,
|
||||
"Only CNN input types can be used with SpaceToBatchLayer, got %s",
|
||||
inputType);
|
||||
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
if (inputType == null) {
|
||||
throw new IllegalStateException(
|
||||
"Invalid input for space to batch layer (layer name=\""
|
||||
+ getName()
|
||||
+ "\"): input is null");
|
||||
}
|
||||
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
throw new UnsupportedOperationException("SpaceToBatchLayer does not contain parameters");
|
||||
}
|
||||
|
||||
public abstract static class SpaceToBatchLayerBuilder<
|
||||
C extends SpaceToBatchLayer, B extends SpaceToBatchLayerBuilder<C, B>>
|
||||
extends NoParamLayerBuilder<C, B> {
|
||||
/**
|
||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height
|
||||
* and width dimensions
|
||||
* @return
|
||||
*/
|
||||
public B blockSize(int... blocks) {
|
||||
this.blockSize = ValidationUtils.validate2NonNegative(blocks, false, "blocks");
|
||||
return self();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SpaceToBatchLayer clone() {
|
||||
return (SpaceToBatchLayer) super.clone();
|
||||
/**
|
||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft,
|
||||
* padRight]]
|
||||
* @return
|
||||
*/
|
||||
public B padding(int[][] padding) {
|
||||
this.padding$value = ValidationUtils.validate2x2NonNegative(padding, "padding");
|
||||
this.padding$set = true;
|
||||
return self();
|
||||
}
|
||||
|
||||
@Override
|
||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
boolean initializeParams, DataType networkDataType) {
|
||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||
|
||||
org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret =
|
||||
new org.deeplearning4j.nn.layers.convolution.SpaceToBatch(lconf, networkDataType);
|
||||
ret.addTrainingListeners(trainingListeners);
|
||||
ret.setIndex(layerIndex);
|
||||
ret.setParamsViewArray(layerParamsView);
|
||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||
ret.setParamTable(paramTable);
|
||||
ret.setLayerConfiguration(lconf);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
InputType.InputTypeConvolutional outputType = (InputType.InputTypeConvolutional) getOutputType(-1, inputType);
|
||||
|
||||
return new LayerMemoryReport.Builder(name, SpaceToBatchLayer.class, inputType, outputType)
|
||||
.standardMemory(0, 0) //No params
|
||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
|
||||
throw new IllegalStateException("Invalid input for Subsampling layer (layer name=\"" + getName()
|
||||
+ "\"): Expected CNN input, got " + inputType);
|
||||
}
|
||||
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
||||
return InputType.convolutional((i.getHeight() + padding[0][0] + padding[0][1]) / blocks[0],
|
||||
(i.getWidth() + padding[1][0] + padding[1][1]) / blocks[1], i.getChannels(), i.getFormat());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return EmptyParamInitializer.getInstance();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
Preconditions.checkState(inputType.getType() == InputType.Type.CNN, "Only CNN input types can be used with SpaceToBatchLayer, got %s", inputType);
|
||||
this.format = ((InputType.InputTypeConvolutional)inputType).getFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
if (inputType == null) {
|
||||
throw new IllegalStateException("Invalid input for space to batch layer (layer name=\"" + getName()
|
||||
+ "\"): input is null");
|
||||
}
|
||||
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
throw new UnsupportedOperationException("SpaceToBatchLayer does not contain parameters");
|
||||
}
|
||||
|
||||
|
||||
@NoArgsConstructor
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder<T extends Builder<T>> extends LayerConfiguration.Builder<T> {
|
||||
|
||||
/**
|
||||
* Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
||||
* dimensions
|
||||
*/
|
||||
@Setter(AccessLevel.NONE)
|
||||
protected int[] blocks;
|
||||
|
||||
/**
|
||||
* A 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
||||
*/
|
||||
protected int[][] padding;
|
||||
|
||||
protected CNN2DFormat format = CNN2DFormat.NCHW;
|
||||
|
||||
/**
|
||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
||||
* dimensions
|
||||
*/
|
||||
public void setBlocks(int... blocks) {
|
||||
this.blocks = ValidationUtils.validate2NonNegative(blocks, false, "blocks");
|
||||
}
|
||||
|
||||
/**
|
||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
||||
*/
|
||||
public void setPadding(int[][] padding) {
|
||||
this.padding = ValidationUtils.validate2x2NonNegative(padding, "padding");
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
||||
* dimensions
|
||||
*/
|
||||
public Builder(int[] blocks) {
|
||||
this.setBlocks(blocks);
|
||||
this.setPadding(new int[][] {{0, 0}, {0, 0}});
|
||||
}
|
||||
|
||||
/**
|
||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
||||
* dimensions
|
||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
||||
*/
|
||||
public Builder(int[] blocks, int[][] padding) {
|
||||
this.setBlocks(blocks);
|
||||
this.setPadding(padding);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
* @param format Format for activations (in and out)
|
||||
*/
|
||||
public T dataFormat(CNN2DFormat format){
|
||||
this.format = format;
|
||||
return (T)this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
||||
* dimensions
|
||||
*/
|
||||
public T blocks(int... blocks) {
|
||||
this.setBlocks(blocks);
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
||||
*/
|
||||
public T padding(int[][] padding) {
|
||||
this.setPadding(padding);
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T name(String layerName) {
|
||||
this.setLayerName(layerName);
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public SpaceToBatchLayer build() {
|
||||
if(padding == null)
|
||||
setPadding(new int[][] {{0, 0}, {0, 0}});
|
||||
return new SpaceToBatchLayer(this);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
package org.deeplearning4j.nn.conf.layers;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||
|
@ -40,6 +41,7 @@ import java.util.Map;
|
|||
@NoArgsConstructor
|
||||
@ToString(callSuper = true)
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@SuperBuilder
|
||||
public class SpaceToDepthLayer extends NoParamLayer {
|
||||
|
||||
/**
|
||||
|
@ -53,16 +55,20 @@ public class SpaceToDepthLayer extends NoParamLayer {
|
|||
return this == NCHW ? CNN2DFormat.NCHW : CNN2DFormat.NHWC;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param blockSize Block size
|
||||
*/
|
||||
protected int blockSize;
|
||||
protected CNN2DFormat dataFormat;
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
* @param dataFormat Format for activations (in and out)
|
||||
*/
|
||||
@Builder.Default
|
||||
protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||
|
||||
|
||||
protected SpaceToDepthLayer(Builder builder) {
|
||||
super(builder);
|
||||
this.setBlockSize(builder.blockSize);
|
||||
this.setDataFormat(builder.dataFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SpaceToDepthLayer clone() {
|
||||
|
@ -74,7 +80,7 @@ public class SpaceToDepthLayer extends NoParamLayer {
|
|||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
boolean initializeParams, DataType networkDataType) {
|
||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||
|
||||
runInheritance();
|
||||
org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret =
|
||||
new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(lconf, networkDataType);
|
||||
ret.addTrainingListeners(trainingListeners);
|
||||
|
@ -133,78 +139,5 @@ public class SpaceToDepthLayer extends NoParamLayer {
|
|||
}
|
||||
|
||||
|
||||
@NoArgsConstructor
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder<T extends Builder<T>> extends LayerConfiguration.Builder<T> {
|
||||
|
||||
protected int blockSize;
|
||||
|
||||
/**
|
||||
* Data format for input activations. Note DL4J uses NCHW in most cases
|
||||
*/
|
||||
protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||
|
||||
/**
|
||||
* @param blockSize Block size
|
||||
*/
|
||||
public Builder(int blockSize) {
|
||||
this.setBlockSize(blockSize);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param blockSize Block size
|
||||
* @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases
|
||||
*/
|
||||
@Deprecated
|
||||
public Builder(int blockSize, DataFormat dataFormat) {
|
||||
this(blockSize, dataFormat.toFormat());
|
||||
}
|
||||
|
||||
public Builder(int blockSize, CNN2DFormat dataFormat) {
|
||||
this.setBlockSize(blockSize);
|
||||
this.setDataFormat(dataFormat);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param blockSize Block size
|
||||
*/
|
||||
public T blocks(int blockSize) {
|
||||
this.setBlockSize(blockSize);
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases
|
||||
* @deprecated Use {@link #dataFormat(CNN2DFormat)}
|
||||
*/
|
||||
@Deprecated
|
||||
public T dataFormat(DataFormat dataFormat) {
|
||||
return dataFormat(dataFormat.toFormat());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||
* See {@link CNN2DFormat} for more details.<br>
|
||||
* Default: NCHW
|
||||
* @param dataFormat Format for activations (in and out)
|
||||
*/
|
||||
public T dataFormat(CNN2DFormat dataFormat) {
|
||||
this.setDataFormat(dataFormat);
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T name(String layerName) {
|
||||
this.setLayerName(layerName);
|
||||
return (T) this;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public SpaceToDepthLayer build() {
|
||||
return new SpaceToDepthLayer(this);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -20,10 +20,14 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers.objdetect;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
|
@ -41,218 +45,139 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
|||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||
import org.nd4j.linalg.lossfunctions.impl.LossL2;
|
||||
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = false)
|
||||
@SuperBuilder(buildMethodName = "initBuild")
|
||||
public class Yolo2OutputLayer extends LayerConfiguration {
|
||||
|
||||
private double lambdaCoord;
|
||||
private double lambdaNoObj;
|
||||
private ILossFunction lossPositionScale;
|
||||
private ILossFunction lossClassPredictions;
|
||||
@JsonSerialize(using = NDArrayTextSerializer.class)
|
||||
@JsonDeserialize(using = BoundingBoxesDeserializer.class)
|
||||
private INDArray boundingBoxes;
|
||||
/**
|
||||
* Loss function coefficient for position and size/scale components of the loss function. Default
|
||||
* (as per paper): 5
|
||||
*/
|
||||
@Builder.Default private double lambdaCoord = 5;
|
||||
/**
|
||||
* Loss function coefficient for the "no object confidence" components of the loss function.
|
||||
* Default (as per paper): 0.5
|
||||
*/
|
||||
@Builder.Default private double lambdaNoObj = 0.5;
|
||||
/** Loss function for position/scale component of the loss function */
|
||||
@Builder.Default private ILossFunction lossPositionScale = new LossL2();
|
||||
/**
|
||||
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as
|
||||
* per the paper), however Loss MCXENT could also be used (which is more common for
|
||||
* classification).
|
||||
*/
|
||||
@Builder.Default private ILossFunction lossClassPredictions = new LossL2();
|
||||
;
|
||||
/**
|
||||
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows,
|
||||
* columns] = [N, 2] Note that dimensions should be specified as fraction of grid size. For
|
||||
* example, a network with 13x13 output, a value of 1.0 would correspond to one grid cell; a value
|
||||
* of 13 would correspond to the entire image.
|
||||
*/
|
||||
@JsonSerialize(using = NDArrayTextSerializer.class)
|
||||
@JsonDeserialize(using = BoundingBoxesDeserializer.class)
|
||||
@Builder.Default
|
||||
private INDArray boundingBoxes;
|
||||
|
||||
private CNN2DFormat format = CNN2DFormat.NCHW; //Default for serialization of old formats
|
||||
@Builder.Default
|
||||
private CNN2DFormat format = CNN2DFormat.NCHW; // Default for serialization of old formats
|
||||
|
||||
private Yolo2OutputLayer() {
|
||||
//No-arg constructor for Jackson JSON
|
||||
}
|
||||
private Yolo2OutputLayer() {
|
||||
// No-arg constructor for Jackson JSON
|
||||
}
|
||||
|
||||
private Yolo2OutputLayer(Builder builder) {
|
||||
super(builder);
|
||||
this.lambdaCoord = builder.lambdaCoord;
|
||||
this.lambdaNoObj = builder.lambdaNoObj;
|
||||
this.lossPositionScale = builder.lossPositionScale;
|
||||
this.lossClassPredictions = builder.lossClassPredictions;
|
||||
this.boundingBoxes = builder.boundingBoxes;
|
||||
}
|
||||
@Override
|
||||
public Layer instantiate(
|
||||
NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners,
|
||||
int layerIndex,
|
||||
INDArray layerParamsView,
|
||||
boolean initializeParams,
|
||||
DataType networkDataType) {
|
||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||
|
||||
@Override
|
||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret =
|
||||
new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(lconf, networkDataType);
|
||||
ret.addTrainingListeners(trainingListeners);
|
||||
ret.setIndex(layerIndex);
|
||||
ret.setParamsViewArray(layerParamsView);
|
||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||
ret.setParamTable(paramTable);
|
||||
ret.setLayerConfiguration(lconf);
|
||||
return ret;
|
||||
}
|
||||
|
||||
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret =
|
||||
new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(lconf, networkDataType);
|
||||
ret.addTrainingListeners(trainingListeners);
|
||||
ret.setIndex(layerIndex);
|
||||
ret.setParamsViewArray(layerParamsView);
|
||||
Map<String, INDArray> paramTable = initializer().init(this, layerParamsView, initializeParams);
|
||||
ret.setParamTable(paramTable);
|
||||
ret.setLayerConfiguration(lconf);
|
||||
return ret;
|
||||
}
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return EmptyParamInitializer.getInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return EmptyParamInitializer.getInstance();
|
||||
}
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
return inputType; // Same shape output as input
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||
return inputType; //Same shape output as input
|
||||
}
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
this.format = c.getFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||
this.format = c.getFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
switch (inputType.getType()) {
|
||||
case FF:
|
||||
case RNN:
|
||||
throw new UnsupportedOperationException("Cannot use FF or RNN input types");
|
||||
case CNN:
|
||||
return null;
|
||||
case CNNFlat:
|
||||
InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType;
|
||||
return new FeedForwardToCnnPreProcessor(cf.getHeight(), cf.getWidth(), cf.getDepth());
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||
//Not applicable
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
switch (inputType.getType()) {
|
||||
case FF:
|
||||
case RNN:
|
||||
throw new UnsupportedOperationException("Cannot use FF or RNN input types");
|
||||
case CNN:
|
||||
return null;
|
||||
case CNNFlat:
|
||||
InputType.InputTypeConvolutionalFlat cf = (InputType.InputTypeConvolutionalFlat) inputType;
|
||||
return new FeedForwardToCnnPreProcessor(cf.getHeight(), cf.getWidth(), cf.getDepth());
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
return false; //No params
|
||||
}
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
long numValues = inputType.arrayElementsPerExample();
|
||||
|
||||
//This is a VERY rough estimate...
|
||||
return new LayerMemoryReport.Builder(name, Yolo2OutputLayer.class, inputType, inputType)
|
||||
.standardMemory(0, 0) //No params
|
||||
.workingMemory(0, numValues, 0, 6 * numValues).cacheMemory(0, 0) //No cache
|
||||
.build();
|
||||
}
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder extends LayerConfiguration.Builder<Builder> {
|
||||
|
||||
/**
|
||||
* Loss function coefficient for position and size/scale components of the loss function. Default (as per
|
||||
* paper): 5
|
||||
*
|
||||
*/
|
||||
private double lambdaCoord = 5;
|
||||
|
||||
/**
|
||||
* Loss function coefficient for the "no object confidence" components of the loss function. Default (as per
|
||||
* paper): 0.5
|
||||
*
|
||||
*/
|
||||
private double lambdaNoObj = 0.5;
|
||||
|
||||
/**
|
||||
* Loss function for position/scale component of the loss function
|
||||
*
|
||||
*/
|
||||
private ILossFunction lossPositionScale = new LossL2();
|
||||
|
||||
/**
|
||||
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as per the
|
||||
* paper), however Loss MCXENT could also be used (which is more common for classification).
|
||||
*
|
||||
*/
|
||||
private ILossFunction lossClassPredictions = new LossL2();
|
||||
|
||||
/**
|
||||
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows, columns] = [N,
|
||||
* 2] Note that dimensions should be specified as fraction of grid size. For example, a network with 13x13
|
||||
* output, a value of 1.0 would correspond to one grid cell; a value of 13 would correspond to the entire
|
||||
* image.
|
||||
*
|
||||
*/
|
||||
private INDArray boundingBoxes;
|
||||
|
||||
/**
|
||||
* Loss function coefficient for position and size/scale components of the loss function. Default (as per
|
||||
* paper): 5
|
||||
*
|
||||
* @param lambdaCoord Lambda value for size/scale component of loss function
|
||||
*/
|
||||
public Builder lambdaCoord(double lambdaCoord) {
|
||||
this.setLambdaCoord(lambdaCoord);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loss function coefficient for the "no object confidence" components of the loss function. Default (as per
|
||||
* paper): 0.5
|
||||
*
|
||||
* @param lambdaNoObj Lambda value for no-object (confidence) component of the loss function
|
||||
*/
|
||||
public Builder lambdaNoObj(double lambdaNoObj) {
|
||||
this.setLambdaNoObj(lambdaNoObj);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loss function for position/scale component of the loss function
|
||||
*
|
||||
* @param lossPositionScale Loss function for position/scale
|
||||
*/
|
||||
public Builder lossPositionScale(ILossFunction lossPositionScale) {
|
||||
this.setLossPositionScale(lossPositionScale);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as per the
|
||||
* paper), however Loss MCXENT could also be used (which is more common for classification).
|
||||
*
|
||||
* @param lossClassPredictions Loss function for the class prediction error component of the YOLO loss function
|
||||
*/
|
||||
public Builder lossClassPredictions(ILossFunction lossClassPredictions) {
|
||||
this.setLossClassPredictions(lossClassPredictions);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows, columns] = [N,
|
||||
* 2] Note that dimensions should be specified as fraction of grid size. For example, a network with 13x13
|
||||
* output, a value of 1.0 would correspond to one grid cell; a value of 13 would correspond to the entire
|
||||
* image.
|
||||
*
|
||||
* @param boundingBoxes Bounding box prior dimensions (width, height)
|
||||
*/
|
||||
public Builder boundingBoxPriors(INDArray boundingBoxes) {
|
||||
this.setBoundingBoxes(boundingBoxes);
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Yolo2OutputLayer build() {
|
||||
if (boundingBoxes == null) {
|
||||
throw new IllegalStateException("Bounding boxes have not been set");
|
||||
}
|
||||
|
||||
if (boundingBoxes.rank() != 2 || boundingBoxes.size(1) != 2) {
|
||||
throw new IllegalStateException("Bounding box priors must have shape [nBoxes, 2]. Has shape: "
|
||||
+ Arrays.toString(boundingBoxes.shape()));
|
||||
}
|
||||
|
||||
return new Yolo2OutputLayer(this);
|
||||
}
|
||||
@Override
|
||||
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||
// Not applicable
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
return false; // No params
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
long numValues = inputType.arrayElementsPerExample();
|
||||
|
||||
// This is a VERY rough estimate...
|
||||
return new LayerMemoryReport.Builder(name, Yolo2OutputLayer.class, inputType, inputType)
|
||||
.standardMemory(0, 0) // No params
|
||||
.workingMemory(0, numValues, 0, 6 * numValues)
|
||||
.cacheMemory(0, 0) // No cache
|
||||
.build();
|
||||
}
|
||||
|
||||
public static abstract class Yolo2OutputLayerBuilder<
|
||||
C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
|
||||
extends LayerConfigurationBuilder<C, B> {
|
||||
public C build() {
|
||||
if (boundingBoxes$value == null) {
|
||||
throw new IllegalStateException("Bounding boxes have not been set");
|
||||
}
|
||||
|
||||
if (boundingBoxes$value.rank() != 2 || boundingBoxes$value.size(1) != 2) {
|
||||
throw new IllegalStateException(
|
||||
"Bounding box priors must have shape [nBoxes, 2]. Has shape: "
|
||||
+ Arrays.toString(boundingBoxes$value.shape()));
|
||||
}
|
||||
return initBuild();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ public class SimpleRnn extends BaseRecurrentLayer {
|
|||
* If true (default = false): enable layer normalization on this layer
|
||||
*
|
||||
*/
|
||||
@lombok.Builder.Default @Accessors
|
||||
@lombok.Builder.Default @Accessors @Getter
|
||||
private boolean hasLayerNorm = false;
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
|
||||
package org.deeplearning4j.nn.conf.layers.samediff;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.*;
|
||||
import lombok.experimental.SuperBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -44,11 +47,6 @@ import org.nd4j.linalg.learning.regularization.L2Regularization;
|
|||
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
|
||||
|
@ -56,291 +54,324 @@ import java.util.Map;
|
|||
@NoArgsConstructor
|
||||
public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||
|
||||
/**
|
||||
* The regularization for the parameters (excluding biases) - for example {@link WeightDecay}
|
||||
*
|
||||
* <p>-- SETTER -- Set the regularization for the parameters (excluding biases) - for example
|
||||
* {@link WeightDecay}
|
||||
*
|
||||
* @param regularization Regularization to apply for the network parameters/weights (excluding
|
||||
* biases)
|
||||
*/
|
||||
protected List<Regularization> regularization;
|
||||
/**
|
||||
* The regularization for the biases only - for example {@link WeightDecay} -- SETTER -- Set the
|
||||
* regularization for the biases only - for example {@link WeightDecay}
|
||||
*
|
||||
* @param regularizationBias Regularization to apply for the network biases only
|
||||
*/
|
||||
protected List<Regularization> regularizationBias;
|
||||
/**
|
||||
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
|
||||
* org.nd4j.linalg.learning.config.Nesterovs}
|
||||
*
|
||||
* @param updater Updater to use
|
||||
*/
|
||||
protected @Getter @Setter IUpdater updater;
|
||||
/**
|
||||
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
||||
* set by {@link #setUpdater(IUpdater)}
|
||||
*
|
||||
* @param biasUpdater Updater to use for bias parameters
|
||||
*/
|
||||
protected @Getter @Setter IUpdater biasUpdater;
|
||||
|
||||
protected GradientNormalization gradientNormalization;
|
||||
protected double gradientNormalizationThreshold = Double.NaN;
|
||||
|
||||
private SDLayerParams layerParams;
|
||||
|
||||
@Override
|
||||
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||
if (layerParams.isWeightParam(paramName)) {
|
||||
return regularization;
|
||||
} else if (layerParams.isBiasParam(paramName)) {
|
||||
return regularizationBias;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public SDLayerParams getLayerParams() {
|
||||
if (layerParams == null) {
|
||||
layerParams = new SDLayerParams();
|
||||
defineParameters(layerParams);
|
||||
}
|
||||
return layerParams;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
// Default implementation: no-op
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
// Default implementation: no-op
|
||||
return null;
|
||||
}
|
||||
|
||||
public void applyGlobalConfigToLayer(
|
||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||
// Default implementation: no op
|
||||
}
|
||||
|
||||
/**
|
||||
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String,
|
||||
* long...)} and {@link SDLayerParams#addBiasParam(String, long...)}
|
||||
*
|
||||
* @param params Object used to set parameters for this layer
|
||||
*/
|
||||
public abstract void defineParameters(SDLayerParams params);
|
||||
|
||||
/**
|
||||
* Set the initial parameter values for this layer, if required
|
||||
*
|
||||
* @param params Parameter arrays that may be initialized
|
||||
*/
|
||||
public abstract void initializeParameters(Map<String, INDArray> params);
|
||||
|
||||
@Override
|
||||
public abstract org.deeplearning4j.nn.api.Layer instantiate(
|
||||
NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners,
|
||||
int layerIndex,
|
||||
INDArray layerParamsView,
|
||||
boolean initializeParams,
|
||||
DataType networkDataType);
|
||||
|
||||
// ==================================================================================================================
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return SameDiffParamInitializer.getInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public IUpdater getUpdaterByParam(String paramName) {
|
||||
if (biasUpdater != null && initializer().isBiasParam(this, paramName)) {
|
||||
return biasUpdater;
|
||||
} else if (initializer().isBiasParam(this, paramName)
|
||||
|| initializer().isWeightParam(this, paramName)) {
|
||||
return updater;
|
||||
}
|
||||
throw new IllegalStateException("Unknown parameter key: " + paramName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
return new LayerMemoryReport(); // TODO
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the memory layout ('c' or 'f' order - i.e., row/column major) of the parameters. In
|
||||
* most cases, this can/should be left
|
||||
*
|
||||
* @param param Name of the parameter
|
||||
* @return Memory layout ('c' or 'f') of the parameter
|
||||
*/
|
||||
public char paramReshapeOrder(String param) {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
protected void initWeights(int fanIn, int fanOut, WeightInit weightInit, INDArray array) {
|
||||
WeightInitUtil.initWeights(
|
||||
fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
|
||||
}
|
||||
|
||||
public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) {
|
||||
NeuralNetConfiguration bConf = b.build();
|
||||
if (regularization == null || regularization.isEmpty()) {
|
||||
regularization = bConf.getRegularization();
|
||||
}
|
||||
if (regularizationBias == null || regularizationBias.isEmpty()) {
|
||||
regularizationBias = bConf.getRegularizationBias();
|
||||
}
|
||||
if (updater == null) {
|
||||
updater = bConf.getUpdater();
|
||||
}
|
||||
if (biasUpdater == null) {
|
||||
biasUpdater = bConf.getBiasUpdater();
|
||||
}
|
||||
if (gradientNormalization == null) {
|
||||
gradientNormalization = bConf.getGradientNormalization();
|
||||
}
|
||||
if (Double.isNaN(gradientNormalizationThreshold)) {
|
||||
gradientNormalizationThreshold = bConf.getGradientNormalizationThreshold();
|
||||
}
|
||||
|
||||
applyGlobalConfigToLayer(b);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method generates an "all ones" mask array for use in the SameDiff model when none is
|
||||
* provided.
|
||||
*
|
||||
* @param input Input to the layer
|
||||
* @return A mask array - should be same datatype as the input (usually)
|
||||
*/
|
||||
public INDArray onesMaskForInput(INDArray input) {
|
||||
if (input.rank() == 2) {
|
||||
return Nd4j.ones(input.dataType(), input.size(0), 1);
|
||||
} else if (input.rank() == 3) {
|
||||
return Nd4j.ones(
|
||||
input.dataType(),
|
||||
input.size(0),
|
||||
input.size(2)); // mask: [mb, length] vs. input [mb, nIn, length]
|
||||
} else if (input.rank() == 4) {
|
||||
// CNN style - return [mb, 1, 1, 1] for broadcast...
|
||||
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1);
|
||||
} else if (input.rank() == 5) {
|
||||
// CNN3D style - return [mb, 1, 1, 1, 1] for broadcast...
|
||||
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1);
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
"When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, "
|
||||
+ "in order to determine the correct mask shape for this layer");
|
||||
}
|
||||
}
|
||||
|
||||
public abstract static class AbstractSameDiffLayerBuilder<
|
||||
C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>>
|
||||
extends LayerConfigurationBuilder<C, B> {
|
||||
/**
|
||||
* The regularization for the parameters (excluding biases) - for example {@link WeightDecay}
|
||||
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1
|
||||
* regularization coefficient for the bias.
|
||||
*/
|
||||
public B l1(double l1) {
|
||||
// Check if existing L1 exists; if so, replace it
|
||||
NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
|
||||
if (l1 > 0.0) {
|
||||
this.regularization.add(new L1Regularization(l1));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* L2 regularization coefficient (weights only). Use {@link #l2Bias(double)} to configure the l2
|
||||
* regularization coefficient for the bias.<br>
|
||||
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)}
|
||||
* should be preferred to L2 regularization. See {@link WeightDecay} javadoc for further
|
||||
* details.<br>
|
||||
*/
|
||||
public B l2(double l2) {
|
||||
// Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make
|
||||
// sense to use both
|
||||
NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
|
||||
if (l2 > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(
|
||||
this.regularization,
|
||||
WeightDecay.class,
|
||||
"WeightDecay regularization removed: incompatible with added L2 regularization");
|
||||
this.regularization.add(new L2Regularization(l2));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/** L1 regularization coefficient for the bias. Default: 0. See also {@link #l1(double)} */
|
||||
public B l1Bias(double l1Bias) {
|
||||
NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
|
||||
if (l1Bias > 0.0) {
|
||||
this.regularizationBias.add(new L1Regularization(l1Bias));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* L2 regularization coefficient for the bias. Default: 0. See also {@link #l2(double)}<br>
|
||||
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)}
|
||||
* should be preferred to L2 regularization. See {@link WeightDecay} javadoc for further
|
||||
* details.<br>
|
||||
*/
|
||||
public B l2Bias(double l2Bias) {
|
||||
NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
|
||||
if (l2Bias > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(
|
||||
this.regularizationBias,
|
||||
WeightDecay.class,
|
||||
"WeightDecay bias regularization removed: incompatible with added L2 regularization");
|
||||
this.regularizationBias.add(new L2Regularization(l2Bias));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add weight decay regularization for the network parameters (excluding biases).<br>
|
||||
* This applies weight decay <i>with</i> multiplying the learning rate - see {@link WeightDecay}
|
||||
* for more details.<br>
|
||||
*
|
||||
* -- SETTER --
|
||||
* Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay}
|
||||
* @param regularization Regularization to apply for the network parameters/weights (excluding biases)
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
* @see #weightDecay(double, boolean)
|
||||
*/
|
||||
protected List<Regularization> regularization;
|
||||
public B weightDecay(double coefficient) {
|
||||
return weightDecay(coefficient, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* The regularization for the biases only - for example {@link WeightDecay}
|
||||
* -- SETTER --
|
||||
* Set the regularization for the biases only - for example {@link WeightDecay}
|
||||
* @param regularizationBias Regularization to apply for the network biases only
|
||||
*/
|
||||
protected List<Regularization> regularizationBias;
|
||||
/**
|
||||
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
|
||||
* org.nd4j.linalg.learning.config.Nesterovs}
|
||||
* Add weight decay regularization for the network parameters (excluding biases). See {@link
|
||||
* WeightDecay} for more details.<br>
|
||||
*
|
||||
* @param updater Updater to use
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
* @param applyLR Whether the learning rate should be multiplied in when performing weight decay
|
||||
* updates. See {@link WeightDecay} for more details.
|
||||
* @see #weightDecay(double, boolean)
|
||||
*/
|
||||
protected @Getter @Setter IUpdater updater;
|
||||
public B weightDecay(double coefficient, boolean applyLR) {
|
||||
// Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't
|
||||
// make sense to use both
|
||||
NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
|
||||
if (coefficient > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(
|
||||
this.regularization,
|
||||
L2Regularization.class,
|
||||
"L2 regularization removed: incompatible with added WeightDecay regularization");
|
||||
this.regularization.add(new WeightDecay(coefficient, applyLR));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as set by {@link
|
||||
* #setUpdater(IUpdater)}
|
||||
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details. This
|
||||
* applies weight decay <i>with</i> multiplying the learning rate.<br>
|
||||
*
|
||||
* @param biasUpdater Updater to use for bias parameters
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
* @see #weightDecayBias(double, boolean)
|
||||
*/
|
||||
protected @Getter @Setter IUpdater biasUpdater;
|
||||
|
||||
|
||||
protected GradientNormalization gradientNormalization;
|
||||
protected double gradientNormalizationThreshold = Double.NaN;
|
||||
|
||||
private SDLayerParams layerParams;
|
||||
|
||||
@Override
|
||||
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||
if(layerParams.isWeightParam(paramName)){
|
||||
return regularization;
|
||||
} else if(layerParams.isBiasParam(paramName)){
|
||||
return regularizationBias;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public SDLayerParams getLayerParams() {
|
||||
if (layerParams == null) {
|
||||
layerParams = new SDLayerParams();
|
||||
defineParameters(layerParams);
|
||||
}
|
||||
return layerParams;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setNIn(InputType inputType, boolean override) {
|
||||
//Default implementation: no-op
|
||||
}
|
||||
|
||||
@Override
|
||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||
//Default implementation: no-op
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||
//Default implementation: no op
|
||||
public B weightDecayBias(double coefficient) {
|
||||
return weightDecayBias(coefficient, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String, long...)} and {@link
|
||||
* SDLayerParams#addBiasParam(String, long...)}
|
||||
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details<br>
|
||||
*
|
||||
* @param params Object used to set parameters for this layer
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
*/
|
||||
public abstract void defineParameters(SDLayerParams params);
|
||||
|
||||
/**
|
||||
* Set the initial parameter values for this layer, if required
|
||||
*
|
||||
* @param params Parameter arrays that may be initialized
|
||||
*/
|
||||
public abstract void initializeParameters(Map<String, INDArray> params);
|
||||
|
||||
@Override
|
||||
public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||
boolean initializeParams, DataType networkDataType);
|
||||
|
||||
//==================================================================================================================
|
||||
|
||||
@Override
|
||||
public ParamInitializer initializer() {
|
||||
return SameDiffParamInitializer.getInstance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public IUpdater getUpdaterByParam(String paramName) {
|
||||
if (biasUpdater != null && initializer().isBiasParam(this, paramName)) {
|
||||
return biasUpdater;
|
||||
} else if (initializer().isBiasParam(this, paramName) || initializer().isWeightParam(this, paramName)) {
|
||||
return updater;
|
||||
}
|
||||
throw new IllegalStateException("Unknown parameter key: " + paramName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isPretrainParam(String paramName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||
return new LayerMemoryReport(); //TODO
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the memory layout ('c' or 'f' order - i.e., row/column major) of the parameters. In most cases, this
|
||||
* can/should be left
|
||||
*
|
||||
* @param param Name of the parameter
|
||||
* @return Memory layout ('c' or 'f') of the parameter
|
||||
*/
|
||||
public char paramReshapeOrder(String param) {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
protected void initWeights(int fanIn, int fanOut, WeightInit weightInit, INDArray array) {
|
||||
WeightInitUtil.initWeights(fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
|
||||
}
|
||||
|
||||
public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) {
|
||||
NeuralNetConfiguration bConf = b.build();
|
||||
if (regularization == null || regularization.isEmpty()) {
|
||||
regularization = bConf.getRegularization();
|
||||
}
|
||||
if (regularizationBias == null || regularizationBias.isEmpty()) {
|
||||
regularizationBias = bConf.getRegularizationBias();
|
||||
}
|
||||
if (updater == null) {
|
||||
updater = bConf.getUpdater();
|
||||
}
|
||||
if (biasUpdater == null) {
|
||||
biasUpdater = bConf.getBiasUpdater();
|
||||
}
|
||||
if (gradientNormalization == null) {
|
||||
gradientNormalization = bConf.getGradientNormalization();
|
||||
}
|
||||
if (Double.isNaN(gradientNormalizationThreshold)) {
|
||||
gradientNormalizationThreshold = bConf.getGradientNormalizationThreshold();
|
||||
}
|
||||
|
||||
applyGlobalConfigToLayer(b);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method generates an "all ones" mask array for use in the SameDiff model when none is provided.
|
||||
* @param input Input to the layer
|
||||
* @return A mask array - should be same datatype as the input (usually)
|
||||
*/
|
||||
public INDArray onesMaskForInput(INDArray input){
|
||||
if(input.rank() == 2){
|
||||
return Nd4j.ones(input.dataType(), input.size(0), 1);
|
||||
} else if(input.rank() == 3){
|
||||
return Nd4j.ones(input.dataType(), input.size(0), input.size(2)); //mask: [mb, length] vs. input [mb, nIn, length]
|
||||
} else if(input.rank() == 4){
|
||||
//CNN style - return [mb, 1, 1, 1] for broadcast...
|
||||
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1);
|
||||
} else if(input.rank() == 5){
|
||||
//CNN3D style - return [mb, 1, 1, 1, 1] for broadcast...
|
||||
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1);
|
||||
} else {
|
||||
throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, " +
|
||||
"in order to determine the correct mask shape for this layer");
|
||||
}
|
||||
}
|
||||
|
||||
public static abstract class AbstractSameDiffLayerBuilder<C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>> {
|
||||
/**
|
||||
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1 regularization
|
||||
* coefficient for the bias.
|
||||
*/
|
||||
public B l1(double l1) {
|
||||
//Check if existing L1 exists; if so, replace it
|
||||
NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
|
||||
if(l1 > 0.0) {
|
||||
this.regularization.add(new L1Regularization(l1));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* L2 regularization coefficient (weights only). Use {@link #l2Bias(double)} to configure the l2 regularization
|
||||
* coefficient for the bias.<br>
|
||||
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)} should be preferred to
|
||||
* L2 regularization. See {@link WeightDecay} javadoc for further details.<br>
|
||||
*/
|
||||
public B l2(double l2) {
|
||||
//Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make sense to use both
|
||||
NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
|
||||
if(l2 > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(this.regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization");
|
||||
this.regularization.add(new L2Regularization(l2));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* L1 regularization coefficient for the bias. Default: 0. See also {@link #l1(double)}
|
||||
*/
|
||||
public B l1Bias(double l1Bias) {
|
||||
NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
|
||||
if(l1Bias > 0.0) {
|
||||
this.regularizationBias.add(new L1Regularization(l1Bias));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* L2 regularization coefficient for the bias. Default: 0. See also {@link #l2(double)}<br>
|
||||
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)} should be preferred to
|
||||
* L2 regularization. See {@link WeightDecay} javadoc for further details.<br>
|
||||
*/
|
||||
public B l2Bias(double l2Bias) {
|
||||
NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
|
||||
if(l2Bias > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, WeightDecay.class, "WeightDecay bias regularization removed: incompatible with added L2 regularization");
|
||||
this.regularizationBias.add(new L2Regularization(l2Bias));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add weight decay regularization for the network parameters (excluding biases).<br>
|
||||
* This applies weight decay <i>with</i> multiplying the learning rate - see {@link WeightDecay} for more details.<br>
|
||||
*
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
* @see #weightDecay(double, boolean)
|
||||
*/
|
||||
public B weightDecay(double coefficient) {
|
||||
return weightDecay(coefficient, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add weight decay regularization for the network parameters (excluding biases). See {@link WeightDecay} for more details.<br>
|
||||
*
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
* @param applyLR Whether the learning rate should be multiplied in when performing weight decay updates. See {@link WeightDecay} for more details.
|
||||
* @see #weightDecay(double, boolean)
|
||||
*/
|
||||
public B weightDecay(double coefficient, boolean applyLR) {
|
||||
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both
|
||||
NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
|
||||
if(coefficient > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization");
|
||||
this.regularization.add(new WeightDecay(coefficient, applyLR));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
|
||||
/**
|
||||
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details.
|
||||
* This applies weight decay <i>with</i> multiplying the learning rate.<br>
|
||||
*
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
* @see #weightDecayBias(double, boolean)
|
||||
*/
|
||||
public B weightDecayBias(double coefficient) {
|
||||
return weightDecayBias(coefficient, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details<br>
|
||||
*
|
||||
* @param coefficient Weight decay regularization coefficient
|
||||
*/
|
||||
public B weightDecayBias(double coefficient, boolean applyLR) {
|
||||
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both
|
||||
NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
|
||||
if(coefficient > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization");
|
||||
this.regularizationBias.add(new WeightDecay(coefficient, applyLR));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
public B weightDecayBias(double coefficient, boolean applyLR) {
|
||||
// Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't
|
||||
// make sense to use both
|
||||
NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
|
||||
if (coefficient > 0.0) {
|
||||
NetworkUtils.removeInstancesWithWarning(
|
||||
this.regularizationBias,
|
||||
L2Regularization.class,
|
||||
"L2 bias regularization removed: incompatible with added WeightDecay regularization");
|
||||
this.regularizationBias.add(new WeightDecay(coefficient, applyLR));
|
||||
}
|
||||
return self();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -176,7 +176,7 @@ public abstract class SameDiffVertex extends GraphVertex implements ITraininable
|
|||
}
|
||||
|
||||
@Override
|
||||
public String getLayerName() {
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
|
|
|
@ -285,5 +285,14 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
|
|||
super.nOut(nOut);
|
||||
return self();
|
||||
}
|
||||
|
||||
public B pzxActivationFunction(IActivation activation) {
|
||||
this.pzxActivationFunction$value = activation;
|
||||
this.pzxActivationFunction$set = true;
|
||||
return self();
|
||||
}
|
||||
public B pzxActivationFunction(Activation activation) {
|
||||
return this.pzxActivationFunction(activation.getActivationFunction());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -107,7 +107,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
INDArray origInput = input;
|
||||
INDArray origEps = epsilon;
|
||||
if(getTypedLayerConfiguration().getDataFormat() != CNN2DFormat.NCHW) {
|
||||
if(getTypedLayerConfiguration().getConvFormat() != CNN2DFormat.NCHW) {
|
||||
input = input.permute(0,3,1,2); //NHWC to NCHW
|
||||
epsilon = epsilon.permute(0,3,1,2); //NHWC to NCHW
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
|
||||
INDArray z = p.getFirst();
|
||||
CNN2DFormat f = getTypedLayerConfiguration().getDataFormat();
|
||||
CNN2DFormat f = getTypedLayerConfiguration().getConvFormat();
|
||||
if(f != CNN2DFormat.NCHW){
|
||||
z = z.permute(0,3,1,2); //NHWC to NCHW
|
||||
}
|
||||
|
@ -159,7 +159,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())) {
|
||||
INDArray helperDelta = delta;
|
||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC)
|
||||
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC)
|
||||
helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC
|
||||
|
||||
if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){
|
||||
|
@ -177,7 +177,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
ret = helper.backpropGradient(origInput, weights, bias, helperDelta, kernel, strides,
|
||||
pad, biasGradView, weightGradView, afn,
|
||||
getTypedLayerConfiguration().getCudnnAlgoMode(), getTypedLayerConfiguration().getCudnnBwdFilterAlgo(), getTypedLayerConfiguration().getCudnnBwdDataAlgo(),
|
||||
convolutionMode, dilation, getTypedLayerConfiguration().getDataFormat(), workspaceMgr);
|
||||
convolutionMode, dilation, getTypedLayerConfiguration().getConvFormat(), workspaceMgr);
|
||||
} catch (ND4JOpProfilerException e){
|
||||
throw e; //NaN panic etc for debugging
|
||||
} catch (Exception e){
|
||||
|
@ -261,7 +261,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
epsNext = backpropDropOutIfPresent(epsNext);
|
||||
|
||||
if(getTypedLayerConfiguration().getDataFormat() != CNN2DFormat.NCHW){
|
||||
if(getTypedLayerConfiguration().getConvFormat()!= CNN2DFormat.NCHW){
|
||||
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
|
||||
}
|
||||
|
||||
|
@ -295,7 +295,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
}
|
||||
|
||||
protected void validateInputDepth(long inDepth) {
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||
int dim = format == CNN2DFormat.NHWC ? 3 : 1;
|
||||
if (input.size(dim) != inDepth) {
|
||||
String layerName = layerConfiguration.getName();
|
||||
|
@ -304,7 +304,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
|
||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + getTypedLayerConfiguration().getDataFormat().dimensionNames()
|
||||
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + getTypedLayerConfiguration().getConvFormat().dimensionNames()
|
||||
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||
+ layerId();
|
||||
|
||||
|
@ -337,7 +337,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
INDArray input = this.input.castTo(dataType);
|
||||
INDArray inputOrig = input;
|
||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
||||
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||
input = input.permute(0,3,1,2).dup(); //NHWC to NCHW
|
||||
}
|
||||
|
||||
|
@ -421,7 +421,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
INDArray ret = null;
|
||||
try {
|
||||
ret = helper.preOutput(inputOrig, weights, bias, kernel, strides, pad, getTypedLayerConfiguration().getCudnnAlgoMode(),
|
||||
getTypedLayerConfiguration().getCudnnFwdAlgo(), convolutionMode, dilation, getTypedLayerConfiguration().getDataFormat(), workspaceMgr);
|
||||
getTypedLayerConfiguration().getCudnnFwdAlgo(), convolutionMode, dilation, getTypedLayerConfiguration().getConvFormat(), workspaceMgr);
|
||||
} catch (ND4JOpProfilerException e){
|
||||
throw e; //NaN panic etc for debugging
|
||||
} catch (Exception e){
|
||||
|
@ -498,7 +498,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
|||
}
|
||||
}
|
||||
|
||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
||||
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||
z = z.permute(0,2,3,1); //NCHW to NHWC
|
||||
z = workspaceMgr.dup(ArrayType.ACTIVATIONS, z);
|
||||
}
|
||||
|
|
|
@ -61,13 +61,13 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
|||
if (input.rank() != 4) {
|
||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||
+ " array as input to Deconvolution2DLayer with shape " + Arrays.toString(input.shape())
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
|
||||
+ layerId());
|
||||
}
|
||||
|
||||
INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr);
|
||||
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||
boolean nchw = format == CNN2DFormat.NCHW;
|
||||
int hDim = nchw ? 2 : 1;
|
||||
int wDim = nchw ? 3 : 2;
|
||||
|
@ -166,7 +166,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
|||
+ " " + layerId());
|
||||
}
|
||||
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||
boolean nchw = format == CNN2DFormat.NCHW;
|
||||
int cDim = nchw ? 1 : 3;
|
||||
int hDim = nchw ? 2 : 1;
|
||||
|
|
|
@ -59,12 +59,12 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
|||
@Override
|
||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||
assertInputSet(true);
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||
boolean nchw = format == CNN2DFormat.NCHW;
|
||||
if (input.rank() != 4) {
|
||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||
+ " array as input to Convolution layer with shape " + Arrays.toString(input.shape())
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
|
||||
+ layerId());
|
||||
}
|
||||
INDArray bias;
|
||||
|
@ -158,7 +158,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
|||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||
+ " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = "
|
||||
+ index + ") with shape " + Arrays.toString(input.shape()) + ". "
|
||||
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + "."
|
||||
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + "."
|
||||
+ (input.rank() == 2
|
||||
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
|
||||
: "") + " " + layerId());
|
||||
|
@ -166,7 +166,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
|||
|
||||
INDArray input = this.input.castTo(dataType); //no-op if correct dtype
|
||||
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||
boolean nchw = format == CNN2DFormat.NCHW;
|
||||
|
||||
long inDepth = depthWiseWeights.size(2);
|
||||
|
|
|
@ -63,7 +63,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
|||
if (input.rank() != 4) {
|
||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||
+ " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape())
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
|
||||
+ layerId());
|
||||
}
|
||||
INDArray bias;
|
||||
|
@ -74,7 +74,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
|||
|
||||
INDArray input = this.input.castTo(dataType);
|
||||
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
||||
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||
boolean nchw = format == CNN2DFormat.NCHW;
|
||||
|
||||
long miniBatch = input.size(0);
|
||||
|
@ -167,7 +167,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
|||
getParamWithNoise(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY, training, workspaceMgr);
|
||||
|
||||
INDArray input = this.input.castTo(dataType);
|
||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
||||
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||
input = input.permute(0,3,1,2).dup();
|
||||
}
|
||||
|
||||
|
@ -182,7 +182,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
|||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||
+ " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = "
|
||||
+ index + ") with shape " + Arrays.toString(input.shape()) + ". "
|
||||
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + "."
|
||||
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + "."
|
||||
+ (input.rank() == 2
|
||||
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
|
||||
: "")
|
||||
|
@ -199,7 +199,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
|||
|
||||
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
|
||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||
+ " (data format = " + getTypedLayerConfiguration().getDataFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
||||
+ " (data format = " + getTypedLayerConfiguration().getConvFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||
+ layerId();
|
||||
|
||||
|
@ -287,7 +287,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
|||
.build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
|
||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
||||
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||
output = output.permute(0,2,3,1); //NCHW to NHWC
|
||||
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
}
|
||||
|
||||
private int[] getBlocks() {
|
||||
return getTypedLayerConfiguration().getBlocks();
|
||||
return getTypedLayerConfiguration().getBlockSize();
|
||||
}
|
||||
|
||||
private int[][] getPadding() {
|
||||
|
@ -55,7 +55,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
}
|
||||
|
||||
private INDArray getBlocksArray() {
|
||||
int[] intBlocks = getTypedLayerConfiguration().getBlocks();
|
||||
int[] intBlocks = getTypedLayerConfiguration().getBlockSize();
|
||||
return Nd4j.createFromArray(intBlocks);
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
|
||||
INDArray input = this.input.castTo(dataType); //Cast to network dtype if required (no-op if already correct type)
|
||||
|
||||
boolean nchw = getTypedLayerConfiguration().getFormat() == CNN2DFormat.NCHW;
|
||||
boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW;
|
||||
|
||||
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
|
||||
|
||||
|
@ -104,7 +104,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
if (input.rank() != 4) {
|
||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||
+ " array as input to space to batch with shape " + Arrays.toString(input.shape())
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getFormat().dimensionNames() + ". "
|
||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
||||
+ layerId());
|
||||
}
|
||||
|
||||
|
@ -112,7 +112,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
|||
return preOutput;
|
||||
}
|
||||
|
||||
boolean nchw = getTypedLayerConfiguration().getFormat() == CNN2DFormat.NCHW;
|
||||
boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW;
|
||||
|
||||
long inMiniBatch = input.size(0);
|
||||
long depth = input.size(nchw ? 1 : 3);
|
||||
|
|
|
@ -87,7 +87,7 @@ public class ZeroPadding1DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
|
|||
|
||||
@Override
|
||||
public Layer clone() {
|
||||
return ZeroPadding1DLayer.builder(layerConfiguration.clone(), dataType);
|
||||
return new ZeroPadding1DLayer(layerConfiguration.clone(), dataType);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -312,6 +312,6 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
|||
|
||||
@Override
|
||||
public boolean hasLayerNorm(){
|
||||
return getTypedLayerConfiguration().hasLayerNorm();
|
||||
return getTypedLayerConfiguration().isHasLayerNorm();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -167,7 +167,7 @@ public class SimpleRnnParamInitializer extends AbstractParamInitializer {
|
|||
|
||||
protected boolean hasLayerNorm(LayerConfiguration layer){
|
||||
if(layer instanceof SimpleRnn){
|
||||
return ((SimpleRnn) layer).hasLayerNorm();
|
||||
return ((SimpleRnn) layer).isHasLayerNorm();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -21,73 +21,109 @@
|
|||
package org.deeplearning4j.nn.weights;
|
||||
|
||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
|
||||
|
||||
public enum WeightInit {
|
||||
DISTRIBUTION, ZERO, ONES, SIGMOID_UNIFORM, NORMAL, LECUN_NORMAL, UNIFORM, XAVIER, XAVIER_UNIFORM, XAVIER_FAN_IN, XAVIER_LEGACY, RELU,
|
||||
RELU_UNIFORM, IDENTITY, LECUN_UNIFORM, VAR_SCALING_NORMAL_FAN_IN, VAR_SCALING_NORMAL_FAN_OUT, VAR_SCALING_NORMAL_FAN_AVG,
|
||||
VAR_SCALING_UNIFORM_FAN_IN, VAR_SCALING_UNIFORM_FAN_OUT, VAR_SCALING_UNIFORM_FAN_AVG;
|
||||
DISTRIBUTION,
|
||||
ZERO,
|
||||
ONES,
|
||||
SIGMOID_UNIFORM,
|
||||
NORMAL,
|
||||
LECUN_NORMAL,
|
||||
UNIFORM,
|
||||
XAVIER,
|
||||
XAVIER_UNIFORM,
|
||||
XAVIER_FAN_IN,
|
||||
XAVIER_LEGACY,
|
||||
RELU,
|
||||
RELU_UNIFORM,
|
||||
IDENTITY,
|
||||
LECUN_UNIFORM,
|
||||
VAR_SCALING_NORMAL_FAN_IN,
|
||||
VAR_SCALING_NORMAL_FAN_OUT,
|
||||
VAR_SCALING_NORMAL_FAN_AVG,
|
||||
CONSTANT,
|
||||
EMBEDDING,
|
||||
VAR_SCALING_UNIFORM_FAN_IN,
|
||||
VAR_SCALING_UNIFORM_FAN_OUT,
|
||||
VAR_SCALING_UNIFORM_FAN_AVG;
|
||||
|
||||
|
||||
/**
|
||||
* Create an instance of the weight initialization function
|
||||
*
|
||||
* @return a new {@link IWeightInit} instance
|
||||
*/
|
||||
public IWeightInit getWeightInitFunction() {
|
||||
return getWeightInitFunction(null);
|
||||
/**
|
||||
* Create an instance of the weight initialization function
|
||||
*
|
||||
* @return a new {@link IWeightInit} instance
|
||||
*/
|
||||
public IWeightInit getWeightInitFunction(Distribution distribution) {
|
||||
switch (this) {
|
||||
case DISTRIBUTION:
|
||||
return new WeightInitDistribution(distribution);
|
||||
default:
|
||||
return getWeightInitFunction();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an instance of the weight initialization function
|
||||
*
|
||||
* @param distribution Distribution of the weights (Only used in case DISTRIBUTION)
|
||||
* @return a new {@link IWeightInit} instance
|
||||
*/
|
||||
public IWeightInit getWeightInitFunction(Distribution distribution) {
|
||||
switch (this) {
|
||||
case ZERO:
|
||||
return new WeightInitConstant(0.0);
|
||||
case ONES:
|
||||
return new WeightInitConstant(1.0);
|
||||
case DISTRIBUTION:
|
||||
return new WeightInitDistribution(distribution);
|
||||
case SIGMOID_UNIFORM:
|
||||
return new WeightInitSigmoidUniform();
|
||||
case LECUN_NORMAL: //Fall through: these 3 are equivalent
|
||||
case XAVIER_FAN_IN:
|
||||
case NORMAL:
|
||||
return new WeightInitNormal();
|
||||
case UNIFORM:
|
||||
return new WeightInitUniform();
|
||||
case XAVIER:
|
||||
return new WeightInitXavier();
|
||||
case XAVIER_UNIFORM:
|
||||
return new WeightInitXavierUniform();
|
||||
case XAVIER_LEGACY:
|
||||
return new WeightInitXavierLegacy();
|
||||
case RELU:
|
||||
return new WeightInitRelu();
|
||||
case RELU_UNIFORM:
|
||||
return new WeightInitReluUniform();
|
||||
case IDENTITY:
|
||||
return new WeightInitIdentity();
|
||||
case LECUN_UNIFORM:
|
||||
return new WeightInitLecunUniform();
|
||||
case VAR_SCALING_NORMAL_FAN_IN:
|
||||
return new WeightInitVarScalingNormalFanIn();
|
||||
case VAR_SCALING_NORMAL_FAN_OUT:
|
||||
return new WeightInitVarScalingNormalFanOut();
|
||||
case VAR_SCALING_NORMAL_FAN_AVG:
|
||||
return new WeightInitVarScalingNormalFanAvg();
|
||||
case VAR_SCALING_UNIFORM_FAN_IN:
|
||||
return new WeightInitVarScalingUniformFanIn();
|
||||
case VAR_SCALING_UNIFORM_FAN_OUT:
|
||||
return new WeightInitVarScalingUniformFanOut();
|
||||
case VAR_SCALING_UNIFORM_FAN_AVG:
|
||||
return new WeightInitVarScalingUniformFanAvg();
|
||||
|
||||
default:
|
||||
throw new UnsupportedOperationException("Unknown or not supported weight initialization function: " + this);
|
||||
}
|
||||
public IWeightInit getWeightInitFunction(
|
||||
EmbeddingInitializer initializer) { // EmbeddingInitializer
|
||||
switch (this) {
|
||||
case EMBEDDING:
|
||||
return new WeightInitEmbedding(initializer);
|
||||
default:
|
||||
return getWeightInitFunction();
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Create an instance of the weight initialization function
|
||||
*
|
||||
* @return a new {@link IWeightInit} instance
|
||||
*/
|
||||
public IWeightInit getWeightInitFunction() {
|
||||
switch (this) {
|
||||
case CONSTANT:
|
||||
return new WeightInitConstant();
|
||||
case ZERO:
|
||||
return new WeightInitConstant(0.0);
|
||||
case ONES:
|
||||
return new WeightInitConstant(1.0);
|
||||
|
||||
case SIGMOID_UNIFORM:
|
||||
return new WeightInitSigmoidUniform();
|
||||
case LECUN_NORMAL: // Fall through: these 3 are equivalent
|
||||
case XAVIER_FAN_IN:
|
||||
case NORMAL:
|
||||
return new WeightInitNormal();
|
||||
case UNIFORM:
|
||||
return new WeightInitUniform();
|
||||
case XAVIER:
|
||||
return new WeightInitXavier();
|
||||
case XAVIER_UNIFORM:
|
||||
return new WeightInitXavierUniform();
|
||||
case XAVIER_LEGACY:
|
||||
return new WeightInitXavierLegacy();
|
||||
case RELU:
|
||||
return new WeightInitRelu();
|
||||
case RELU_UNIFORM:
|
||||
return new WeightInitReluUniform();
|
||||
case IDENTITY:
|
||||
return new WeightInitIdentity();
|
||||
case LECUN_UNIFORM:
|
||||
return new WeightInitLecunUniform();
|
||||
case VAR_SCALING_NORMAL_FAN_IN:
|
||||
return new WeightInitVarScalingNormalFanIn();
|
||||
case VAR_SCALING_NORMAL_FAN_OUT:
|
||||
return new WeightInitVarScalingNormalFanOut();
|
||||
case VAR_SCALING_NORMAL_FAN_AVG:
|
||||
return new WeightInitVarScalingNormalFanAvg();
|
||||
case VAR_SCALING_UNIFORM_FAN_IN:
|
||||
return new WeightInitVarScalingUniformFanIn();
|
||||
case VAR_SCALING_UNIFORM_FAN_OUT:
|
||||
return new WeightInitVarScalingUniformFanOut();
|
||||
case VAR_SCALING_UNIFORM_FAN_AVG:
|
||||
return new WeightInitVarScalingUniformFanAvg();
|
||||
|
||||
default:
|
||||
throw new UnsupportedOperationException(
|
||||
"Unknown or not supported weight initialization function: " + this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,4 +42,13 @@ public class WeightInitConstant implements IWeightInit {
|
|||
paramView.assign(value);
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.CONSTANT;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -51,4 +51,13 @@ public class WeightInitDistribution implements IWeightInit {
|
|||
}
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.DISTRIBUTION;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,4 +38,13 @@ public class WeightInitLecunUniform implements IWeightInit {
|
|||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-b, b));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.LECUN_UNIFORM;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,4 +38,13 @@ public class WeightInitNormal implements IWeightInit {
|
|||
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.NORMAL;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,4 +33,13 @@ public class WeightInitRelu implements IWeightInit {
|
|||
Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn)
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.RELU;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,4 +35,13 @@ public class WeightInitReluUniform implements IWeightInit {
|
|||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.RELU_UNIFORM;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,4 +35,13 @@ public class WeightInitSigmoidUniform implements IWeightInit {
|
|||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-r, r));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.SIGMOID_UNIFORM;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,4 +34,13 @@ public class WeightInitUniform implements IWeightInit {
|
|||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-a, a));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.UNIFORM;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,4 +50,13 @@ public class WeightInitVarScalingNormalFanAvg implements IWeightInit {
|
|||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,4 +48,13 @@ public class WeightInitVarScalingNormalFanIn implements IWeightInit {
|
|||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.VAR_SCALING_NORMAL_FAN_IN;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,4 +50,13 @@ public class WeightInitVarScalingNormalFanOut implements IWeightInit {
|
|||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||
return paramView.reshape(order, shape);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public WeightInit enumValue() {
|
||||
return WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
|
||||
}
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue