Using @SuperBuilder for LayerConfigurations
Signed-off-by: brian <brian@brutex.de>master
parent
f6100c362d
commit
ad870c5281
|
@ -221,7 +221,7 @@ public class TestMiscFunctions extends BaseSparkTest {
|
||||||
int nIn = 10;
|
int nIn = 10;
|
||||||
|
|
||||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.reconstructionDistribution(
|
.reconstructionDistribution(
|
||||||
new GaussianReconstructionDistribution(Activation.IDENTITY))
|
new GaussianReconstructionDistribution(Activation.IDENTITY))
|
||||||
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build())
|
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13).build())
|
||||||
|
@ -261,7 +261,7 @@ public class TestMiscFunctions extends BaseSparkTest {
|
||||||
|
|
||||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder()
|
||||||
.list().layer(0,
|
.list().layer(0,
|
||||||
new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.reconstructionDistribution(new LossFunctionWrapper(
|
.reconstructionDistribution(new LossFunctionWrapper(
|
||||||
Activation.IDENTITY, new LossMSE()))
|
Activation.IDENTITY, new LossMSE()))
|
||||||
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)
|
.nIn(nIn).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)
|
||||||
|
|
|
@ -96,7 +96,7 @@ public class App {
|
||||||
private static LayerConfiguration[] genLayers() {
|
private static LayerConfiguration[] genLayers() {
|
||||||
return new LayerConfiguration[] {
|
return new LayerConfiguration[] {
|
||||||
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
DenseLayer.builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(Activation.LEAKYRELU).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
DenseLayer.builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(),
|
||||||
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
ActivationLayer.builder(new ActivationLReLU(0.2)).build(),
|
||||||
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
DenseLayer.builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(),
|
||||||
|
|
|
@ -331,10 +331,10 @@ public class CNN2DTestCases {
|
||||||
.build(),
|
.build(),
|
||||||
"leaky_re_lu_8")
|
"leaky_re_lu_8")
|
||||||
.addLayer("outputs",
|
.addLayer("outputs",
|
||||||
new Yolo2OutputLayer.Builder()
|
Yolo2OutputLayer.builder()
|
||||||
.lambdaNoObj(lambdaNoObj)
|
.lambdaNoObj(lambdaNoObj)
|
||||||
.lambdaCoord(lambdaCoord)
|
.lambdaCoord(lambdaCoord)
|
||||||
.boundingBoxPriors(priors)
|
.boundingBoxes(priors)
|
||||||
.build(),
|
.build(),
|
||||||
"convolution2d_9")
|
"convolution2d_9")
|
||||||
.setOutputs("outputs")
|
.setOutputs("outputs")
|
||||||
|
|
|
@ -322,7 +322,7 @@ public class RNNTestCases {
|
||||||
.updater(new Adam(5e-2))
|
.updater(new Adam(5e-2))
|
||||||
.l1(1e-3).l2(1e-3)
|
.l1(1e-3).l2(1e-3)
|
||||||
|
|
||||||
.layer(0, Bidirectional.builder(LSTM.builder().activation(Activation.TANH).nOut(10).build()))
|
.layer(0, Bidirectional.builder(LSTM.builder().activation(Activation.TANH).nOut(10).build()).build())
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||||
.layer(OutputLayer.builder().nOut(6)
|
.layer(OutputLayer.builder().nOut(6)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
|
|
@ -22,9 +22,11 @@ package org.nd4j.linalg.activations;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.impl.*;
|
import org.nd4j.linalg.activations.impl.*;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
public enum Activation {
|
public enum Activation implements IActivation {
|
||||||
CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6,
|
CUBE, ELU, HARDSIGMOID, HARDTANH, IDENTITY, LEAKYRELU, RATIONALTANH, RELU, RELU6,
|
||||||
RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH,
|
RRELU, SIGMOID, SOFTMAX, SOFTPLUS, SOFTSIGN, TANH, RECTIFIEDTANH, SELU, SWISH,
|
||||||
THRESHOLDEDRELU, GELU, MISH;
|
THRESHOLDEDRELU, GELU, MISH;
|
||||||
|
@ -149,4 +151,44 @@ public enum Activation {
|
||||||
throw new UnsupportedOperationException("Activation function not yet supported: " + this);
|
throw new UnsupportedOperationException("Activation function not yet supported: " + this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Carry out activation function on the input array (usually known as 'preOut' or 'z')
|
||||||
|
* Implementations must overwrite "in", transform in place and return "in"
|
||||||
|
* Can support separate behaviour during test
|
||||||
|
*
|
||||||
|
* @param in input array.
|
||||||
|
* @param training true when training.
|
||||||
|
* @return transformed activation
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
|
return getActivationFunction().getActivation(in, training);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Backpropagate the errors through the activation function, given input z and epsilon dL/da.<br>
|
||||||
|
* Returns 2 INDArrays:<br>
|
||||||
|
* (a) The gradient dL/dz, calculated from dL/da, and<br>
|
||||||
|
* (b) The parameter gradients dL/dW, where w is the weights in the activation function. For activation functions
|
||||||
|
* with no gradients, this will be null.
|
||||||
|
*
|
||||||
|
* @param in Input, before applying the activation function (z, or 'preOut')
|
||||||
|
* @param epsilon Gradient to be backpropagated: dL/da, where L is the loss function
|
||||||
|
* @return dL/dz and dL/dW, for weights w (null if activation function has no weights)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
|
return getActivationFunction().backprop(in, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param inputSize
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public int numParams(int inputSize) {
|
||||||
|
return getActivationFunction().numParams(inputSize);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -872,7 +872,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
||||||
.nIn(10)
|
.nIn(10)
|
||||||
.nOut(10)
|
.nOut(10)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.gateActivationFunction(Activation.SIGMOID)
|
.gateActivationFunction(Activation.SIGMOID.getActivationFunction())
|
||||||
.dropOut(0.5)
|
.dropOut(0.5)
|
||||||
.build())
|
.build())
|
||||||
.layer(1, RnnOutputLayer.builder()
|
.layer(1, RnnOutputLayer.builder()
|
||||||
|
|
|
@ -90,8 +90,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.layer(LSTM.builder().nOut(layerSize).build())
|
.layer(LSTM.builder().nOut(layerSize).build())
|
||||||
.layer( projectInput ?
|
.layer( projectInput ?
|
||||||
new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()
|
SelfAttentionLayer.builder().nOut(4).nHeads(2).projectInput(true).build()
|
||||||
: new SelfAttentionLayer.Builder().nHeads(1).projectInput(false).build()
|
: SelfAttentionLayer.builder().nHeads(1).projectInput(false).build()
|
||||||
)
|
)
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
|
@ -151,8 +151,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
.list()
|
.list()
|
||||||
.layer(LSTM.builder().nOut(layerSize).build())
|
.layer(LSTM.builder().nOut(layerSize).build())
|
||||||
.layer( projectInput ?
|
.layer( projectInput ?
|
||||||
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
LearnedSelfAttentionLayer.builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
||||||
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
: LearnedSelfAttentionLayer.builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
||||||
)
|
)
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
|
@ -191,8 +191,8 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
.list()
|
.list()
|
||||||
.layer(LSTM.builder().nOut(layerSize).build())
|
.layer(LSTM.builder().nOut(layerSize).build())
|
||||||
.layer( projectInput ?
|
.layer( projectInput ?
|
||||||
new LearnedSelfAttentionLayer.Builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
LearnedSelfAttentionLayer.builder().nOut(4).nHeads(2).nQueries(numQueries).projectInput(true).build()
|
||||||
: new LearnedSelfAttentionLayer.Builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
: LearnedSelfAttentionLayer.builder().nHeads(1).nQueries(numQueries).projectInput(false).build()
|
||||||
)
|
)
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
|
@ -245,7 +245,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
.list()
|
||||||
.layer(LSTM.builder().nOut(layerSize).build())
|
.layer(LSTM.builder().nOut(layerSize).build())
|
||||||
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
@ -308,7 +308,7 @@ public class AttentionLayerTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
.list()
|
||||||
.layer(LSTM.builder().nOut(layerSize).build())
|
.layer(LSTM.builder().nOut(layerSize).build())
|
||||||
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.AVG).build())
|
||||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
|
|
@ -363,7 +363,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.nOut(1)
|
.nOut(1)
|
||||||
.build()) // output: (5-2+0)/1+1 = 4
|
.build()) // output: (5-2+0)/1+1 = 4
|
||||||
.layer(
|
.layer(
|
||||||
new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW)
|
SpaceToDepthLayer.builder().blockSize(blocks).dataFormat(CNN2DFormat.NCHW)
|
||||||
.build()) // (mb,1,4,4) -> (mb,4,2,2)
|
.build()) // (mb,1,4,4) -> (mb,4,2,2)
|
||||||
.layer(
|
.layer(
|
||||||
OutputLayer.builder()
|
OutputLayer.builder()
|
||||||
|
@ -450,10 +450,10 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer.builder(kernel)
|
ConvolutionLayer.builder(kernel)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
new SpaceToBatchLayer.Builder(blocks)
|
SpaceToBatchLayer.builder(blocks)
|
||||||
.dataFormat(format)
|
.dataFormat(format)
|
||||||
.build()) // trivial space to batch
|
.build()) // trivial space to batch
|
||||||
.layer(
|
.layer(
|
||||||
|
@ -546,7 +546,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(
|
.layer(
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.build()) // output: (5-2+0)/1+1 = 4
|
.build()) // output: (5-2+0)/1+1 = 4
|
||||||
.layer(
|
.layer(
|
||||||
|
@ -641,7 +641,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
0,
|
0,
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.build()) // output: (5-2+0)/1+1 = 4
|
.build()) // output: (5-2+0)/1+1 = 4
|
||||||
.layer(
|
.layer(
|
||||||
|
@ -750,7 +750,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
0,
|
0,
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.build()) // output: (5-2+0)/1+1 = 4
|
.build()) // output: (5-2+0)/1+1 = 4
|
||||||
.layer(
|
.layer(
|
||||||
|
@ -765,7 +765,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(
|
.layer(
|
||||||
2,
|
2,
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(3)
|
.nIn(3)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.build()) // Output: (3-2+0)/1+1 = 2
|
.build()) // Output: (3-2+0)/1+1 = 2
|
||||||
|
@ -849,7 +849,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer.builder()
|
ConvolutionLayer.builder()
|
||||||
.kernelSize(2, 2)
|
.kernelSize(2, 2)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
|
@ -861,7 +861,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.nOut(7)
|
.nOut(7)
|
||||||
.kernelSize(2, 2)
|
.kernelSize(2, 2)
|
||||||
.dataFormat(format)
|
.dataFormat(format)
|
||||||
.setInputSize(4, 4)
|
.inputSize(new int[]{4, 4})
|
||||||
.convolutionMode(ConvolutionMode.Strict)
|
.convolutionMode(ConvolutionMode.Strict)
|
||||||
.hasBias(false)
|
.hasBias(false)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
|
@ -873,7 +873,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.nIn(7)
|
.nIn(7)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.kernelSize(2, 2)
|
.kernelSize(2, 2)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.build()) // (3-2+0)/1+1 = 2
|
.build()) // (3-2+0)/1+1 = 2
|
||||||
|
@ -959,7 +959,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer.builder()
|
ConvolutionLayer.builder()
|
||||||
.kernelSize(2, 2)
|
.kernelSize(2, 2)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
|
@ -970,7 +970,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.nIn(2)
|
.nIn(2)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.kernelSize(2, 2)
|
.kernelSize(2, 2)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.build()) // (4-2+0)/1+1 = 3
|
.build()) // (4-2+0)/1+1 = 3
|
||||||
|
@ -980,7 +980,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.nIn(2)
|
.nIn(2)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.kernelSize(2, 2)
|
.kernelSize(2, 2)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.build()) // (3-2+0)/1+1 = 2
|
.build()) // (3-2+0)/1+1 = 2
|
||||||
|
@ -1076,7 +1076,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer.builder()
|
ConvolutionLayer.builder()
|
||||||
.name("layer 0")
|
.name("layer 0")
|
||||||
.kernelSize(k, k)
|
.kernelSize(k, k)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
|
@ -1097,7 +1097,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.nIn(2)
|
.nIn(2)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.kernelSize(k, k)
|
.kernelSize(k, k)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.build())
|
.build())
|
||||||
|
@ -1181,7 +1181,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer.builder()
|
ConvolutionLayer.builder()
|
||||||
.name("layer 0")
|
.name("layer 0")
|
||||||
.kernelSize(k, k)
|
.kernelSize(k, k)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(stride, stride)
|
.stride(stride, stride)
|
||||||
.padding(0, 0)
|
.padding(0, 0)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
|
@ -1297,7 +1297,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.layer(
|
.layer(
|
||||||
0,
|
0,
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.build()) // output: (6-2+0)/1+1 = 5
|
.build()) // output: (6-2+0)/1+1 = 5
|
||||||
|
@ -1307,7 +1307,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.nIn(3)
|
.nIn(3)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.build()) // output: (6-2+0)/1+1 = 5
|
.build()) // output: (6-2+0)/1+1 = 5
|
||||||
.layer(
|
.layer(
|
||||||
3,
|
3,
|
||||||
|
@ -1436,7 +1436,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.name("deconvolution_2D_layer")
|
.name("deconvolution_2D_layer")
|
||||||
.kernelSize(k, k)
|
.kernelSize(k, k)
|
||||||
.stride(s, s)
|
.stride(s, s)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.dilation(d, d)
|
.dilation(d, d)
|
||||||
.convolutionMode(cm)
|
.convolutionMode(cm)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
|
@ -1530,7 +1530,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.stride(s, s)
|
.stride(s, s)
|
||||||
.dilation(d, d)
|
.dilation(d, d)
|
||||||
.depthMultiplier(3)
|
.depthMultiplier(3)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.build())
|
.build())
|
||||||
|
@ -1621,7 +1621,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.kernelSize(k, k)
|
.kernelSize(k, k)
|
||||||
.stride(s, s)
|
.stride(s, s)
|
||||||
.dilation(d, d)
|
.dilation(d, d)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.build());
|
.build());
|
||||||
|
@ -1642,7 +1642,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.kernelSize(k, k)
|
.kernelSize(k, k)
|
||||||
.stride(s, s)
|
.stride(s, s)
|
||||||
.dilation(d, d)
|
.dilation(d, d)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1732,14 +1732,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.list()
|
.list()
|
||||||
.layer(
|
.layer(
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(inputDepth)
|
.nIn(inputDepth)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.build()) // output: (6-2+0)/1+1 = 5
|
.build()) // output: (6-2+0)/1+1 = 5
|
||||||
.layer(Cropping2D.builder(crop).dataFormat(format).build())
|
.layer(Cropping2D.builder(crop).dataFormat(format).build())
|
||||||
.layer(
|
.layer(
|
||||||
ConvolutionLayer.builder(kernel, stride, padding)
|
ConvolutionLayer.builder(kernel, stride, padding)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(2)
|
.nIn(2)
|
||||||
.nOut(2)
|
.nOut(2)
|
||||||
.build())
|
.build())
|
||||||
|
@ -1857,7 +1857,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest {
|
||||||
.stride(1, 1)
|
.stride(1, 1)
|
||||||
.nIn(nIn)
|
.nIn(nIn)
|
||||||
.nOut(nIn)
|
.nOut(nIn)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.build())
|
.build())
|
||||||
.layer(
|
.layer(
|
||||||
DepthwiseConvolution2D.builder()
|
DepthwiseConvolution2D.builder()
|
||||||
|
|
|
@ -82,7 +82,7 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest {
|
||||||
.seed(123)
|
.seed(123)
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dist(new UniformDistribution(-6, 6))
|
.dist(new UniformDistribution(-6, 6))
|
||||||
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel)
|
.layer(PrimaryCapsules.builder(primaryCapsDim, primarpCapsChannel)
|
||||||
.kernelSize(3, 3)
|
.kernelSize(3, 3)
|
||||||
.stride(2, 2)
|
.stride(2, 2)
|
||||||
.build())
|
.build())
|
||||||
|
|
|
@ -131,7 +131,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
|
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
|
||||||
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
.layer(0, ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
||||||
.dataFormat(nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)
|
.convFormat(nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)
|
||||||
.nOut(layerDepth)
|
.nOut(layerDepth)
|
||||||
.build())
|
.build())
|
||||||
.layer(1, GlobalPoolingLayer.builder().poolingType(pt).build())
|
.layer(1, GlobalPoolingLayer.builder().poolingType(pt).build())
|
||||||
|
|
|
@ -345,10 +345,10 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest {
|
||||||
.dist(new NormalDistribution(0, 0.1))
|
.dist(new NormalDistribution(0, 0.1))
|
||||||
.updater(new NoOp()).graphBuilder().addInputs("input")
|
.updater(new NoOp()).graphBuilder().addInputs("input")
|
||||||
.addLayer("l1", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
.addLayer("l1", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).padding(0, 0)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||||
.addLayer("l2", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
.addLayer("l2", ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
||||||
.padding(0, 0).dataFormat(format)
|
.padding(0, 0).convFormat(format)
|
||||||
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
.nIn(2).nOut(2).activation(Activation.TANH).build(), "input")
|
||||||
.addVertex("merge", new MergeVertex(), "l1", "l2")
|
.addVertex("merge", new MergeVertex(), "l1", "l2")
|
||||||
.addLayer("outputLayer",
|
.addLayer("outputLayer",
|
||||||
|
|
|
@ -116,7 +116,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
|
||||||
.layer(Bidirectional.builder(m,
|
.layer(Bidirectional.builder(m,
|
||||||
(simple ?
|
(simple ?
|
||||||
SimpleRnn.builder().nIn(3).nOut(3).hasLayerNorm(hasLayerNorm).build() :
|
SimpleRnn.builder().nIn(3).nOut(3).hasLayerNorm(hasLayerNorm).build() :
|
||||||
LSTM.builder().nIn(3).nOut(3).build())))
|
LSTM.builder().nIn(3).nOut(3).build())).build())
|
||||||
.layer(RnnOutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX).build())
|
.layer(RnnOutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -115,12 +115,11 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
.activation(a)
|
.activation(a)
|
||||||
.l1(l1[i]).l2(l2[i])
|
.l1(l1[i]).l2(l2[i])
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.list()
|
|
||||||
.layer(ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
.layer(ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nIn(depthIn).nOut(yoloDepth).build())//output: (5-2+0)/1+1 = 4
|
.nIn(depthIn).nOut(yoloDepth).build())//output: (5-2+0)/1+1 = 4
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(bbPrior)
|
.boundingBoxes(bbPrior)
|
||||||
.build())
|
.build())
|
||||||
.inputType(InputType.convolutional(h, w, depthIn, format))
|
.inputType(InputType.convolutional(h, w, depthIn, format))
|
||||||
.build();
|
.build();
|
||||||
|
@ -237,8 +236,8 @@ public class YoloGradientCheckTests extends BaseDL4JTest {
|
||||||
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nOut(4).build())
|
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nOut(4).build())
|
||||||
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build())
|
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build())
|
||||||
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(3,3).stride(1,1).nOut(depthOut).build())
|
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(3,3).stride(1,1).nOut(depthOut).build())
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(bbPriors)
|
.boundingBoxes(bbPriors)
|
||||||
.build())
|
.build())
|
||||||
.inputType(InputType.convolutional(h,w,c))
|
.inputType(InputType.convolutional(h,w,c))
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -437,7 +437,7 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest {
|
||||||
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
||||||
.layer(!lossLayer ? OutputLayer.builder().nIn(10).nOut(nOut[i])
|
.layer(!lossLayer ? OutputLayer.builder().nIn(10).nOut(nOut[i])
|
||||||
.activation(activations[i]).lossFunction(lf[i]).build()
|
.activation(activations[i]).lossFunction(lf[i]).build()
|
||||||
: LossLayer.builder().lossFunction().activation(activations[i]).lossFunction(lf[i])
|
: LossLayer.builder().activation(activations[i]).lossFunction(lf[i].getILossFunction())
|
||||||
.build())
|
.build())
|
||||||
.validateOutputLayerConfig(validate)
|
.validateOutputLayerConfig(validate)
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -48,6 +48,7 @@ import org.nd4j.linalg.learning.config.RmsProp;
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
@ -230,7 +231,8 @@ public class TestConstraints extends BaseDL4JTest {
|
||||||
.biasInit(0.2)
|
.biasInit(0.2)
|
||||||
|
|
||||||
.layer(DenseLayer.builder().nIn(12).nOut(10)
|
.layer(DenseLayer.builder().nIn(12).nOut(10)
|
||||||
.constrainAllParameters(lc).build())
|
.allParamConstraints(List.of(lc))
|
||||||
|
.build())
|
||||||
.layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build())
|
.layer(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -201,21 +201,21 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
||||||
.addInputs("input1", "input2", "input3")
|
.addInputs("input1", "input2", "input3")
|
||||||
.addLayer("dense1",
|
.addLayer("dense1",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input1")
|
"input1")
|
||||||
.addLayer("dense2",
|
.addLayer("dense2",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input2")
|
"input2")
|
||||||
.addLayer("dense3",
|
.addLayer("dense3",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input3")
|
"input3")
|
||||||
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
|
.addVertex("elementwiseAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "dense1",
|
||||||
"dense2", "dense3")
|
"dense2", "dense3")
|
||||||
.addLayer("output",
|
.addLayer("output",
|
||||||
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
||||||
.activation(new ActivationSigmoid())
|
.activation(Activation.SIGMOID)
|
||||||
.lossFunction(LossFunction.MSE).build(),
|
.lossFunction(LossFunction.MSE).build(),
|
||||||
"elementwiseAdd")
|
"elementwiseAdd")
|
||||||
.setOutputs("output").build();
|
.setOutputs("output").build();
|
||||||
|
@ -377,21 +377,21 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
||||||
.addInputs("input1", "input2", "input3")
|
.addInputs("input1", "input2", "input3")
|
||||||
.addLayer("dense1",
|
.addLayer("dense1",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input1")
|
"input1")
|
||||||
.addLayer("dense2",
|
.addLayer("dense2",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input2")
|
"input2")
|
||||||
.addLayer("dense3",
|
.addLayer("dense3",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input3")
|
"input3")
|
||||||
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1",
|
.addVertex("elementwiseProduct", new ElementWiseVertex(ElementWiseVertex.Op.Product), "dense1",
|
||||||
"dense2", "dense3")
|
"dense2", "dense3")
|
||||||
.addLayer("output",
|
.addLayer("output",
|
||||||
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
||||||
.activation(new ActivationSigmoid())
|
.activation(Activation.SIGMOID)
|
||||||
.lossFunction(LossFunction.MSE).build(),
|
.lossFunction(LossFunction.MSE).build(),
|
||||||
"elementwiseProduct")
|
"elementwiseProduct")
|
||||||
.setOutputs("output").build();
|
.setOutputs("output").build();
|
||||||
|
@ -552,17 +552,17 @@ public class ElementWiseVertexTest extends BaseDL4JTest {
|
||||||
.addInputs("input1", "input2")
|
.addInputs("input1", "input2")
|
||||||
.addLayer("dense1",
|
.addLayer("dense1",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input1")
|
"input1")
|
||||||
.addLayer("dense2",
|
.addLayer("dense2",
|
||||||
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
DenseLayer.builder().nIn(featuresz).nOut(midsz)
|
||||||
.activation(new ActivationTanH()).build(),
|
.activation(Activation.TANH).build(),
|
||||||
"input2")
|
"input2")
|
||||||
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
|
.addVertex("elementwiseSubtract", new ElementWiseVertex(ElementWiseVertex.Op.Subtract),
|
||||||
"dense1", "dense2")
|
"dense1", "dense2")
|
||||||
.addLayer("output",
|
.addLayer("output",
|
||||||
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
OutputLayer.builder().nIn(midsz).nOut(outputsz)
|
||||||
.activation(new ActivationSigmoid())
|
.activation(Activation.SIGMOID)
|
||||||
.lossFunction(LossFunction.MSE).build(),
|
.lossFunction(LossFunction.MSE).build(),
|
||||||
"elementwiseSubtract")
|
"elementwiseSubtract")
|
||||||
.setOutputs("output").build();
|
.setOutputs("output").build();
|
||||||
|
|
|
@ -493,7 +493,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(3).activation(Activation.TANH).build();
|
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(3).activation(Activation.TANH).build();
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
ol = new Yolo2OutputLayer.Builder().boundingBoxPriors(Nd4j.create(new double[][]{{1.0, 1.0}, {2.0, 2.0}}).castTo(networkDtype)).build();
|
ol = Yolo2OutputLayer.builder().boundingBoxes(Nd4j.create(new double[][]{{1.0, 1.0}, {2.0, 2.0}}).castTo(networkDtype)).build();
|
||||||
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(14).activation(Activation.TANH).build();
|
secondLast = ConvolutionLayer.builder().kernelSize(2, 2).stride(1, 1).nOut(14).activation(Activation.TANH).build();
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -817,8 +817,8 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.updater(new Adam(1e-2))
|
.updater(new Adam(1e-2))
|
||||||
.list()
|
.list()
|
||||||
.layer(new SpaceToBatchLayer.Builder().blocks(1, 1).build())
|
.layer(SpaceToBatchLayer.builder().blockSize(1, 1).build())
|
||||||
.layer(new SpaceToDepthLayer.Builder().blocks(2).build())
|
.layer(SpaceToDepthLayer.builder().blockSize(2).build())
|
||||||
.layer(OutputLayer.builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.layer(OutputLayer.builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.inputType(InputType.convolutional(28, 28, 5))
|
.inputType(InputType.convolutional(28, 28, 5))
|
||||||
.build();
|
.build();
|
||||||
|
@ -907,7 +907,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
.layer(Bidirectional.builder(LSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()).build())
|
.layer(Bidirectional.builder(LSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()).build())
|
||||||
.layer(new TimeDistributed(DenseLayer.builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
|
.layer(new TimeDistributed(DenseLayer.builder().nIn(10).nOut(5).activation(Activation.TANH).build()))
|
||||||
.layer(SimpleRnn.builder().nIn(5).nOut(5).build())
|
.layer(SimpleRnn.builder().nIn(5).nOut(5).build())
|
||||||
.layer(new MaskZeroLayer.Builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskValue(0.0).build())
|
.layer(MaskZeroLayer.builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskingValue(0.0).build())
|
||||||
.layer(secondLast)
|
.layer(secondLast)
|
||||||
.layer(ol)
|
.layer(ol)
|
||||||
.build();
|
.build();
|
||||||
|
@ -986,7 +986,7 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dist(new UniformDistribution(-6, 6))
|
.dist(new UniformDistribution(-6, 6))
|
||||||
|
|
||||||
.layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel)
|
.layer(PrimaryCapsules.builder(primaryCapsDim, primarpCapsChannel)
|
||||||
.kernelSize(3, 3)
|
.kernelSize(3, 3)
|
||||||
.stride(2, 2)
|
.stride(2, 2)
|
||||||
.build())
|
.build())
|
||||||
|
@ -1400,9 +1400,9 @@ public class DTypeTests extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
.list()
|
||||||
.layer(LSTM.builder().nOut(layerSize).build())
|
.layer(LSTM.builder().nOut(layerSize).build())
|
||||||
.layer(new SelfAttentionLayer.Builder().nOut(8).nHeads(2).projectInput(true).build())
|
.layer(SelfAttentionLayer.builder().nOut(8).nHeads(2).projectInput(true).build())
|
||||||
.layer(new LearnedSelfAttentionLayer.Builder().nOut(8).nHeads(2).nQueries(numQueries).projectInput(true).build())
|
.layer(LearnedSelfAttentionLayer.builder().nOut(8).nHeads(2).nQueries(numQueries).projectInput(true).build())
|
||||||
.layer(new RecurrentAttentionLayer.Builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
.layer(RecurrentAttentionLayer.builder().nIn(layerSize).nOut(layerSize).nHeads(1).projectInput(false).hasBias(false).build())
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||||
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
|
|
@ -161,7 +161,7 @@ public class TestCompGraphCNN extends BaseDL4JTest {
|
||||||
imageHeight))
|
imageHeight))
|
||||||
.addLayer("conv1", ConvolutionLayer.builder()
|
.addLayer("conv1", ConvolutionLayer.builder()
|
||||||
.kernelSize(kernelHeight, kernelWidth).stride(1, 1)
|
.kernelSize(kernelHeight, kernelWidth).stride(1, 1)
|
||||||
.dataFormat(CNN2DFormat.NCHW)
|
.convFormat(CNN2DFormat.NCHW)
|
||||||
.nIn(nChannels).nOut(2).weightInit(WeightInit.XAVIER)
|
.nIn(nChannels).nOut(2).weightInit(WeightInit.XAVIER)
|
||||||
.activation(Activation.RELU).build(), "input")
|
.activation(Activation.RELU).build(), "input")
|
||||||
.addLayer("pool1",
|
.addLayer("pool1",
|
||||||
|
|
|
@ -1163,7 +1163,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
"act")
|
"act")
|
||||||
.addLayer("drop", DropoutLayer.builder(0.5).build(), "pool")
|
.addLayer("drop", DropoutLayer.builder(0.5).build(), "pool")
|
||||||
.addLayer("dense", DenseLayer.builder().nIn(1).nOut(1).build(), "drop")
|
.addLayer("dense", DenseLayer.builder().nIn(1).nOut(1).build(), "drop")
|
||||||
.addLayer("loss", LossLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
.addLayer("loss", LossLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT.getILossFunction())
|
||||||
.build(), "dense")
|
.build(), "dense")
|
||||||
.allowDisconnected(true)
|
.allowDisconnected(true)
|
||||||
.setOutputs("loss").build();
|
.setOutputs("loss").build();
|
||||||
|
@ -1457,7 +1457,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build(), "in")
|
.layer("0", SubsamplingLayer.builder().kernelSize(2,2).stride(2,2).build(), "in")
|
||||||
.layer("1", LossLayer.builder().lossFunction().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build(), "0")
|
.layer("1", LossLayer.builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build(), "0")
|
||||||
.setOutputs("1")
|
.setOutputs("1")
|
||||||
.setInputTypes(InputType.convolutionalFlat(28,28,1))
|
.setInputTypes(InputType.convolutionalFlat(28,28,1))
|
||||||
.build();
|
.build();
|
||||||
|
@ -1791,7 +1791,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
.nIn(10).nOut(5)
|
.nIn(10).nOut(5)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.dropOut(new GaussianNoise(0.05))
|
.dropOut(new GaussianNoise(0.05))
|
||||||
.build())
|
.build()).build()
|
||||||
,"merge")
|
,"merge")
|
||||||
.addLayer("out1",
|
.addLayer("out1",
|
||||||
RnnOutputLayer.builder().activation(Activation.SOFTMAX)
|
RnnOutputLayer.builder().activation(Activation.SOFTMAX)
|
||||||
|
@ -1986,10 +1986,10 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
.updater(new Adam())
|
.updater(new Adam())
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("x_emb")
|
.addInputs("x_emb")
|
||||||
.addLayer("agg_lstm", Bidirectional.builder(CONCAT, LSTM.builder().nOut(hiddenSize/2).build()), "x_emb")
|
.addLayer("agg_lstm", Bidirectional.builder(CONCAT, LSTM.builder().nOut(hiddenSize/2).build()).build(), "x_emb")
|
||||||
.addLayer("agg_att", DenseLayer.builder().nIn(100).nOut(1).activation(Activation.SOFTMAX).build(), "agg_lstm")
|
.addLayer("agg_att", DenseLayer.builder().nIn(100).nOut(1).activation(Activation.SOFTMAX).build(), "agg_lstm")
|
||||||
.addVertex("att", new PreprocessorVertex(new ComposableInputPreProcessor(new FeedForwardToRnnPreProcessor(), new PermutePreprocessor(0,2,1), new RnnToFeedForwardPreProcessor())), "agg_att")
|
.addVertex("att", new PreprocessorVertex(new ComposableInputPreProcessor(new FeedForwardToRnnPreProcessor(), new PermutePreprocessor(0,2,1), new RnnToFeedForwardPreProcessor())), "agg_att")
|
||||||
.addLayer("att_repeat", new RepeatVector.Builder(hiddenSize).build(),"att")
|
.addLayer("att_repeat", RepeatVector.builder().repetitionFactor(hiddenSize).build(),"att")
|
||||||
.addVertex("att_trans", new PreprocessorVertex(new PermutePreprocessor(0, 2, 1)), "att_repeat")
|
.addVertex("att_trans", new PreprocessorVertex(new PermutePreprocessor(0, 2, 1)), "att_repeat")
|
||||||
.addVertex("mult", new ElementWiseVertex(ElementWiseVertex.Op.Product), "agg_lstm", "att_trans")
|
.addVertex("mult", new ElementWiseVertex(ElementWiseVertex.Op.Product), "agg_lstm", "att_trans")
|
||||||
.addLayer("sum", GlobalPoolingLayer.builder().build(), "mult")
|
.addLayer("sum", GlobalPoolingLayer.builder().build(), "mult")
|
||||||
|
@ -2197,16 +2197,16 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("l0", ConvolutionLayer.builder()
|
.layer("l0", ConvolutionLayer.builder()
|
||||||
.nOut(16)
|
.nOut(16)
|
||||||
.dataFormat(CNN2DFormat.NHWC)
|
.convFormat(CNN2DFormat.NHWC)
|
||||||
.kernelSize(2,2).stride(1,1)
|
.kernelSize(2,2).stride(1,1)
|
||||||
.build(), "in")
|
.build(), "in")
|
||||||
.layer("l1", ConvolutionLayer.builder()
|
.layer("l1", ConvolutionLayer.builder()
|
||||||
.nOut(8)
|
.nOut(8)
|
||||||
.dataFormat(CNN2DFormat.NHWC)
|
.convFormat(CNN2DFormat.NHWC)
|
||||||
.kernelSize(2,2).stride(1,1)
|
.kernelSize(2,2).stride(1,1)
|
||||||
.build(), "in")
|
.build(), "in")
|
||||||
.addVertex("merge", new MergeVertex(), "l0", "l1")
|
.addVertex("merge", new MergeVertex(), "l0", "l1")
|
||||||
.layer("out", CnnLossLayer.builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
|
.layer("out", CnnLossLayer.builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build(), "merge")
|
||||||
.setOutputs("out")
|
.setOutputs("out")
|
||||||
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
|
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -357,7 +357,7 @@ public class ActivationLayerTest extends BaseDL4JTest {
|
||||||
.activation(Activation.RATIONALTANH)
|
.activation(Activation.RATIONALTANH)
|
||||||
|
|
||||||
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
.layer(DenseLayer.builder().nIn(10).nOut(10).build())
|
||||||
.layer(ActivationLayer.builder())
|
.layer(ActivationLayer.builder().build())
|
||||||
.layer(ActivationLayer.builder().build())
|
.layer(ActivationLayer.builder().build())
|
||||||
.layer(ActivationLayer.builder().activation(Activation.ELU).build())
|
.layer(ActivationLayer.builder().activation(Activation.ELU).build())
|
||||||
.layer(
|
.layer(
|
||||||
|
@ -404,7 +404,7 @@ public class ActivationLayerTest extends BaseDL4JTest {
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in")
|
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in")
|
||||||
.addLayer("1", ActivationLayer.builder(), "0")
|
.addLayer("1", ActivationLayer.builder().build(), "0")
|
||||||
.addLayer("2", ActivationLayer.builder().build(), "1")
|
.addLayer("2", ActivationLayer.builder().build(), "1")
|
||||||
.addLayer("3", ActivationLayer.builder().activation(Activation.ELU).build(), "2")
|
.addLayer("3", ActivationLayer.builder().activation(Activation.ELU).build(), "2")
|
||||||
.addLayer(
|
.addLayer(
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class CapsNetMNISTTest extends BaseDL4JTest {
|
||||||
.kernelSize(9, 9)
|
.kernelSize(9, 9)
|
||||||
.stride(3, 3)
|
.stride(3, 3)
|
||||||
.build())
|
.build())
|
||||||
.layer(new PrimaryCapsules.Builder(8, 8)
|
.layer(PrimaryCapsules.builder(8, 8)
|
||||||
.kernelSize(7, 7)
|
.kernelSize(7, 7)
|
||||||
.stride(2, 2)
|
.stride(2, 2)
|
||||||
.build())
|
.build())
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testOutputType(){
|
public void testOutputType(){
|
||||||
PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8)
|
PrimaryCapsules layer = PrimaryCapsules.builder(8, 8)
|
||||||
.kernelSize(7, 7)
|
.kernelSize(7, 7)
|
||||||
.stride(2, 2)
|
.stride(2, 2)
|
||||||
.build();
|
.build();
|
||||||
|
@ -57,7 +57,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testInputType(){
|
public void testInputType(){
|
||||||
PrimaryCapsules layer = new PrimaryCapsules.Builder(8, 8)
|
PrimaryCapsules layer = PrimaryCapsules.builder(8, 8)
|
||||||
.kernelSize(7, 7)
|
.kernelSize(7, 7)
|
||||||
.stride(2, 2)
|
.stride(2, 2)
|
||||||
.build();
|
.build();
|
||||||
|
@ -72,7 +72,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testConfig(){
|
public void testConfig(){
|
||||||
PrimaryCapsules layer1 = new PrimaryCapsules.Builder(8, 10)
|
PrimaryCapsules layer1 = PrimaryCapsules.builder(8, 10)
|
||||||
.kernelSize(5, 5)
|
.kernelSize(5, 5)
|
||||||
.stride(4, 4)
|
.stride(4, 4)
|
||||||
.useLeakyReLU(0.5)
|
.useLeakyReLU(0.5)
|
||||||
|
@ -84,22 +84,22 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
||||||
assertArrayEquals(new int[]{4, 4}, layer1.getStride());
|
assertArrayEquals(new int[]{4, 4}, layer1.getStride());
|
||||||
assertArrayEquals(new int[]{0, 0}, layer1.getPadding());
|
assertArrayEquals(new int[]{0, 0}, layer1.getPadding());
|
||||||
assertArrayEquals(new int[]{1, 1}, layer1.getDilation());
|
assertArrayEquals(new int[]{1, 1}, layer1.getDilation());
|
||||||
assertTrue(layer1.isUseRelu());
|
assertTrue(layer1.isUseRelU());
|
||||||
assertEquals(0.5, layer1.getLeak(), 0.001);
|
assertEquals(0.5, layer1.getUseLeakyReLU(), 0.001);
|
||||||
|
|
||||||
PrimaryCapsules layer2 = new PrimaryCapsules.Builder(8, 10)
|
PrimaryCapsules layer2 = PrimaryCapsules.builder(8, 10)
|
||||||
.kernelSize(5, 5)
|
.kernelSize(5, 5)
|
||||||
.stride(4, 4)
|
.stride(4, 4)
|
||||||
.build();
|
.build();
|
||||||
assertFalse(layer2.isUseRelu());
|
assertFalse(layer2.isUseRelU());
|
||||||
|
|
||||||
PrimaryCapsules layer3 = new PrimaryCapsules.Builder(8, 10)
|
PrimaryCapsules layer3 = PrimaryCapsules.builder(8, 10)
|
||||||
.kernelSize(5, 5)
|
.kernelSize(5, 5)
|
||||||
.stride(4, 4)
|
.stride(4, 4)
|
||||||
.useReLU()
|
.useReLU()
|
||||||
.build();
|
.build();
|
||||||
assertTrue(layer3.isUseRelu());
|
assertTrue(layer3.isUseRelU());
|
||||||
assertEquals(0, layer3.getLeak(), 0.001);
|
assertEquals(0, layer3.getUseLeakyReLU(), 0.001);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.seed(123)
|
.seed(123)
|
||||||
.list()
|
.list()
|
||||||
.layer(new PrimaryCapsules.Builder(8, 10)
|
.layer(PrimaryCapsules.builder(8, 10)
|
||||||
.kernelSize(5, 5)
|
.kernelSize(5, 5)
|
||||||
.stride(4, 4)
|
.stride(4, 4)
|
||||||
.useLeakyReLU(0.5)
|
.useLeakyReLU(0.5)
|
||||||
|
|
|
@ -561,7 +561,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
.kernelSize(3, 3)
|
.kernelSize(3, 3)
|
||||||
.stride(2, 2)
|
.stride(2, 2)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.helperAllowFallback(false)
|
.helperAllowFallback(false)
|
||||||
.build(), format, cm, null);
|
.build(), format, cm, null);
|
||||||
|
@ -685,14 +685,14 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
return getNetWithLayer(Deconvolution2D.builder().nOut(2)
|
return getNetWithLayer(Deconvolution2D.builder().nOut(2)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.kernelSize(2,2)
|
.kernelSize(2,2)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(2,2)
|
.stride(2,2)
|
||||||
.build(), format, cm, null);
|
.build(), format, cm, null);
|
||||||
} else {
|
} else {
|
||||||
return getNetWithLayer(Deconvolution2D.builder().nOut(2)
|
return getNetWithLayer(Deconvolution2D.builder().nOut(2)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.kernelSize(2,2)
|
.kernelSize(2,2)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.stride(2,2)
|
.stride(2,2)
|
||||||
.build(), format, cm, null);
|
.build(), format, cm, null);
|
||||||
}
|
}
|
||||||
|
@ -715,26 +715,26 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
|
|
||||||
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
private MultiLayerNetwork getSpaceToDepthNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
if (setOnLayerAlso) {
|
if (setOnLayerAlso) {
|
||||||
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
return getNetWithLayer(SpaceToDepthLayer.builder()
|
||||||
.blocks(2)
|
.blockSize(2)
|
||||||
.dataFormat(format)
|
.dataFormat(format)
|
||||||
.build(), format, ConvolutionMode.Same, null);
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
} else {
|
} else {
|
||||||
return getNetWithLayer(new SpaceToDepthLayer.Builder()
|
return getNetWithLayer(SpaceToDepthLayer.builder()
|
||||||
.blocks(2)
|
.blockSize(2)
|
||||||
.build(), format, ConvolutionMode.Same, null);
|
.build(), format, ConvolutionMode.Same, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
private MultiLayerNetwork getSpaceToBatchNet(CNN2DFormat format, boolean setOnLayerAlso) {
|
||||||
if (setOnLayerAlso) {
|
if (setOnLayerAlso) {
|
||||||
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
return getNetWithLayer(SpaceToBatchLayer.builder()
|
||||||
.blocks(2, 2)
|
.blockSize(2, 2)
|
||||||
.dataFormat(format)
|
.dataFormat(format)
|
||||||
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||||
} else {
|
} else {
|
||||||
return getNetWithLayer(new SpaceToBatchLayer.Builder()
|
return getNetWithLayer(SpaceToBatchLayer.builder()
|
||||||
.blocks(2, 2)
|
.blockSize(2, 2)
|
||||||
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
.build(), format, ConvolutionMode.Same, InputType.convolutional(16, 16, 3, format));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -807,7 +807,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
.kernelSize(3, 3)
|
.kernelSize(3, 3)
|
||||||
.stride(2, 2)
|
.stride(2, 2)
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.dataFormat(format)
|
.convFormat(format)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.helperAllowFallback(false)
|
.helperAllowFallback(false)
|
||||||
.build());
|
.build());
|
||||||
|
@ -988,7 +988,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
|
|
||||||
switch (i){
|
switch (i){
|
||||||
case 0:
|
case 0:
|
||||||
b.layer(ConvolutionLayer.builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build());
|
b.layer(ConvolutionLayer.builder().kernelSize(2,2).nIn(3).nOut(3).convFormat(df).build());
|
||||||
b.inputType(InputType.convolutional(12,12,3,df));
|
b.inputType(InputType.convolutional(12,12,3,df));
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
|
@ -996,7 +996,7 @@ public class ConvDataFormatTests extends BaseDL4JTest {
|
||||||
b.inputType(InputType.convolutional(12,12,3,df));
|
b.inputType(InputType.convolutional(12,12,3,df));
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
b.layer(Deconvolution2D.builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
b.layer(Deconvolution2D.builder().convFormat(df).kernelSize(2,2).nIn(3).nOut(3).build());
|
||||||
b.inputType(InputType.convolutional(12,12,3,df));
|
b.inputType(InputType.convolutional(12,12,3,df));
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -370,7 +371,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list()
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list()
|
||||||
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14
|
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14
|
||||||
.layer(new SpaceToBatchLayer.Builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7
|
.layer(SpaceToBatchLayer.builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7
|
||||||
.layer(OutputLayer.builder().nOut(3).activation(Activation.SOFTMAX).build())
|
.layer(OutputLayer.builder().nOut(3).activation(Activation.SOFTMAX).build())
|
||||||
.inputType(InputType.convolutional(28, 28, 1));
|
.inputType(InputType.convolutional(28, 28, 1));
|
||||||
|
|
||||||
|
@ -389,11 +390,11 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest {
|
||||||
|
|
||||||
int blocks = 2;
|
int blocks = 2;
|
||||||
|
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list()
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder()
|
||||||
//(28-2+0)/2+1 = 14 -> 14x14x3 out
|
//(28-2+0)/2+1 = 14 -> 14x14x3 out
|
||||||
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build())
|
.layer(ConvolutionLayer.builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build())
|
||||||
// Divide space dimensions by blocks, i.e. 14/2 = 7 -> 7x7x12 out (3x2x2 depth)
|
// Divide space dimensions by blocks, i.e. 14/2 = 7 -> 7x7x12 out (3x2x2 depth)
|
||||||
.layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build())
|
.layer(SpaceToDepthLayer.builder().blockSize(blocks).dataFormat(CNN2DFormat.NCHW).build())
|
||||||
.layer(OutputLayer.builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()) // nIn of the next layer gets multiplied by 2*2.
|
.layer(OutputLayer.builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()) // nIn of the next layer gets multiplied by 2*2.
|
||||||
.inputType(InputType.convolutional(28, 28, 1));
|
.inputType(InputType.convolutional(28, 28, 1));
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -71,7 +71,7 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest {
|
||||||
.layer(LocallyConnected2D.builder().kernelSize(8, 8).nIn(3)
|
.layer(LocallyConnected2D.builder().kernelSize(8, 8).nIn(3)
|
||||||
.stride(4, 4).nOut(16).dropOut(0.5)
|
.stride(4, 4).nOut(16).dropOut(0.5)
|
||||||
.convolutionMode(ConvolutionMode.Strict)
|
.convolutionMode(ConvolutionMode.Strict)
|
||||||
.setInputSize(28, 28)
|
.inputSize(28, 28)
|
||||||
.activation(Activation.RELU).weightInit(
|
.activation(Activation.RELU).weightInit(
|
||||||
WeightInit.XAVIER)
|
WeightInit.XAVIER)
|
||||||
.build())
|
.build())
|
||||||
|
@ -94,11 +94,10 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123)
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4)
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4)
|
||||||
.updater(new Nesterovs(0.9)).dropOut(0.5)
|
.updater(new Nesterovs(0.9)).dropOut(0.5)
|
||||||
.list()
|
|
||||||
.layer(LocallyConnected1D.builder().kernelSize(4).nIn(3)
|
.layer(LocallyConnected1D.builder().kernelSize(4).nIn(3)
|
||||||
.stride(1).nOut(16).dropOut(0.5)
|
.stride(1).nOut(16).dropOut(0.5)
|
||||||
.convolutionMode(ConvolutionMode.Strict)
|
.convolutionMode(ConvolutionMode.Strict)
|
||||||
.setInputSize(28)
|
.inputSize(28)
|
||||||
.activation(Activation.RELU).weightInit(
|
.activation(Activation.RELU).weightInit(
|
||||||
WeightInit.XAVIER)
|
WeightInit.XAVIER)
|
||||||
.build())
|
.build())
|
||||||
|
|
|
@ -61,7 +61,7 @@ public class SpaceToDepthTest extends BaseDL4JTest {
|
||||||
private Layer getSpaceToDepthLayer() {
|
private Layer getSpaceToDepthLayer() {
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123)
|
||||||
.layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build();
|
.layer(SpaceToDepthLayer.builder().blockSize(blockSize).dataFormat(dataFormat.toFormat()).build()).build();
|
||||||
return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType());
|
return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.layers.custom;
|
package org.deeplearning4j.nn.layers.custom;
|
||||||
|
|
||||||
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
@ -30,31 +32,38 @@ import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.learning.config.Sgd;
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
|
||||||
|
|
||||||
public class TestCustomActivation extends BaseDL4JTest {
|
public class TestCustomActivation extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCustomActivationFn() {
|
public void testCustomActivationFn() {
|
||||||
//Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works...
|
// Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config
|
||||||
|
// actually works...
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)).list()
|
NeuralNetConfiguration conf =
|
||||||
.layer(0, DenseLayer.builder().nIn(10).nOut(10).activation(new CustomActivation()).build())
|
NeuralNetConfiguration.builder()
|
||||||
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build())
|
.updater(new Sgd(0.1))
|
||||||
|
|
||||||
|
.layer(
|
||||||
|
0, DenseLayer.builder().nIn(10).nOut(10).activation(new CustomActivation()).build())
|
||||||
|
.layer(
|
||||||
|
1,
|
||||||
|
OutputLayer.builder()
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nIn(10)
|
||||||
|
.nOut(10)
|
||||||
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
String json = conf.toJson();
|
String json = conf.toJson();
|
||||||
String yaml = conf.toYaml();
|
String yaml = conf.toYaml();
|
||||||
|
|
||||||
// System.out.println(json);
|
// System.out.println(json);
|
||||||
|
|
||||||
NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json);
|
NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json);
|
||||||
assertEquals(conf, confFromJson);
|
assertEquals(conf, confFromJson);
|
||||||
|
|
||||||
NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml);
|
NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml);
|
||||||
assertEquals(conf, confFromYaml);
|
assertEquals(conf, confFromYaml);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,7 +119,7 @@ public class TestCustomLayers extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration conf =
|
NeuralNetConfiguration conf =
|
||||||
NeuralNetConfiguration.builder().seed(12345).list()
|
NeuralNetConfiguration.builder().seed(12345).list()
|
||||||
.layer(0, DenseLayer.builder().nIn(10).nOut(10).build())
|
.layer(0, DenseLayer.builder().nIn(10).nOut(10).build())
|
||||||
.layer(1, new CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
.layer(1, CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX)
|
.activation(Activation.SOFTMAX)
|
||||||
.nIn(10).nOut(10).build())
|
.nIn(10).nOut(10).build())
|
||||||
.build();
|
.build();
|
||||||
|
@ -172,7 +172,7 @@ public class TestCustomLayers extends BaseDL4JTest {
|
||||||
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345)
|
||||||
.graphBuilder().addInputs("in")
|
.graphBuilder().addInputs("in")
|
||||||
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in").addLayer("1",
|
.addLayer("0", DenseLayer.builder().nIn(10).nOut(10).build(), "in").addLayer("1",
|
||||||
new CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10)
|
CustomOutputLayer.builder().lossFunction(LossFunctions.LossFunction.MCXENT).nIn(10)
|
||||||
.nOut(10).activation(Activation.SOFTMAX).build(),
|
.nOut(10).activation(Activation.SOFTMAX).build(),
|
||||||
"0")
|
"0")
|
||||||
.setOutputs("1").build();
|
.setOutputs("1").build();
|
||||||
|
|
|
@ -91,8 +91,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
||||||
.l2(0.01)
|
.l2(0.01)
|
||||||
.list()
|
.list()
|
||||||
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1,1).build())
|
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1,1).build())
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(bbPrior)
|
.boundingBoxes(bbPrior)
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -179,8 +179,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.list()
|
.list()
|
||||||
.layer(ConvolutionLayer.builder().nIn(1).nOut(1).kernelSize(1,1).build())
|
.layer(ConvolutionLayer.builder().nIn(1).nOut(1).kernelSize(1,1).build())
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(bbPrior)
|
.boundingBoxes(bbPrior)
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -337,8 +337,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.list()
|
.list()
|
||||||
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nIn(3).nOut(3).build())
|
.layer(ConvolutionLayer.builder().kernelSize(3,3).stride(1,1).nIn(3).nOut(3).build())
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(bbPriors)
|
.boundingBoxes(bbPriors)
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -506,8 +506,8 @@ public class TestYolo2OutputLayer extends BaseDL4JTest {
|
||||||
.layer(ConvolutionLayer.builder().kernelSize(5,5).stride(2,2).nOut(256).build())
|
.layer(ConvolutionLayer.builder().kernelSize(5,5).stride(2,2).nOut(256).build())
|
||||||
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2)/*.poolingType(SubsamplingLayer.PoolingType.AVG)*/.build())
|
.layer(SubsamplingLayer.builder().kernelSize(2,2).stride(2,2)/*.poolingType(SubsamplingLayer.PoolingType.AVG)*/.build())
|
||||||
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(5,5).stride(1,1).nOut(depthOut).build())
|
.layer(ConvolutionLayer.builder().activation(Activation.IDENTITY).kernelSize(5,5).stride(1,1).nOut(depthOut).build())
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(bbPriors)
|
.boundingBoxes(bbPriors)
|
||||||
.build())
|
.build())
|
||||||
.inputType(InputType.convolutional(h,w,c))
|
.inputType(InputType.convolutional(h,w,c))
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -209,7 +209,7 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
||||||
return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3)
|
return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3)
|
||||||
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
.dataFormat(format).build(), format, lastTimeStep, maskZeros);
|
||||||
} else {
|
} else {
|
||||||
return getNetWithLayer(new GravesBidirectionalLSTM.Builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
return getNetWithLayer(GravesBidirectionalLSTM.builder().nOut(3).build(), format, lastTimeStep, maskZeros);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
private MultiLayerNetwork getGravesLstmNet(RNNFormat format, boolean setOnLayerAlso, boolean lastTimeStep, boolean maskZeros) {
|
||||||
|
@ -240,7 +240,7 @@ public class RnnDataFormatTests extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
private MultiLayerNetwork getNetWithLayer(LayerConfiguration layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
|
private MultiLayerNetwork getNetWithLayer(LayerConfiguration layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) {
|
||||||
if (maskZeros){
|
if (maskZeros){
|
||||||
layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build();
|
layer = MaskZeroLayer.builder().maskingValue(0.).underlying(layer).build();
|
||||||
}
|
}
|
||||||
if(lastTimeStep){
|
if(lastTimeStep){
|
||||||
layer = new LastTimeStep(layer);
|
layer = new LastTimeStep(layer);
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.layers.GravesLSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.LSTM;
|
import org.deeplearning4j.nn.conf.layers.LSTM;
|
||||||
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInitDistribution;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
@ -48,17 +49,17 @@ public class TestRecurrentWeightInit extends BaseDL4JTest {
|
||||||
switch (i) {
|
switch (i) {
|
||||||
case 0:
|
case 0:
|
||||||
b.layer(LSTM.builder().nIn(10).nOut(10)
|
b.layer(LSTM.builder().nIn(10).nOut(10)
|
||||||
.weightInitRecurrent(new UniformDistribution(2, 3))
|
.weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3)))
|
||||||
.build());
|
.build());
|
||||||
break;
|
break;
|
||||||
case 1:
|
case 1:
|
||||||
b.layer(GravesLSTM.builder().nIn(10).nOut(10)
|
b.layer(GravesLSTM.builder().nIn(10).nOut(10)
|
||||||
.weightInitRecurrent(new UniformDistribution(2, 3))
|
.weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3)))
|
||||||
.build());
|
.build());
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
b.layer(SimpleRnn.builder().nIn(10).nOut(10)
|
b.layer(SimpleRnn.builder().nIn(10).nOut(10)
|
||||||
.weightInitRecurrent(new UniformDistribution(2, 3)).build());
|
.weightInitRecurrent(new WeightInitDistribution(new UniformDistribution(2, 3))).build());
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
|
|
@ -145,8 +145,8 @@ public class TestTimeDistributed extends BaseDL4JTest {
|
||||||
l2 = SimpleRnn.builder().nOut(5).build();
|
l2 = SimpleRnn.builder().nOut(5).build();
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
l0 = Bidirectional.builder(LSTM.builder().nOut(5).build());
|
l0 = Bidirectional.builder(LSTM.builder().nOut(5).build()).build();
|
||||||
l2 = Bidirectional.builder(LSTM.builder().nOut(5).build());
|
l2 = Bidirectional.builder(LSTM.builder().nOut(5).build()).build();
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException("Not implemented: " + rnnType);
|
throw new RuntimeException("Not implemented: " + rnnType);
|
||||||
|
|
|
@ -67,7 +67,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffConv.Builder().nIn(nIn).nOut(nOut).kernelSize(kH, kW).build())
|
.layer(SameDiffConv.builder().nIn(nIn).nOut(nOut).kernelSize(kH, kW).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -131,7 +131,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffConv.Builder()
|
.layer(SameDiffConv.builder()
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.nIn(nIn)
|
.nIn(nIn)
|
||||||
.nOut(nOut)
|
.nOut(nOut)
|
||||||
|
@ -142,7 +142,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
||||||
.activation(a)
|
.activation(a)
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
.build())
|
.build())
|
||||||
.layer(new SameDiffConv.Builder()
|
.layer(SameDiffConv.builder()
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.nIn(nOut)
|
.nIn(nOut)
|
||||||
.nOut(nOut)
|
.nOut(nOut)
|
||||||
|
@ -273,7 +273,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
||||||
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||||
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffConv.Builder()
|
.layer(SameDiffConv.builder()
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.nIn(nIn)
|
.nIn(nIn)
|
||||||
.nOut(nOut)
|
.nOut(nOut)
|
||||||
|
@ -284,7 +284,7 @@ public class TestSameDiffConv extends BaseDL4JTest {
|
||||||
.activation(Activation.TANH)
|
.activation(Activation.TANH)
|
||||||
.hasBias(hasBias)
|
.hasBias(hasBias)
|
||||||
.build())
|
.build())
|
||||||
.layer(new SameDiffConv.Builder()
|
.layer(SameDiffConv.builder()
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.nIn(nOut)
|
.nIn(nOut)
|
||||||
.nOut(nOut)
|
.nOut(nOut)
|
||||||
|
|
|
@ -65,7 +65,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build())
|
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
@ -106,7 +106,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
.inferenceWorkspaceMode(wsm)
|
.inferenceWorkspaceMode(wsm)
|
||||||
.trainingWorkspaceMode(wsm)
|
.trainingWorkspaceMode(wsm)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
|
||||||
.activation(a)
|
.activation(a)
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
@ -178,10 +178,10 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.seed(12345)
|
.seed(12345)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(a).build())
|
.activation(a).build())
|
||||||
.layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut)
|
.layer(SameDiffDense.builder().nIn(nOut).nOut(nOut)
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.activation(a).build())
|
.activation(a).build())
|
||||||
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut)
|
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut)
|
||||||
|
@ -267,7 +267,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||||
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
|
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut)
|
||||||
.activation(a)
|
.activation(a)
|
||||||
.build())
|
.build())
|
||||||
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
|
@ -357,8 +357,8 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
.inferenceWorkspaceMode(wsm)
|
.inferenceWorkspaceMode(wsm)
|
||||||
.updater(new Adam(0.1))
|
.updater(new Adam(0.1))
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
|
.layer(SameDiffDense.builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
|
||||||
.layer(new SameDiffDense.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
.layer(SameDiffDense.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||||
.layer(OutputLayer.builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
.build();
|
.build();
|
||||||
|
@ -428,8 +428,8 @@ public class TestSameDiffDense extends BaseDL4JTest {
|
||||||
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||||
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
|
||||||
.list()
|
.list()
|
||||||
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).activation(a).build())
|
.layer(SameDiffDense.builder().nIn(nIn).nOut(nOut).activation(a).build())
|
||||||
.layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut).activation(a).build())
|
.layer(SameDiffDense.builder().nIn(nOut).nOut(nOut).activation(a).build())
|
||||||
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
//.inputType(InputType.feedForward(nIn)) //TODO
|
//.inputType(InputType.feedForward(nIn)) //TODO
|
||||||
|
|
|
@ -60,7 +60,7 @@ public class TestSameDiffOutput extends BaseDL4JTest {
|
||||||
.updater(new Adam(0.01))
|
.updater(new Adam(0.01))
|
||||||
.list()
|
.list()
|
||||||
.layer(DenseLayer.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
.layer(DenseLayer.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
|
||||||
.layer(LossLayer.builder().lossFunction().activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build())
|
.layer(LossLayer.builder().activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE.getILossFunction()).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork netSD = new MultiLayerNetwork(confSD);
|
MultiLayerNetwork netSD = new MultiLayerNetwork(confSD);
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -45,52 +46,62 @@ import java.util.*;
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
@JsonIgnoreProperties({"paramShapes"})
|
@JsonIgnoreProperties({"paramShapes"})
|
||||||
|
@NoArgsConstructor
|
||||||
|
@SuperBuilder
|
||||||
public class SameDiffConv extends SameDiffLayer {
|
public class SameDiffConv extends SameDiffLayer {
|
||||||
|
public static abstract class SameDiffConvBuilder<C extends SameDiffConv, B extends SameDiffConvBuilder<C, B>> extends
|
||||||
|
SameDiffLayerBuilder<C, B> {
|
||||||
|
public B kernelSize(int... k) {
|
||||||
|
this.kernelSize$value = k;
|
||||||
|
this.kernelSize$set = true;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B stride(int... s) {
|
||||||
|
this.stride$value = s;
|
||||||
|
this.stride$set = true;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
|
public B padding(int... p) {
|
||||||
|
this.padding$value = p;
|
||||||
|
this.padding$set = true;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static final List<String> WEIGHT_KEYS = Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY);
|
private static final List<String> WEIGHT_KEYS = Collections.singletonList(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
private static final List<String> BIAS_KEYS = Collections.singletonList(ConvolutionParamInitializer.BIAS_KEY);
|
private static final List<String> BIAS_KEYS = Collections.singletonList(ConvolutionParamInitializer.BIAS_KEY);
|
||||||
//Order to match 'vanilla' conv layer implementation, for easy comparison
|
//Order to match 'vanilla' conv layer implementation, for easy comparison
|
||||||
private static final List<String> PARAM_KEYS = Arrays.asList(ConvolutionParamInitializer.BIAS_KEY, ConvolutionParamInitializer.WEIGHT_KEY);
|
private static final List<String> PARAM_KEYS = Arrays.asList(ConvolutionParamInitializer.BIAS_KEY, ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
|
|
||||||
private long nIn;
|
|
||||||
private long nOut;
|
|
||||||
private Activation activation;
|
|
||||||
private int[] kernel;
|
|
||||||
private int[] stride;
|
|
||||||
private int[] padding;
|
|
||||||
private ConvolutionMode cm;
|
|
||||||
private int[] dilation;
|
|
||||||
private boolean hasBias;
|
|
||||||
|
|
||||||
protected SameDiffConv(Builder b) {
|
private int nIn;
|
||||||
super(b);
|
private int nOut;
|
||||||
this.nIn = b.nIn;
|
@Builder.Default private Activation activation = Activation.TANH;
|
||||||
this.nOut = b.nOut;
|
@Builder.Default private int[] kernelSize = new int[]{2, 2};
|
||||||
this.activation = b.activation;
|
|
||||||
this.kernel = b.kernel;
|
@Builder.Default private int[] stride = new int[]{1, 1};
|
||||||
this.stride = b.stride;
|
@Builder.Default private int[] padding = new int[]{0, 0};
|
||||||
this.padding = b.padding;
|
@Builder.Default private int[] dilation = new int[]{1, 1};
|
||||||
this.cm = b.cm;
|
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Same;
|
||||||
this.dilation = b.dilation;
|
@Builder.Default private boolean hasBias = true;
|
||||||
this.hasBias = b.hasBias;
|
|
||||||
}
|
|
||||||
|
|
||||||
private SameDiffConv(){
|
|
||||||
//No arg constructor for Jackson/JSON serialization
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernel, stride, padding, new int[]{1, 1},
|
return InputTypeUtil.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, new int[]{1, 1},
|
||||||
cm, nOut, layerIndex, getName(), SameDiffConv.class);
|
convolutionMode, nOut, layerIndex, getName(), SameDiffConv.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
if (nIn <= 0 || override) {
|
if (nIn <= 0 || override) {
|
||||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
this.nIn = c.getChannels();
|
this.nIn = (int) c.getChannels();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +113,7 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public void defineParameters(SDLayerParams params) {
|
public void defineParameters(SDLayerParams params) {
|
||||||
params.clear();
|
params.clear();
|
||||||
val weightsShape = new long[]{kernel[0], kernel[1], nIn, nOut}; //[kH, kW, iC, oC] in libnd4j
|
val weightsShape = new long[]{kernelSize[0], kernelSize[1], nIn, nOut}; //[kH, kW, iC, oC] in libnd4j
|
||||||
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
|
||||||
if(hasBias) {
|
if(hasBias) {
|
||||||
val biasShape = new long[]{1, nOut};
|
val biasShape = new long[]{1, nOut};
|
||||||
|
@ -113,8 +124,8 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public void initializeParameters(Map<String, INDArray> params) {
|
public void initializeParameters(Map<String, INDArray> params) {
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
double fanIn = nIn * kernel[0] * kernel[1];
|
double fanIn = nIn * kernelSize[0] * kernelSize[1];
|
||||||
double fanOut = nOut * kernel[0] * kernel[1] / ((double) stride[0] * stride[1]);
|
double fanOut = nOut * kernelSize[0] * kernelSize[1] / ((double) stride[0] * stride[1]);
|
||||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||||
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
|
if(paramWeightInit != null && paramWeightInit.containsKey(e.getKey())){
|
||||||
paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
|
paramWeightInit.get(e.getKey()).init(fanIn, fanOut, e.getValue().shape(), 'c', e.getValue());
|
||||||
|
@ -135,11 +146,11 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
SDVariable w = paramTable.get(ConvolutionParamInitializer.WEIGHT_KEY);
|
||||||
|
|
||||||
Conv2DConfig c = Conv2DConfig.builder()
|
Conv2DConfig c = Conv2DConfig.builder()
|
||||||
.kH(kernel[0]).kW(kernel[1])
|
.kH(kernelSize[0]).kW(kernelSize[1])
|
||||||
.pH(padding[0]).pW(padding[1])
|
.pH(padding[0]).pW(padding[1])
|
||||||
.sH(stride[0]).sW(stride[1])
|
.sH(stride[0]).sW(stride[1])
|
||||||
.dH(dilation[0]).dW(dilation[1])
|
.dH(dilation[0]).dW(dilation[1])
|
||||||
.isSameMode(this.cm == ConvolutionMode.Same)
|
.isSameMode(this.convolutionMode == ConvolutionMode.Same)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable conv = null;
|
SDVariable conv = null;
|
||||||
|
@ -159,72 +170,10 @@ public class SameDiffConv extends SameDiffLayer {
|
||||||
if (activation == null) {
|
if (activation == null) {
|
||||||
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
|
activation = SameDiffLayerUtils.fromIActivation(clone.getActivation());
|
||||||
}
|
}
|
||||||
if (cm == null) {
|
if (convolutionMode == null) {
|
||||||
cm = clone.getConvolutionMode();
|
convolutionMode = clone.getConvolutionMode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder extends SameDiffLayer.Builder<Builder> {
|
|
||||||
|
|
||||||
private int nIn;
|
|
||||||
private int nOut;
|
|
||||||
private Activation activation = Activation.TANH;
|
|
||||||
private int[] kernel = new int[]{2, 2};
|
|
||||||
|
|
||||||
private int[] stride = new int[]{1, 1};
|
|
||||||
private int[] padding = new int[]{0, 0};
|
|
||||||
private int[] dilation = new int[]{1, 1};
|
|
||||||
private ConvolutionMode cm = ConvolutionMode.Same;
|
|
||||||
private boolean hasBias = true;
|
|
||||||
|
|
||||||
public Builder nIn(int nIn) {
|
|
||||||
this.nIn = nIn;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder nOut(int nOut) {
|
|
||||||
this.nOut = nOut;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder activation(Activation activation) {
|
|
||||||
this.activation = activation;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder kernelSize(int... k) {
|
|
||||||
this.kernel = k;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder stride(int... s) {
|
|
||||||
this.stride = s;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder padding(int... p) {
|
|
||||||
this.padding = p;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder convolutionMode(ConvolutionMode cm) {
|
|
||||||
this.cm = cm;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder dilation(int... d) {
|
|
||||||
this.dilation = d;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder hasBias(boolean hasBias){
|
|
||||||
this.hasBias = hasBias;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiffConv build() {
|
|
||||||
return new SameDiffConv(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,8 @@ package org.deeplearning4j.nn.layers.samediff.testlayers;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -40,30 +42,22 @@ import java.util.*;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true, exclude = {"paramShapes"})
|
@EqualsAndHashCode(callSuper = true, exclude = {"paramShapes"})
|
||||||
|
@NoArgsConstructor()
|
||||||
@JsonIgnoreProperties("paramShapes")
|
@JsonIgnoreProperties("paramShapes")
|
||||||
|
@SuperBuilder
|
||||||
public class SameDiffDense extends SameDiffLayer {
|
public class SameDiffDense extends SameDiffLayer {
|
||||||
|
|
||||||
private static final List<String> W_KEYS = Collections.singletonList(DefaultParamInitializer.WEIGHT_KEY);
|
private static final List<String> W_KEYS = Collections.singletonList(DefaultParamInitializer.WEIGHT_KEY);
|
||||||
private static final List<String> B_KEYS = Collections.singletonList(DefaultParamInitializer.BIAS_KEY);
|
private static final List<String> B_KEYS = Collections.singletonList(DefaultParamInitializer.BIAS_KEY);
|
||||||
private static final List<String> PARAM_KEYS = Arrays.asList(DefaultParamInitializer.WEIGHT_KEY, DefaultParamInitializer.BIAS_KEY);
|
private static final List<String> PARAM_KEYS = Arrays.asList(DefaultParamInitializer.WEIGHT_KEY, DefaultParamInitializer.BIAS_KEY);
|
||||||
|
|
||||||
private Map<String,long[]> paramShapes;
|
private final Map<String, long[]> paramShapes = new HashMap<>();
|
||||||
|
|
||||||
private long nIn;
|
private long nIn;
|
||||||
private long nOut;
|
private long nOut;
|
||||||
private Activation activation;
|
private Activation activation;
|
||||||
|
|
||||||
protected SameDiffDense(Builder builder) {
|
|
||||||
super(builder);
|
|
||||||
|
|
||||||
nIn = builder.nIn;
|
|
||||||
nOut = builder.nOut;
|
|
||||||
activation = builder.activation;
|
|
||||||
}
|
|
||||||
|
|
||||||
private SameDiffDense(){
|
|
||||||
//No op constructor for Jackson
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
|
@ -128,31 +122,5 @@ public class SameDiffDense extends SameDiffLayer {
|
||||||
return 'f';
|
return 'f';
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class Builder extends SameDiffLayer.Builder<Builder> {
|
|
||||||
|
|
||||||
private int nIn;
|
|
||||||
private int nOut;
|
|
||||||
|
|
||||||
private Activation activation;
|
|
||||||
|
|
||||||
public Builder nIn(int nIn){
|
|
||||||
this.nIn = nIn;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder nOut(int nOut){
|
|
||||||
this.nOut = nOut;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder activation(Activation activation){
|
|
||||||
this.activation = activation;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SameDiffDense build() {
|
|
||||||
return new SameDiffDense(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ public class TestVAE extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration mlc =
|
NeuralNetConfiguration mlc =
|
||||||
NeuralNetConfiguration.builder()
|
NeuralNetConfiguration.builder()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.nIn(10).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)
|
.nIn(10).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13)
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
@ -95,7 +95,7 @@ public class TestVAE extends BaseDL4JTest {
|
||||||
for (int i = 0; i < encLayerSizes.length; i++) {
|
for (int i = 0; i < encLayerSizes.length; i++) {
|
||||||
|
|
||||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list().layer(0,
|
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list().layer(0,
|
||||||
new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder().nIn(10)
|
org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder().nIn(10)
|
||||||
.nOut(5).encoderLayerSizes(encLayerSizes[i]).decoderLayerSizes(13).build())
|
.nOut(5).encoderLayerSizes(encLayerSizes[i]).decoderLayerSizes(13).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -121,7 +121,7 @@ public class TestVAE extends BaseDL4JTest {
|
||||||
int inputSize = 3;
|
int inputSize = 3;
|
||||||
|
|
||||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.nIn(inputSize).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
|
.nIn(inputSize).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ public class TestVAE extends BaseDL4JTest {
|
||||||
public void testParamGradientOrderAndViews() {
|
public void testParamGradientOrderAndViews() {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
|
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -217,7 +217,7 @@ public class TestVAE extends BaseDL4JTest {
|
||||||
|
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().seed(12345).list()
|
NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().seed(12345).list()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
|
.nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build())
|
||||||
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(5).nOut(6)
|
.layer(1, OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(5).nOut(6)
|
||||||
.activation(new ActivationTanH()).build())
|
.activation(new ActivationTanH()).build())
|
||||||
|
@ -269,22 +269,22 @@ public class TestVAE extends BaseDL4JTest {
|
||||||
public void testJsonYaml() {
|
public void testJsonYaml() {
|
||||||
|
|
||||||
NeuralNetConfiguration config = NeuralNetConfiguration.builder().seed(12345).list()
|
NeuralNetConfiguration config = NeuralNetConfiguration.builder().seed(12345).list()
|
||||||
.layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(0, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.IDENTITY))
|
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.IDENTITY))
|
||||||
.nIn(3).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
|
.nIn(3).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build())
|
||||||
.layer(1, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(1, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.TANH))
|
.reconstructionDistribution(new GaussianReconstructionDistribution(Activation.TANH))
|
||||||
.nIn(7).nOut(8).encoderLayerSizes(9).decoderLayerSizes(10).build())
|
.nIn(7).nOut(8).encoderLayerSizes(9).decoderLayerSizes(10).build())
|
||||||
.layer(2, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(2, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.reconstructionDistribution(new BernoulliReconstructionDistribution()).nIn(11)
|
.reconstructionDistribution(new BernoulliReconstructionDistribution()).nIn(11)
|
||||||
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
||||||
.layer(3, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(3, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.reconstructionDistribution(new ExponentialReconstructionDistribution(Activation.TANH))
|
.reconstructionDistribution(new ExponentialReconstructionDistribution(Activation.TANH))
|
||||||
.nIn(11).nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
.nIn(11).nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
||||||
.layer(4, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(4, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.lossFunction(new ActivationTanH(), LossFunctions.LossFunction.MSE).nIn(11)
|
.lossFunction(new ActivationTanH(), LossFunctions.LossFunction.MSE).nIn(11)
|
||||||
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
.nOut(12).encoderLayerSizes(13).decoderLayerSizes(14).build())
|
||||||
.layer(5, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder()
|
.layer(5, org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.builder()
|
||||||
.reconstructionDistribution(new CompositeReconstructionDistribution.Builder()
|
.reconstructionDistribution(new CompositeReconstructionDistribution.Builder()
|
||||||
.addDistribution(5, new GaussianReconstructionDistribution())
|
.addDistribution(5, new GaussianReconstructionDistribution())
|
||||||
.addDistribution(5,
|
.addDistribution(5,
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class TestMemoryReports extends BaseDL4JTest {
|
||||||
l.add(new Pair<>(DropoutLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
|
l.add(new Pair<>(DropoutLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
|
||||||
l.add(new Pair<>(EmbeddingLayer.builder().nIn(1).nOut(20).build(), InputType.feedForward(20)));
|
l.add(new Pair<>(EmbeddingLayer.builder().nIn(1).nOut(20).build(), InputType.feedForward(20)));
|
||||||
l.add(new Pair<>(OutputLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
|
l.add(new Pair<>(OutputLayer.builder().nIn(20).nOut(20).build(), InputType.feedForward(20)));
|
||||||
l.add(new Pair<>(LossLayer.builder().lossFunction().build(), InputType.feedForward(20)));
|
l.add(new Pair<>(LossLayer.builder().build(), InputType.feedForward(20)));
|
||||||
|
|
||||||
//RNN layers:
|
//RNN layers:
|
||||||
l.add(new Pair<>(GravesLSTM.builder().nIn(20).nOut(20).build(), InputType.recurrent(20, 30)));
|
l.add(new Pair<>(GravesLSTM.builder().nIn(20).nOut(20).build(), InputType.recurrent(20, 30)));
|
||||||
|
|
|
@ -469,7 +469,7 @@ public class WorkspaceTests extends BaseDL4JTest {
|
||||||
.addLayer("a", GravesLSTM.builder().nOut(300).activation(Activation.HARDTANH).build(), "embeddings")
|
.addLayer("a", GravesLSTM.builder().nOut(300).activation(Activation.HARDTANH).build(), "embeddings")
|
||||||
.addVertex("b", new LastTimeStepVertex("in"), "a")
|
.addVertex("b", new LastTimeStepVertex("in"), "a")
|
||||||
.addLayer("c", DenseLayer.builder().nOut(300).activation(Activation.HARDTANH).build(), "b")
|
.addLayer("c", DenseLayer.builder().nOut(300).activation(Activation.HARDTANH).build(), "b")
|
||||||
.addLayer("output", LossLayer.builder().lossFunction().lossFunction(LossFunctions.LossFunction.COSINE_PROXIMITY).build(), "c")
|
.addLayer("output", LossLayer.builder().lossFunction(LossFunctions.LossFunction.COSINE_PROXIMITY.getILossFunction()).build(), "c")
|
||||||
.setOutputs("output")
|
.setOutputs("output")
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -1455,10 +1455,10 @@ public class MultiLayerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
|
||||||
.l2(0.01)
|
.l2(0.01)
|
||||||
.list()
|
|
||||||
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1, 1).build())
|
.layer(ConvolutionLayer.builder().nIn(depth).nOut(depth).kernelSize(1, 1).build())
|
||||||
.layer(new Yolo2OutputLayer.Builder()
|
.layer(Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(bbPrior)
|
.boundingBoxes(bbPrior)
|
||||||
.build())
|
.build())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
|
|
@ -500,10 +500,10 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
.addInputs(inputName)
|
.addInputs(inputName)
|
||||||
.setOutputs(outputName)
|
.setOutputs(outputName)
|
||||||
.setInputTypes(InputType.inferInputTypes(input))
|
.setInputTypes(InputType.inferInputTypes(input))
|
||||||
.addLayer(firstConv, new Convolution2D.Builder(3, 3)
|
.addLayer(firstConv, Convolution2D.builder(3, 3)
|
||||||
.nOut(10)
|
.nOut(10)
|
||||||
.build(), inputName)
|
.build(), inputName)
|
||||||
.addLayer(secondConv, new Convolution2D.Builder(1, 1)
|
.addLayer(secondConv, Convolution2D.builder(1, 1)
|
||||||
.nOut(3)
|
.nOut(3)
|
||||||
.build(), firstConv)
|
.build(), firstConv)
|
||||||
.addLayer(outputName, OutputLayer.builder()
|
.addLayer(outputName, OutputLayer.builder()
|
||||||
|
@ -546,11 +546,11 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
.addInputs(inputName)
|
.addInputs(inputName)
|
||||||
.setOutputs(outputName)
|
.setOutputs(outputName)
|
||||||
.setInputTypes(InputType.inferInputTypes(input))
|
.setInputTypes(InputType.inferInputTypes(input))
|
||||||
.addLayer(changeNoutName, new Convolution2D.Builder(1, 1)
|
.addLayer(changeNoutName, Convolution2D.builder(1, 1)
|
||||||
.nOut(10)
|
.nOut(10)
|
||||||
.build(), inputName)
|
.build(), inputName)
|
||||||
.addLayer(poolName, SubsamplingLayer.builder(1,1).build(), changeNoutName)
|
.addLayer(poolName, SubsamplingLayer.builder(1,1).build(), changeNoutName)
|
||||||
.addLayer(afterPoolName, new Convolution2D.Builder(1, 1)
|
.addLayer(afterPoolName, Convolution2D.builder(1, 1)
|
||||||
.nOut(7)
|
.nOut(7)
|
||||||
.build(), poolName)
|
.build(), poolName)
|
||||||
.addLayer(outputName, OutputLayer.builder()
|
.addLayer(outputName, OutputLayer.builder()
|
||||||
|
@ -583,7 +583,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("l0", LSTM.builder().nIn(5).nOut(5).build(), "in")
|
.layer("l0", LSTM.builder().nIn(5).nOut(5).build(), "in")
|
||||||
.layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
|
.layer("l1", RecurrentAttentionLayer.builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0")
|
||||||
.layer("out", RnnOutputLayer.builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
.layer("out", RnnOutputLayer.builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1")
|
||||||
.setOutputs("out")
|
.setOutputs("out")
|
||||||
.build();
|
.build();
|
||||||
|
|
|
@ -28,7 +28,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
|
import org.deeplearning4j.nn.conf.graph.SubsetVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
@ -214,9 +213,9 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
|
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
|
||||||
(NeuralNetConfiguration) overallConf.clone()
|
(NeuralNetConfiguration) overallConf.clone()
|
||||||
.layer(0, new Builder().nIn(4).nOut(3).build())
|
.layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
|
||||||
.layer(1, new Builder().nIn(3).nOut(2).build())
|
.layer(1, DenseLayer.builder().nIn(3).nOut(2).build())
|
||||||
.layer(2, new Builder().nIn(2).nOut(3).build())
|
.layer(2, DenseLayer.builder().nIn(2).nOut(3).build())
|
||||||
.layer(3, OutputLayer.builder(
|
.layer(3, OutputLayer.builder(
|
||||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||||
.build())
|
.build())
|
||||||
|
@ -233,7 +232,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest {
|
||||||
Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
|
Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams());
|
||||||
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
MultiLayerNetwork notFrozen = new MultiLayerNetwork(
|
||||||
(NeuralNetConfiguration) overallConf.clone().list()
|
(NeuralNetConfiguration) overallConf.clone().list()
|
||||||
.layer(0, new Builder().nIn(2).nOut(3).build())
|
.layer(0, DenseLayer.builder().nIn(2).nOut(3).build())
|
||||||
.layer(1, OutputLayer.builder(
|
.layer(1, OutputLayer.builder(
|
||||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||||
.build())
|
.build())
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.deeplearning4j.nn.conf.distribution.ConstantDistribution;
|
||||||
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder;
|
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
|
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
|
||||||
|
@ -74,7 +74,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
|
|
||||||
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
|
MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(
|
||||||
(NeuralNetConfiguration) confToChange.list()
|
(NeuralNetConfiguration) confToChange.list()
|
||||||
.layer(0, new Builder().nIn(4).nOut(3).build())
|
.layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
|
||||||
.layer(1, OutputLayer.builder(
|
.layer(1, OutputLayer.builder(
|
||||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||||
.build())
|
.build())
|
||||||
|
@ -101,7 +101,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
.updater(new RmsProp(0.5)).l2(0.4);
|
.updater(new RmsProp(0.5)).l2(0.4);
|
||||||
|
|
||||||
MultiLayerNetwork expectedModel = new MultiLayerNetwork((NeuralNetConfiguration) confSet.list()
|
MultiLayerNetwork expectedModel = new MultiLayerNetwork((NeuralNetConfiguration) confSet.list()
|
||||||
.layer(0, new Builder().nIn(4).nOut(3).build())
|
.layer(0, DenseLayer.builder().nIn(4).nOut(3).build())
|
||||||
.layer(1, OutputLayer.builder(
|
.layer(1, OutputLayer.builder(
|
||||||
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3)
|
||||||
.build())
|
.build())
|
||||||
|
@ -651,8 +651,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
.weightInit(new ConstantDistribution(666))
|
.weightInit(new ConstantDistribution(666))
|
||||||
.list()
|
.list()
|
||||||
.inputType(InputType.inferInputTypes(input)[0])
|
.inputType(InputType.inferInputTypes(input)[0])
|
||||||
.layer(new Convolution2D.Builder(3, 3).nOut(10).build())
|
.layer(Convolution2D.builder(3, 3).nOut(10).build())
|
||||||
.layer(new Convolution2D.Builder(1, 1).nOut(3).build())
|
.layer(Convolution2D.builder(1, 1).nOut(3).build())
|
||||||
.layer(OutputLayer.builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE)
|
.layer(OutputLayer.builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE)
|
||||||
.build()).build());
|
.build()).build());
|
||||||
net.init();
|
net.init();
|
||||||
|
@ -682,9 +682,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork( NeuralNetConfiguration.builder()
|
MultiLayerNetwork net = new MultiLayerNetwork( NeuralNetConfiguration.builder()
|
||||||
.list()
|
.list()
|
||||||
.inputType(InputType.inferInputTypes(input)[0])
|
.inputType(InputType.inferInputTypes(input)[0])
|
||||||
.layer(new Convolution2D.Builder(1, 1).nOut(10).build())
|
.layer(Convolution2D.builder(1, 1).nOut(10).build())
|
||||||
.layer(SubsamplingLayer.builder(1,1).build())
|
.layer(SubsamplingLayer.builder(1,1).build())
|
||||||
.layer(new Convolution2D.Builder(1, 1).nOut(7).build())
|
.layer(Convolution2D.builder(1, 1).nOut(7).build())
|
||||||
.layer(OutputLayer.builder().activation(Activation.SOFTMAX).nOut(2).build())
|
.layer(OutputLayer.builder().activation(Activation.SOFTMAX).nOut(2).build())
|
||||||
.build());
|
.build());
|
||||||
net.init();
|
net.init();
|
||||||
|
@ -712,7 +712,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest {
|
||||||
.weightInit(WeightInit.XAVIER)
|
.weightInit(WeightInit.XAVIER)
|
||||||
.list()
|
.list()
|
||||||
.layer(LSTM.builder().nOut(8).build())
|
.layer(LSTM.builder().nOut(8).build())
|
||||||
.layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build())
|
.layer( SelfAttentionLayer.builder().nOut(4).nHeads(2).projectInput(true).build())
|
||||||
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
.layer(GlobalPoolingLayer.builder().poolingType(PoolingType.MAX).build())
|
||||||
.layer(OutputLayer.builder().nOut(2).activation(Activation.SOFTMAX)
|
.layer(OutputLayer.builder().nOut(2).activation(Activation.SOFTMAX)
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
|
|
@ -52,7 +52,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs(inputName)
|
.addInputs(inputName)
|
||||||
.setOutputs(output)
|
.setOutputs(output)
|
||||||
.layer(conv, new Convolution1DLayer.Builder(7)
|
.layer(conv, Convolution1DLayer.builder(7)
|
||||||
.convolutionMode(ConvolutionMode.Same)
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
.nOut(input.size(1))
|
.nOut(input.size(1))
|
||||||
.weightInit(new WeightInitIdentity())
|
.weightInit(new WeightInitIdentity())
|
||||||
|
@ -115,7 +115,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest {
|
||||||
.weightInit(new WeightInitIdentity())
|
.weightInit(new WeightInitIdentity())
|
||||||
.activation(new ActivationIdentity())
|
.activation(new ActivationIdentity())
|
||||||
.build(), inputName)
|
.build(), inputName)
|
||||||
.layer(output, new Cnn3DLossLayer.Builder(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv)
|
.layer(output, Cnn3DLossLayer.builder().dataFormat(Convolution3D.DataFormat.NCDHW).activation(new ActivationIdentity()).build(), conv)
|
||||||
.build());
|
.build());
|
||||||
graph.init();
|
graph.init();
|
||||||
|
|
||||||
|
|
|
@ -249,7 +249,7 @@ public class ModelGuesserTest extends BaseDL4JTest {
|
||||||
int nOut = 6;
|
int nOut = 6;
|
||||||
|
|
||||||
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01).l2(0.01)
|
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01).l2(0.01)
|
||||||
.updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list()
|
.updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER)
|
||||||
.layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build())
|
.layer(0, DenseLayer.builder().nIn(nIn).nOut(20).build())
|
||||||
.layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder()
|
.layer(1, DenseLayer.builder().nIn(20).nOut(30).build()).layer(2, OutputLayer.builder()
|
||||||
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
|
.lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build())
|
||||||
|
|
|
@ -56,10 +56,10 @@ public class KerasSpaceToDepth extends KerasLayer {
|
||||||
|
|
||||||
// TODO: we hard-code block size here to import YOLO9000. This size is not available as property
|
// TODO: we hard-code block size here to import YOLO9000. This size is not available as property
|
||||||
// in the hdf5 file outside of the serialized lambda function (that we can't really well deserialize).
|
// in the hdf5 file outside of the serialized lambda function (that we can't really well deserialize).
|
||||||
SpaceToDepthLayer.Builder builder = new SpaceToDepthLayer.Builder()
|
var builder = SpaceToDepthLayer.builder()
|
||||||
.blocks(2)
|
.blockSize(2)
|
||||||
//the default data format is tensorflow/NWHC for keras import
|
//the default data format is tensorflow/NWHC for keras import
|
||||||
.dataFormat(SpaceToDepthLayer.DataFormat.NHWC)
|
.dataFormat(SpaceToDepthLayer.DataFormat.NHWC.toFormat())
|
||||||
.name(name);
|
.name(name);
|
||||||
|
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class KerasUpsampling3D extends KerasLayer {
|
||||||
int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 3, conf);
|
int[] size = KerasConvolutionUtils.getUpsamplingSizeFromConfig(layerConfig, 3, conf);
|
||||||
// TODO: make sure to allow different sizes.
|
// TODO: make sure to allow different sizes.
|
||||||
|
|
||||||
Upsampling3D.Builder builder = new Upsampling3D.Builder()
|
var builder = Upsampling3D.builder()
|
||||||
.name(this.name)
|
.name(this.name)
|
||||||
.dropOut(this.dropout)
|
.dropOut(this.dropout)
|
||||||
.size(size[0]);
|
.size(size[0]);
|
||||||
|
|
|
@ -59,7 +59,7 @@ public class KerasLRN extends KerasLayer {
|
||||||
super(layerConfig, enforceTrainingConfig);
|
super(layerConfig, enforceTrainingConfig);
|
||||||
Map<String, Object> lrnParams = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
Map<String, Object> lrnParams = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf);
|
||||||
|
|
||||||
LocalResponseNormalization.Builder builder = LocalResponseNormalization.builder().name(this.name)
|
var builder = LocalResponseNormalization.builder().name(this.name)
|
||||||
.dropOut(this.dropout).alpha((double) lrnParams.get("alpha"))
|
.dropOut(this.dropout).alpha((double) lrnParams.get("alpha"))
|
||||||
.beta((double) lrnParams.get("beta")).k((int) lrnParams.get("k")).n((int) lrnParams.get("n"));
|
.beta((double) lrnParams.get("beta")).k((int) lrnParams.get("k")).n((int) lrnParams.get("n"));
|
||||||
this.layer = builder.build();
|
this.layer = builder.build();
|
||||||
|
|
|
@ -33,6 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolu
|
||||||
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
|
||||||
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
import org.deeplearning4j.nn.params.ConvolutionParamInitializer;
|
||||||
import org.deeplearning4j.nn.weights.IWeightInit;
|
import org.deeplearning4j.nn.weights.IWeightInit;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
@ -98,7 +99,7 @@ public class KerasLocallyConnected1D extends KerasConvolution {
|
||||||
LocallyConnected1D.LocallyConnected1DBuilder builder = LocallyConnected1D.builder().name(this.name)
|
LocallyConnected1D.LocallyConnected1DBuilder builder = LocallyConnected1D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getActivationFromConfig(layerConfig, conf))
|
.activation(getActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
.weightInit(WeightInit.valueOf(conf.getKERAS_PARAM_NAME_W()))
|
||||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
.kernelSize(getKernelSizeFromConfig(layerConfig, 1, conf, kerasMajorVersion)[0])
|
||||||
|
|
|
@ -99,7 +99,7 @@ public class KerasLocallyConnected2D extends KerasConvolution {
|
||||||
LocallyConnected2D.LocallyConnected2DBuilder builder = LocallyConnected2D.builder().name(this.name)
|
LocallyConnected2D.LocallyConnected2DBuilder builder = LocallyConnected2D.builder().name(this.name)
|
||||||
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
.nOut(getNOutFromConfig(layerConfig, conf)).dropOut(this.dropout)
|
||||||
.activation(getActivationFromConfig(layerConfig, conf))
|
.activation(getActivationFromConfig(layerConfig, conf))
|
||||||
.weightInit(conf.getKERAS_PARAM_NAME_W(), init)
|
.weightInit(init.enumValue())
|
||||||
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
.l1(this.weightL1Regularization).l2(this.weightL2Regularization)
|
||||||
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
.convolutionMode(getConvolutionModeFromConfig(layerConfig, conf))
|
||||||
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
.kernelSize(getKernelSizeFromConfig(layerConfig, 2, conf, kerasMajorVersion))
|
||||||
|
|
|
@ -130,12 +130,14 @@ public class KerasBatchNormalization extends KerasLayer {
|
||||||
BatchNormalization.BatchNormalizationBuilder builder =BatchNormalization.builder()
|
BatchNormalization.BatchNormalizationBuilder builder =BatchNormalization.builder()
|
||||||
.name(this.name)
|
.name(this.name)
|
||||||
.dropOut(this.dropout)
|
.dropOut(this.dropout)
|
||||||
.minibatch(true)
|
|
||||||
|
.isMinibatch(true)
|
||||||
.lockGammaBeta(false)
|
.lockGammaBeta(false)
|
||||||
.useLogStd(false)
|
.useLogStd(false)
|
||||||
.decay(getMomentumFromConfig(layerConfig))
|
.decay(getMomentumFromConfig(layerConfig))
|
||||||
.eps(getEpsFromConfig(layerConfig));
|
.eps(getEpsFromConfig(layerConfig));
|
||||||
if (betaConstraint != null)
|
if (betaConstraint != null)
|
||||||
|
|
||||||
builder.constrainBeta(betaConstraint);
|
builder.constrainBeta(betaConstraint);
|
||||||
if (gammaConstraint != null)
|
if (gammaConstraint != null)
|
||||||
builder.constrainGamma(gammaConstraint);
|
builder.constrainGamma(gammaConstraint);
|
||||||
|
|
|
@ -58,11 +58,11 @@ public class KerasModelImportTest extends BaseDL4JTest {
|
||||||
MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5");
|
MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5");
|
||||||
List<LayerConfiguration> layerConfigs = model.getNetConfiguration().getFlattenedLayerConfigurations();
|
List<LayerConfiguration> layerConfigs = model.getNetConfiguration().getFlattenedLayerConfigurations();
|
||||||
ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0);
|
ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0);
|
||||||
assertEquals(CNN2DFormat.NCHW,convolutionLayer.getDataFormat());
|
assertEquals(CNN2DFormat.NCHW,convolutionLayer.getConvFormat());
|
||||||
SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1);
|
SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1);
|
||||||
assertEquals(CNN2DFormat.NHWC,subsamplingLayer.getDataFormat());
|
assertEquals(CNN2DFormat.NHWC,subsamplingLayer.getDataFormat());
|
||||||
ConvolutionLayer convolutionLayer1 = (ConvolutionLayer) layerConfigs.get(2);
|
ConvolutionLayer convolutionLayer1 = (ConvolutionLayer) layerConfigs.get(2);
|
||||||
assertEquals(CNN2DFormat.NHWC,convolutionLayer1.getDataFormat());
|
assertEquals(CNN2DFormat.NHWC,convolutionLayer1.getConvFormat());
|
||||||
|
|
||||||
model.output(Nd4j.zeros(1,1,28,28));
|
model.output(Nd4j.zeros(1,1,28,28));
|
||||||
assertNotNull(model);
|
assertNotNull(model);
|
||||||
|
|
|
@ -60,8 +60,8 @@ public class KerasYolo9000PredictTest extends BaseDL4JTest {
|
||||||
|
|
||||||
ComputationGraph model = new TransferLearning.GraphBuilder(graph)
|
ComputationGraph model = new TransferLearning.GraphBuilder(graph)
|
||||||
.addLayer("outputs",
|
.addLayer("outputs",
|
||||||
new org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.Builder()
|
org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer.builder()
|
||||||
.boundingBoxPriors(priors)
|
.boundingBoxes(priors)
|
||||||
.build(),
|
.build(),
|
||||||
"conv2d_23")
|
"conv2d_23")
|
||||||
.setOutputs("outputs")
|
.setOutputs("outputs")
|
||||||
|
|
|
@ -126,7 +126,7 @@ public class KerasLocallyConnected1DTest extends BaseDL4JTest {
|
||||||
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0);
|
||||||
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0);
|
||||||
assertEquals(new Dropout(DROPOUT_DL4J), layer.getDropOut());
|
assertEquals(new Dropout(DROPOUT_DL4J), layer.getDropOut());
|
||||||
assertEquals(KERNEL_SIZE, layer.getKernel());
|
assertEquals(KERNEL_SIZE, layer.getKernelSize());
|
||||||
assertEquals(STRIDE, layer.getStride());
|
assertEquals(STRIDE, layer.getStride());
|
||||||
assertEquals(N_OUT, layer.getNOut());
|
assertEquals(N_OUT, layer.getNOut());
|
||||||
assertEquals(ConvolutionMode.Truncate, layer.getConvolutionMode());
|
assertEquals(ConvolutionMode.Truncate, layer.getConvolutionMode());
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
package net.brutex.ai.dnn.api;
|
package net.brutex.ai.dnn.api;
|
||||||
|
|
||||||
|
|
||||||
public interface ILayerConfiguration {
|
public interface ILayerConfiguration {
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -561,13 +561,12 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
|
||||||
|
|
||||||
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
|
List<Object> innerConfigurations$value = new ArrayList<>(); // initialize with an empty list
|
||||||
|
|
||||||
public B activation(IActivation activation) {
|
public B activation(Activation activation) {
|
||||||
this.activation = activation;
|
this.activation = activation;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
public B activation(IActivation activation) {
|
||||||
public B activation(Activation activation) {
|
this.activation = activation;
|
||||||
this.activation = activation.getActivationFunction();
|
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -157,9 +157,9 @@ public class AttentionVertex extends SameDiffVertex {
|
||||||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||||
|
|
||||||
attention = sameDiff.nn.multiHeadDotProductAttention(getLayerName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
|
attention = sameDiff.nn.multiHeadDotProductAttention(getName(), queries, keys, values, Wq, Wk, Wv, Wo, mask, true);
|
||||||
}else{
|
}else{
|
||||||
attention = sameDiff.nn.dotProductAttention(getLayerName(), queries, keys, values, mask, true);
|
attention = sameDiff.nn.dotProductAttention(getName(), queries, keys, values, mask, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(maskVars != null){
|
if(maskVars != null){
|
||||||
|
|
|
@ -53,11 +53,11 @@ public class ActivationLayer extends NoParamLayer {
|
||||||
public static ActivationLayerBuilder<?, ?> builder(Activation activation) {
|
public static ActivationLayerBuilder<?, ?> builder(Activation activation) {
|
||||||
return innerBuilder().activation(activation);
|
return innerBuilder().activation(activation);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static ActivationLayerBuilder<?, ?> builder(IActivation activation) {
|
public static ActivationLayerBuilder<?, ?> builder(IActivation activation) {
|
||||||
return innerBuilder().activation(activation);
|
return innerBuilder().activation(activation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static ActivationLayerBuilder<?, ?> builder() {
|
public static ActivationLayerBuilder<?, ?> builder() {
|
||||||
return innerBuilder();
|
return innerBuilder();
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
|
import org.deeplearning4j.nn.api.layers.LayerConstraint;
|
||||||
import net.brutex.ai.dnn.api.LayerType;
|
import net.brutex.ai.dnn.api.LayerType;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
|
@ -80,19 +81,38 @@ public class BatchNormalization extends FeedForwardLayer {
|
||||||
@lombok.Builder.Default protected boolean isMinibatch = true;
|
@lombok.Builder.Default protected boolean isMinibatch = true;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
||||||
* 1.0
|
* 1.0
|
||||||
*
|
*
|
||||||
* @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode
|
* @param gamma Gamma parameter for all activations, used only with locked gamma/beta configuration mode
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default protected double gamma = 1.0;
|
@lombok.Builder.Default protected double gamma = 1.0;
|
||||||
/**
|
/**
|
||||||
* Used only when 'true' is passed to {@link #lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
* Used only when 'true' is passed to {@link BatchNormalizationBuilder#lockGammaBeta(boolean)}. Value is not used otherwise.<br> Default:
|
||||||
* 0.0
|
* 0.0
|
||||||
*
|
*
|
||||||
* @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode
|
* @param beta Beta parameter for all activations, used only with locked gamma/beta configuration mode
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default protected double beta = 0.0;
|
@lombok.Builder.Default protected double beta = 0.0;
|
||||||
|
/**
|
||||||
|
* Set constraints to be applied to the beta parameter of this batch normalisation layer. Default: no
|
||||||
|
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
|
||||||
|
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
||||||
|
* been updated.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
protected List<LayerConstraint> betaConstraints;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set constraints to be applied to the gamma parameter of this batch normalisation layer. Default: no
|
||||||
|
* constraints.<br> Constraints can be used to enforce certain conditions (non-negativity of parameters,
|
||||||
|
* max-norm regularization, etc). These constraints are applied at each iteration, after the parameters have
|
||||||
|
* been updated.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
protected List<LayerConstraint> gammaConstraints;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed?
|
* When using CuDNN or MKLDNN and an error is encountered, should fallback to the non-helper implementation be allowed?
|
||||||
* If set to false, an exception in the helper will be propagated back to the user. If true, the built-in
|
* If set to false, an exception in the helper will be propagated back to the user. If true, the built-in
|
||||||
|
@ -298,6 +318,15 @@ public class BatchNormalization extends FeedForwardLayer {
|
||||||
this.cudnnAllowFallback$set = true;
|
this.cudnnAllowFallback$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public B constrainBeta(LayerConstraint ... constraints) {
|
||||||
|
this.betaConstraints = List.of(constraints);
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
public B constrainGamma(LayerConstraint ... constraints) {
|
||||||
|
this.gammaConstraints = List.of(constraints);
|
||||||
|
return self();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,16 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
||||||
public class Convolution1DLayer extends ConvolutionLayer {
|
public class Convolution1DLayer extends ConvolutionLayer {
|
||||||
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
@Builder.Default private RNNFormat rnnDataFormat = RNNFormat.NCW;
|
||||||
|
/**
|
||||||
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
* Default: NCHW
|
||||||
|
*
|
||||||
|
* @param format Format for activations (in and out)
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
protected CNN2DFormat dataFormat =
|
||||||
|
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||||
/**
|
/**
|
||||||
* Size of the convolution
|
* Size of the convolution
|
||||||
*
|
*
|
||||||
|
|
|
@ -26,10 +26,7 @@ import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.*;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.layers.convolution.Convolution3DLayer;
|
import org.deeplearning4j.nn.layers.convolution.Convolution3DLayer;
|
||||||
import org.deeplearning4j.nn.params.Convolution3DParamInitializer;
|
import org.deeplearning4j.nn.params.Convolution3DParamInitializer;
|
||||||
|
|
|
@ -65,6 +65,18 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
* details Default is {@link ConvolutionMode}.Truncate.
|
* details Default is {@link ConvolutionMode}.Truncate.
|
||||||
*/
|
*/
|
||||||
@Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
@Builder.Default protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
* Default: NCHW
|
||||||
|
*
|
||||||
|
* @param format Format for activations (in and out)
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
protected CNN2DFormat convFormat =
|
||||||
|
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated
|
* Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated
|
||||||
* convolutions, which are also known as atrous convolutions.
|
* convolutions, which are also known as atrous convolutions.
|
||||||
|
@ -86,16 +98,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
* false, the built-in (non-CuDNN) implementation for ConvolutionLayer will be used
|
* false, the built-in (non-CuDNN) implementation for ConvolutionLayer will be used
|
||||||
*/
|
*/
|
||||||
@Builder.Default protected boolean cudnnAllowFallback = true;
|
@Builder.Default protected boolean cudnnAllowFallback = true;
|
||||||
/**
|
|
||||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
|
||||||
* See {@link CNN2DFormat} for more details.<br>
|
|
||||||
* Default: NCHW
|
|
||||||
*
|
|
||||||
* @param format Format for activations (in and out)
|
|
||||||
*/
|
|
||||||
@Builder.Default
|
|
||||||
protected CNN2DFormat dataFormat =
|
|
||||||
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
|
||||||
/** Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */
|
/** Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */
|
||||||
@Builder.Default protected AlgoMode cudnnAlgoMode = AlgoMode.PREFER_FASTEST;
|
@Builder.Default protected AlgoMode cudnnAlgoMode = AlgoMode.PREFER_FASTEST;
|
||||||
|
|
||||||
|
@ -179,7 +182,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
nOut,
|
nOut,
|
||||||
layerIndex,
|
layerIndex,
|
||||||
getName(),
|
getName(),
|
||||||
dataFormat,
|
convFormat,
|
||||||
ConvolutionLayer.class);
|
ConvolutionLayer.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -196,11 +199,11 @@ public class ConvolutionLayer extends FeedForwardLayer {
|
||||||
if (!defaultValueOverriden || nIn <= 0 || override) {
|
if (!defaultValueOverriden || nIn <= 0 || override) {
|
||||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
this.nIn = c.getChannels();
|
this.nIn = c.getChannels();
|
||||||
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dataFormat == null || override)
|
if (convFormat == null || override)
|
||||||
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
this.convFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -53,6 +53,16 @@ public class DepthwiseConvolution2D extends ConvolutionLayer {
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
protected int depthMultiplier = 1;
|
protected int depthMultiplier = 1;
|
||||||
|
/**
|
||||||
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
* Default: NCHW
|
||||||
|
*
|
||||||
|
* @param format Format for activations (in and out)
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
protected CNN2DFormat dataFormat =
|
||||||
|
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||||
/**
|
/**
|
||||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
* See {@link CNN2DFormat} for more details.<br>
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
|
|
@ -36,6 +36,7 @@ import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||||
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
|
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
|
||||||
import org.deeplearning4j.optimize.api.TrainingListener;
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -63,7 +64,7 @@ public class EmbeddingLayer extends FeedForwardLayer {
|
||||||
*/
|
*/
|
||||||
public static EmbeddingLayerBuilder<?, ?> builder() {
|
public static EmbeddingLayerBuilder<?, ?> builder() {
|
||||||
return innerBuilder()
|
return innerBuilder()
|
||||||
.activation(new ActivationIdentity());
|
.activation(Activation.IDENTITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class EmbeddingLayerBuilder<C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C,B>>
|
public static abstract class EmbeddingLayerBuilder<C extends EmbeddingLayer, B extends EmbeddingLayerBuilder<C,B>>
|
||||||
|
|
|
@ -88,7 +88,7 @@ public abstract class LayerConfiguration
|
||||||
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
|
* Activation#getActivationFunction()} but not vice versa. The default is Identity Activation.
|
||||||
*/
|
*/
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
@Getter @Setter private IActivation activation = new ActivationIdentity();
|
@Getter @Setter private IActivation activation = Activation.IDENTITY;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the activation interface (function) from the activation. The activation must have been set
|
* Get the activation interface (function) from the activation. The activation must have been set
|
||||||
|
@ -335,7 +335,7 @@ public abstract class LayerConfiguration
|
||||||
|
|
||||||
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
public static abstract class LayerConfigurationBuilder<C extends LayerConfiguration, B extends LayerConfigurationBuilder<C, B>> {
|
||||||
public B activation(Activation activation) {
|
public B activation(Activation activation) {
|
||||||
this.activation$value = activation.getActivationFunction();
|
this.activation$value = activation;
|
||||||
this.activation$set = true;
|
this.activation$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
@ -344,6 +344,7 @@ public abstract class LayerConfiguration
|
||||||
this.activation$set = true;
|
this.activation$set = true;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public B dropOut(double d) {
|
public B dropOut(double d) {
|
||||||
this.dropOut = new Dropout(d);
|
this.dropOut = new Dropout(d);
|
||||||
return self();
|
return self();
|
||||||
|
@ -352,6 +353,14 @@ public abstract class LayerConfiguration
|
||||||
this.dropOut = d;
|
this.dropOut = d;
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public B constrainBias(LayerConstraint constraint) {
|
||||||
|
return this.biasConstraints(List.of(constraint));
|
||||||
|
}
|
||||||
|
|
||||||
|
public B constrainWeights(LayerConstraint constraint) {
|
||||||
|
return this.weightConstraints(List.of(constraint));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.MaskState;
|
import org.deeplearning4j.nn.api.MaskState;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
@ -32,39 +34,35 @@ import org.nd4j.autodiff.samediff.SDIndex;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@SuperBuilder(buildMethodName = "initBuild")
|
||||||
public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
||||||
private long nIn;
|
|
||||||
private long nOut;
|
|
||||||
private int nHeads;
|
|
||||||
private long headSize;
|
|
||||||
private boolean projectInput;
|
|
||||||
private int nQueries;
|
|
||||||
|
|
||||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||||
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
||||||
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
||||||
private static final String WEIGHT_QUERIES = "Q";
|
private static final String WEIGHT_QUERIES = "Q";
|
||||||
|
/** Number of inputs to the layer (input size) */
|
||||||
|
private int nIn;
|
||||||
|
/** Number of outputs (output size) */
|
||||||
|
private int nOut;
|
||||||
|
/** Number of Attention Heads */
|
||||||
|
private int nHeads;
|
||||||
|
/** Size of attention heads */
|
||||||
|
private int headSize;
|
||||||
|
/** Project input before applying attention or not. */
|
||||||
|
private boolean projectInput;
|
||||||
|
/** Number of queries to learn */
|
||||||
|
private int nQueries;
|
||||||
|
|
||||||
private LearnedSelfAttentionLayer(){/*No arg constructor for serialization*/}
|
private LearnedSelfAttentionLayer() {
|
||||||
|
/*No arg constructor for serialization*/
|
||||||
protected LearnedSelfAttentionLayer(Builder builder){
|
|
||||||
super(builder);
|
|
||||||
nIn = builder.nIn;
|
|
||||||
nOut = builder.nOut;
|
|
||||||
nHeads = builder.nHeads;
|
|
||||||
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
|
|
||||||
projectInput = builder.projectInput;
|
|
||||||
nQueries = builder.nQueries;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -75,27 +73,34 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||||
throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer name = \"" + getName()
|
throw new IllegalStateException(
|
||||||
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
|
"Invalid input for Learned Self Attention layer (layer name = \""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): expect RNN input type with size > 0. Got: "
|
||||||
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nIn <= 0 || override) {
|
if (nIn <= 0 || override) {
|
||||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||||
this.nIn = r.getSize();
|
this.nIn = (int) r.getSize();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||||
throw new IllegalStateException("Invalid input for Learned Self Attention layer (layer index = " + layerIndex
|
throw new IllegalStateException(
|
||||||
+ ", layer name = \"" + getName() + "\"): expect RNN input type with size > 0. Got: "
|
"Invalid input for Learned Self Attention layer (layer index = "
|
||||||
|
+ layerIndex
|
||||||
|
+ ", layer name = \""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): expect RNN input type with size > 0. Got: "
|
||||||
+ inputType);
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(projectInput){
|
if (projectInput) {
|
||||||
return InputType.recurrent(nOut, nQueries);
|
return InputType.recurrent(nOut, nQueries);
|
||||||
}else{
|
} else {
|
||||||
return InputType.recurrent(nIn, nQueries);
|
return InputType.recurrent(nIn, nQueries);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,7 +111,7 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
||||||
|
|
||||||
params.addWeightParam(WEIGHT_QUERIES, 1, nIn, nQueries);
|
params.addWeightParam(WEIGHT_QUERIES, 1, nIn, nQueries);
|
||||||
|
|
||||||
if(projectInput){
|
if (projectInput) {
|
||||||
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
||||||
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
||||||
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
||||||
|
@ -118,137 +123,72 @@ public class LearnedSelfAttentionLayer extends SameDiffLayer {
|
||||||
public void initializeParameters(Map<String, INDArray> params) {
|
public void initializeParameters(Map<String, INDArray> params) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||||
if(e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)){
|
if (e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)) {
|
||||||
WeightInitUtil.initWeights(nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
WeightInitUtil.initWeights(
|
||||||
}else if(e.getKey().equals(WEIGHT_QUERIES)){
|
nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||||
WeightInitUtil.initWeights(nIn, nQueries, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
} else if (e.getKey().equals(WEIGHT_QUERIES)) {
|
||||||
}else{
|
WeightInitUtil.initWeights(
|
||||||
WeightInitUtil.initWeights(nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
nIn, nQueries, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||||
|
} else {
|
||||||
|
WeightInitUtil.initWeights(
|
||||||
|
nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
public SDVariable defineLayer(
|
||||||
|
SameDiff sameDiff,
|
||||||
|
SDVariable layerInput,
|
||||||
|
Map<String, SDVariable> paramTable,
|
||||||
|
SDVariable mask) {
|
||||||
val baseQueries = paramTable.get(WEIGHT_QUERIES);
|
val baseQueries = paramTable.get(WEIGHT_QUERIES);
|
||||||
val batchSize = layerInput.shape().get(SDIndex.point(0));
|
val batchSize = layerInput.shape().get(SDIndex.point(0));
|
||||||
val tileAxis = sameDiff.scatterUpdate(sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize);
|
val tileAxis =
|
||||||
|
sameDiff.scatterUpdate(
|
||||||
|
sameDiff.onesLike(layerInput.shape()), sameDiff.constant(0), batchSize);
|
||||||
|
|
||||||
val queries = sameDiff.tile(baseQueries, tileAxis);
|
val queries = sameDiff.tile(baseQueries, tileAxis);
|
||||||
|
|
||||||
if(projectInput){
|
if (projectInput) {
|
||||||
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
||||||
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
||||||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||||
|
|
||||||
return sameDiff.nn.multiHeadDotProductAttention(getName(), queries, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
return sameDiff.nn.multiHeadDotProductAttention(
|
||||||
}else{
|
getName(), queries, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
||||||
return sameDiff.nn.dotProductAttention(getName(), queries, layerInput, layerInput, mask, true);
|
} else {
|
||||||
|
return sameDiff.nn.dotProductAttention(
|
||||||
|
getName(), queries, layerInput, layerInput, mask, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
public Pair<INDArray, MaskState> feedForwardMaskArray(
|
||||||
// No further mask propagation here, as the results have taken any mask into account, like in a global pooling layer
|
INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
|
||||||
|
// No further mask propagation here, as the results have taken any mask into account, like in a
|
||||||
|
// global pooling layer
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Getter
|
public abstract static class LearnedSelfAttentionLayerBuilder<
|
||||||
@Setter
|
C extends LearnedSelfAttentionLayer, B extends LearnedSelfAttentionLayerBuilder<C, B>>
|
||||||
public static class Builder extends SameDiffLayer.Builder<LearnedSelfAttentionLayer.Builder> {
|
extends SameDiffLayerBuilder<C, B> {
|
||||||
|
public C build() {
|
||||||
/**
|
Preconditions.checkArgument(
|
||||||
* Number of inputs to the layer (input size)
|
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||||
*/
|
Preconditions.checkArgument(
|
||||||
private int nIn;
|
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||||
|
Preconditions.checkArgument(
|
||||||
/**
|
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||||
* Number of outputs (output size)
|
Preconditions.checkArgument(
|
||||||
*/
|
this.nOut % nHeads == 0 || headSize > 0,
|
||||||
private int nOut;
|
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of Attention Heads
|
|
||||||
*/
|
|
||||||
private int nHeads;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Size of attention heads
|
|
||||||
*/
|
|
||||||
private int headSize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Project input before applying attention or not.
|
|
||||||
*/
|
|
||||||
private boolean projectInput;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of queries to learn
|
|
||||||
*/
|
|
||||||
private int nQueries;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param nIn Number of inputs to the layer (input size)
|
|
||||||
*/
|
|
||||||
public Builder nIn(int nIn) {
|
|
||||||
this.nIn = nIn;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param nOut Number of outputs (output size)
|
|
||||||
*/
|
|
||||||
public Builder nOut(int nOut) {
|
|
||||||
this.nOut = nOut;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of Attention Heads
|
|
||||||
*/
|
|
||||||
public Builder nHeads(int nHeads){
|
|
||||||
this.nHeads = nHeads;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Size of attention heads
|
|
||||||
*/
|
|
||||||
public Builder headSize(int headSize){
|
|
||||||
this.headSize = headSize;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Project input before applying attention or not.
|
|
||||||
*/
|
|
||||||
public Builder projectInput(boolean projectInput){
|
|
||||||
this.projectInput = projectInput;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of queries to learn
|
|
||||||
*/
|
|
||||||
public Builder nQueries(int nQueries){
|
|
||||||
this.nQueries = nQueries;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public LearnedSelfAttentionLayer build() {
|
|
||||||
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
|
||||||
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
|
||||||
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
|
||||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
|
||||||
Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries.");
|
Preconditions.checkArgument(this.nQueries > 0, "You must set numQueries.");
|
||||||
|
|
||||||
return new LearnedSelfAttentionLayer(this);
|
return initBuild();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.enums.PadMode;
|
import org.nd4j.enums.PadMode;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.activations.IActivation;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -333,6 +334,11 @@ public class LocallyConnected2D extends SameDiffLayer {
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public B inputSize(int ... size) {
|
||||||
|
this.inputSize = size;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
|
||||||
public B stride(int ... stride) {
|
public B stride(int ... stride) {
|
||||||
this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride");
|
this.stride$value = ValidationUtils.validate2NonNegative(stride, false, "stride");
|
||||||
this.stride$set = true;
|
this.stride$set = true;
|
||||||
|
|
|
@ -20,7 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
import org.deeplearning4j.nn.conf.inputs.InputType.InputTypeConvolutional;
|
||||||
|
@ -37,87 +39,191 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
|
||||||
public class PrimaryCapsules extends SameDiffLayer {
|
public class PrimaryCapsules extends SameDiffLayer {
|
||||||
|
|
||||||
private int[] kernelSize;
|
|
||||||
private int[] stride;
|
|
||||||
private int[] padding;
|
|
||||||
private int[] dilation;
|
|
||||||
private int inputChannels;
|
|
||||||
private int channels;
|
|
||||||
|
|
||||||
private boolean hasBias;
|
|
||||||
|
|
||||||
private int capsules;
|
|
||||||
private int capsuleDimensions;
|
|
||||||
|
|
||||||
private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
|
||||||
|
|
||||||
private boolean useRelu = false;
|
|
||||||
private double leak = 0;
|
|
||||||
|
|
||||||
private static final String WEIGHT_PARAM = "weight";
|
private static final String WEIGHT_PARAM = "weight";
|
||||||
private static final String BIAS_PARAM = "bias";
|
private static final String BIAS_PARAM = "bias";
|
||||||
|
/**
|
||||||
|
* Sets the kernel size of the 2d convolution
|
||||||
|
*
|
||||||
|
* @param kernelSize
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private int[] kernelSize = new int[] {9, 9};
|
||||||
|
/**
|
||||||
|
* Sets the stride of the 2d convolution
|
||||||
|
*
|
||||||
|
* @param stride
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private int[] stride = new int[] {2, 2};
|
||||||
|
/**
|
||||||
|
* Sets the padding of the 2d convolution
|
||||||
|
*
|
||||||
|
* @param padding
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private int[] padding = new int[] {0, 0};
|
||||||
|
/**
|
||||||
|
* Sets the dilation of the 2d convolution
|
||||||
|
*
|
||||||
|
* @param dilation
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private int[] dilation = new int[] {1, 1};
|
||||||
|
|
||||||
public PrimaryCapsules(Builder builder){
|
private int inputChannels;
|
||||||
super(builder);
|
/**
|
||||||
|
* Sets the number of channels to use in the 2d convolution.
|
||||||
|
*
|
||||||
|
* <p>Note that the actual number of channels is channels * capsuleDimensions
|
||||||
|
*
|
||||||
|
* <p>Does the same thing as nOut()
|
||||||
|
*
|
||||||
|
* @param channels
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private int channels = 32;
|
||||||
|
|
||||||
this.kernelSize = builder.kernelSize;
|
@Builder.Default private boolean hasBias = true;
|
||||||
this.stride = builder.stride;
|
/**
|
||||||
this.padding = builder.padding;
|
* Usually inferred automatically.
|
||||||
this.dilation = builder.dilation;
|
*
|
||||||
this.channels = builder.channels;
|
* @param capsules
|
||||||
this.hasBias = builder.hasBias;
|
* @return
|
||||||
this.capsules = builder.capsules;
|
*/
|
||||||
this.capsuleDimensions = builder.capsuleDimensions;
|
private int capsules;
|
||||||
this.convolutionMode = builder.convolutionMode;
|
/**
|
||||||
this.useRelu = builder.useRelu;
|
* Sets the number of dimensions to use in the capsules.
|
||||||
this.leak = builder.leak;
|
*
|
||||||
|
* @param capsuleDimensions
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private int capsuleDimensions;
|
||||||
|
/**
|
||||||
|
* The convolution mode to use in the 2d convolution
|
||||||
|
*
|
||||||
|
* @param convolutionMode
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
||||||
|
/**
|
||||||
|
* Whether to use a ReLU activation on the 2d convolution
|
||||||
|
*
|
||||||
|
* @param useRelu
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private boolean useRelU = false;
|
||||||
|
/**
|
||||||
|
* Use a LeakyReLU activation on the 2d convolution
|
||||||
|
*
|
||||||
|
* @param leak the alpha value for the LeakyReLU activation.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Builder.Default private double useLeakyReLU = 0;
|
||||||
|
|
||||||
if(capsuleDimensions <= 0 || channels <= 0){
|
public static PrimaryCapsulesBuilder<?, ?> builder() {
|
||||||
throw new IllegalArgumentException("Invalid configuration for Primary Capsules (layer name = \""
|
return innerBuilder();
|
||||||
+ name + "\"):"
|
|
||||||
+ " capsuleDimensions and channels must be > 0. Got: "
|
|
||||||
+ capsuleDimensions + ", " + channels);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(capsules < 0){
|
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||||
throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \""
|
int capsuleDimensions,
|
||||||
+ name + "\"):"
|
int channels,
|
||||||
+ " capsules must be >= 0 if set. Got: "
|
int[] kernelSize,
|
||||||
+ capsules);
|
int[] stride,
|
||||||
|
int[] padding,
|
||||||
|
int[] dilation,
|
||||||
|
ConvolutionMode convolutionMode) {
|
||||||
|
return innerBuilder()
|
||||||
|
.capsuleDimensions(capsuleDimensions)
|
||||||
|
.channels(channels)
|
||||||
|
.kernelSize(kernelSize)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.dilation(dilation)
|
||||||
|
.convolutionMode(convolutionMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||||
|
int capsuleDimensions,
|
||||||
|
int channels,
|
||||||
|
int[] kernelSize,
|
||||||
|
int[] stride,
|
||||||
|
int[] padding,
|
||||||
|
int[] dilation) {
|
||||||
|
return innerBuilder()
|
||||||
|
.capsuleDimensions(capsuleDimensions)
|
||||||
|
.channels(channels)
|
||||||
|
.kernelSize(kernelSize)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding)
|
||||||
|
.dilation(dilation);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||||
|
int capsuleDimensions, int channels, int[] kernelSize, int[] stride, int[] padding) {
|
||||||
|
return innerBuilder()
|
||||||
|
.capsuleDimensions(capsuleDimensions)
|
||||||
|
.channels(channels)
|
||||||
|
.kernelSize(kernelSize)
|
||||||
|
.stride(stride)
|
||||||
|
.padding(padding);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||||
|
int capsuleDimensions, int channels, int[] kernelSize, int[] stride) {
|
||||||
|
return innerBuilder()
|
||||||
|
.capsuleDimensions(capsuleDimensions)
|
||||||
|
.channels(channels)
|
||||||
|
.kernelSize(kernelSize)
|
||||||
|
.stride(stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PrimaryCapsulesBuilder<?, ?> builder(
|
||||||
|
int capsuleDimensions, int channels, int[] kernelSize) {
|
||||||
|
return innerBuilder()
|
||||||
|
.capsuleDimensions(capsuleDimensions)
|
||||||
|
.channels(channels)
|
||||||
|
.kernelSize(kernelSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PrimaryCapsulesBuilder<?, ?> builder(int capsuleDimensions, int channels) {
|
||||||
|
return innerBuilder().capsuleDimensions(capsuleDimensions).channels(channels);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable defineLayer(SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
public SDVariable defineLayer(
|
||||||
Conv2DConfig conf = Conv2DConfig.builder()
|
SameDiff SD, SDVariable input, Map<String, SDVariable> paramTable, SDVariable mask) {
|
||||||
.kH(kernelSize[0]).kW(kernelSize[1])
|
Conv2DConfig conf =
|
||||||
.sH(stride[0]).sW(stride[1])
|
Conv2DConfig.builder()
|
||||||
.pH(padding[0]).pW(padding[1])
|
.kH(kernelSize[0])
|
||||||
.dH(dilation[0]).dW(dilation[1])
|
.kW(kernelSize[1])
|
||||||
|
.sH(stride[0])
|
||||||
|
.sW(stride[1])
|
||||||
|
.pH(padding[0])
|
||||||
|
.pW(padding[1])
|
||||||
|
.dH(dilation[0])
|
||||||
|
.dW(dilation[1])
|
||||||
.isSameMode(convolutionMode == ConvolutionMode.Same)
|
.isSameMode(convolutionMode == ConvolutionMode.Same)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
SDVariable conved;
|
SDVariable conved;
|
||||||
|
|
||||||
if(hasBias){
|
if (hasBias) {
|
||||||
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf);
|
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), paramTable.get(BIAS_PARAM), conf);
|
||||||
} else {
|
} else {
|
||||||
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf);
|
conved = SD.cnn.conv2d(input, paramTable.get(WEIGHT_PARAM), conf);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(useRelu){
|
if (useRelU) {
|
||||||
if(leak == 0) {
|
if (useLeakyReLU == 0) {
|
||||||
conved = SD.nn.relu(conved, 0);
|
conved = SD.nn.relu(conved, 0);
|
||||||
} else {
|
} else {
|
||||||
conved = SD.nn.leakyRelu(conved, leak);
|
conved = SD.nn.leakyRelu(conved, useLeakyReLU);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,10 +234,14 @@ public class PrimaryCapsules extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public void defineParameters(SDLayerParams params) {
|
public void defineParameters(SDLayerParams params) {
|
||||||
params.clear();
|
params.clear();
|
||||||
params.addWeightParam(WEIGHT_PARAM,
|
params.addWeightParam(
|
||||||
kernelSize[0], kernelSize[1], inputChannels, (long) capsuleDimensions * channels);
|
WEIGHT_PARAM,
|
||||||
|
kernelSize[0],
|
||||||
|
kernelSize[1],
|
||||||
|
inputChannels,
|
||||||
|
(long) capsuleDimensions * channels);
|
||||||
|
|
||||||
if(hasBias){
|
if (hasBias) {
|
||||||
params.addBiasParam(BIAS_PARAM, (long) capsuleDimensions * channels);
|
params.addBiasParam(BIAS_PARAM, (long) capsuleDimensions * channels);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -142,11 +252,16 @@ public class PrimaryCapsules extends SameDiffLayer {
|
||||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||||
if (BIAS_PARAM.equals(e.getKey())) {
|
if (BIAS_PARAM.equals(e.getKey())) {
|
||||||
e.getValue().assign(0);
|
e.getValue().assign(0);
|
||||||
} else if(WEIGHT_PARAM.equals(e.getKey())){
|
} else if (WEIGHT_PARAM.equals(e.getKey())) {
|
||||||
double fanIn = inputChannels * kernelSize[0] * kernelSize[1];
|
double fanIn = inputChannels * kernelSize[0] * kernelSize[1];
|
||||||
double fanOut = capsuleDimensions * channels * kernelSize[0] * kernelSize[1] / ((double) stride[0] * stride[1]);
|
double fanOut =
|
||||||
WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c',
|
capsuleDimensions
|
||||||
e.getValue());
|
* channels
|
||||||
|
* kernelSize[0]
|
||||||
|
* kernelSize[1]
|
||||||
|
/ ((double) stride[0] * stride[1]);
|
||||||
|
WeightInitUtil.initWeights(
|
||||||
|
fanIn, fanOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -155,19 +270,33 @@ public class PrimaryCapsules extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
if (inputType == null || inputType.getType() != Type.CNN) {
|
if (inputType == null || inputType.getType() != Type.CNN) {
|
||||||
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \""
|
throw new IllegalStateException(
|
||||||
+ name + "\"): expect CNN input. Got: " + inputType);
|
"Invalid input for Primary Capsules layer (layer name = \""
|
||||||
|
+ name
|
||||||
|
+ "\"): expect CNN input. Got: "
|
||||||
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(capsules > 0){
|
if (capsules > 0) {
|
||||||
return InputType.recurrent(capsules, capsuleDimensions);
|
return InputType.recurrent(capsules, capsuleDimensions);
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil
|
InputTypeConvolutional out =
|
||||||
.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode,
|
(InputTypeConvolutional)
|
||||||
(long) capsuleDimensions * channels, -1, getName(), PrimaryCapsules.class);
|
InputTypeUtil.getOutputTypeCnnLayers(
|
||||||
|
inputType,
|
||||||
|
kernelSize,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
convolutionMode,
|
||||||
|
(long) capsuleDimensions * channels,
|
||||||
|
-1,
|
||||||
|
getName(),
|
||||||
|
PrimaryCapsules.class);
|
||||||
|
|
||||||
return InputType.recurrent((int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions),
|
return InputType.recurrent(
|
||||||
|
(int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions),
|
||||||
capsuleDimensions);
|
capsuleDimensions);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -175,250 +304,122 @@ public class PrimaryCapsules extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
if (inputType == null || inputType.getType() != Type.CNN) {
|
if (inputType == null || inputType.getType() != Type.CNN) {
|
||||||
throw new IllegalStateException("Invalid input for Primary Capsules layer (layer name = \""
|
throw new IllegalStateException(
|
||||||
+ name + "\"): expect CNN input. Got: " + inputType);
|
"Invalid input for Primary Capsules layer (layer name = \""
|
||||||
|
+ name
|
||||||
|
+ "\"): expect CNN input. Got: "
|
||||||
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputTypeConvolutional ci = (InputTypeConvolutional) inputType;
|
InputTypeConvolutional ci = (InputTypeConvolutional) inputType;
|
||||||
|
|
||||||
this.inputChannels = (int) ci.getChannels();
|
this.inputChannels = (int) ci.getChannels();
|
||||||
|
|
||||||
if(capsules <= 0 || override) {
|
if (capsules <= 0 || override) {
|
||||||
|
|
||||||
InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil
|
InputTypeConvolutional out =
|
||||||
.getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode,
|
(InputTypeConvolutional)
|
||||||
(long) capsuleDimensions * channels, -1, getName(), PrimaryCapsules.class);
|
InputTypeUtil.getOutputTypeCnnLayers(
|
||||||
|
inputType,
|
||||||
|
kernelSize,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
convolutionMode,
|
||||||
|
(long) capsuleDimensions * channels,
|
||||||
|
-1,
|
||||||
|
getName(),
|
||||||
|
PrimaryCapsules.class);
|
||||||
|
|
||||||
this.capsules = (int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions);
|
this.capsules =
|
||||||
|
(int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Getter
|
public abstract static class PrimaryCapsulesBuilder<
|
||||||
@Setter
|
C extends PrimaryCapsules, B extends PrimaryCapsulesBuilder<C, B>>
|
||||||
public static class Builder extends SameDiffLayer.Builder<Builder>{
|
extends SameDiffLayerBuilder<C, B> {
|
||||||
|
|
||||||
@Setter(AccessLevel.NONE)
|
public B kernelSize(int... kernelSize) {
|
||||||
private int[] kernelSize = new int[]{9, 9};
|
this.kernelSize$value = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
|
||||||
|
this.kernelSize$set = true;
|
||||||
@Setter(AccessLevel.NONE)
|
return self();
|
||||||
private int[] stride = new int[]{2, 2};
|
|
||||||
|
|
||||||
@Setter(AccessLevel.NONE)
|
|
||||||
private int[] padding = new int[]{0, 0};
|
|
||||||
|
|
||||||
@Setter(AccessLevel.NONE)
|
|
||||||
private int[] dilation = new int[]{1, 1};
|
|
||||||
|
|
||||||
private int channels = 32;
|
|
||||||
|
|
||||||
private boolean hasBias = true;
|
|
||||||
|
|
||||||
private int capsules;
|
|
||||||
private int capsuleDimensions;
|
|
||||||
|
|
||||||
private ConvolutionMode convolutionMode = ConvolutionMode.Truncate;
|
|
||||||
|
|
||||||
private boolean useRelu = false;
|
|
||||||
private double leak = 0;
|
|
||||||
|
|
||||||
|
|
||||||
public void setKernelSize(int... kernelSize){
|
|
||||||
this.kernelSize = ValidationUtils.validate2NonNegative(kernelSize, true, "kernelSize");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setStride(int... stride){
|
public B stride(int... stride) {
|
||||||
this.stride = ValidationUtils.validate2NonNegative(stride, true, "stride");
|
this.stride$value = ValidationUtils.validate2NonNegative(stride, true, "stride");
|
||||||
|
this.stride$set = true;
|
||||||
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setPadding(int... padding){
|
public B padding(int... padding) {
|
||||||
this.padding = ValidationUtils.validate2NonNegative(padding, true, "padding");
|
this.padding$value = ValidationUtils.validate2NonNegative(padding, true, "padding");
|
||||||
|
this.padding$set = true;
|
||||||
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setDilation(int... dilation){
|
public B dilation(int... dilation) {
|
||||||
this.dilation = ValidationUtils.validate2NonNegative(dilation, true, "dilation");
|
this.dilation$value = ValidationUtils.validate2NonNegative(dilation, true, "dilation");
|
||||||
|
this.dilation$set = true;
|
||||||
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public Builder(int capsuleDimensions, int channels,
|
|
||||||
int[] kernelSize, int[] stride, int[] padding, int[] dilation,
|
|
||||||
ConvolutionMode convolutionMode){
|
|
||||||
this.capsuleDimensions = capsuleDimensions;
|
|
||||||
this.channels = channels;
|
|
||||||
this.setKernelSize(kernelSize);
|
|
||||||
this.setStride(stride);
|
|
||||||
this.setPadding(padding);
|
|
||||||
this.setDilation(dilation);
|
|
||||||
this.convolutionMode = convolutionMode;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder(int capsuleDimensions, int channels,
|
|
||||||
int[] kernelSize, int[] stride, int[] padding, int[] dilation){
|
|
||||||
this(capsuleDimensions, channels, kernelSize, stride, padding, dilation, ConvolutionMode.Truncate);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder(int capsuleDimensions, int channels,
|
|
||||||
int[] kernelSize, int[] stride, int[] padding){
|
|
||||||
this(capsuleDimensions, channels, kernelSize, stride, padding, new int[]{1, 1}, ConvolutionMode.Truncate);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder(int capsuleDimensions, int channels,
|
|
||||||
int[] kernelSize, int[] stride){
|
|
||||||
this(capsuleDimensions, channels, kernelSize, stride, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder(int capsuleDimensions, int channels,
|
|
||||||
int[] kernelSize){
|
|
||||||
this(capsuleDimensions, channels, kernelSize, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder(int capsuleDimensions, int channels){
|
|
||||||
this(capsuleDimensions, channels, new int[]{9, 9}, new int[]{2, 2}, new int[]{0, 0}, new int[]{1, 1}, ConvolutionMode.Truncate);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the kernel size of the 2d convolution
|
|
||||||
*
|
|
||||||
* @see ConvolutionLayer.Builder#kernelSize(int...)
|
|
||||||
* @param kernelSize
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder kernelSize(int... kernelSize){
|
|
||||||
this.setKernelSize(kernelSize);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the stride of the 2d convolution
|
|
||||||
*
|
|
||||||
* @see ConvolutionLayer.Builder#stride(int...)
|
|
||||||
* @param stride
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder stride(int... stride){
|
|
||||||
this.setStride(stride);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the padding of the 2d convolution
|
|
||||||
*
|
|
||||||
* @see ConvolutionLayer.Builder#padding(int...)
|
|
||||||
* @param padding
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder padding(int... padding){
|
|
||||||
this.setPadding(padding);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the dilation of the 2d convolution
|
|
||||||
*
|
|
||||||
* @see ConvolutionLayer.Builder#dilation(int...)
|
|
||||||
* @param dilation
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder dilation(int... dilation){
|
|
||||||
this.setDilation(dilation);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the number of channels to use in the 2d convolution.
|
* Sets the number of channels to use in the 2d convolution.
|
||||||
*
|
*
|
||||||
* Note that the actual number of channels is channels * capsuleDimensions
|
* <p>Note that the actual number of channels is channels * capsuleDimensions
|
||||||
*
|
*
|
||||||
* Does the same thing as nOut()
|
* <p>Does the same thing as channels()
|
||||||
*
|
|
||||||
* @param channels
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder channels(int channels){
|
|
||||||
this.channels = channels;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the number of channels to use in the 2d convolution.
|
|
||||||
*
|
|
||||||
* Note that the actual number of channels is channels * capsuleDimensions
|
|
||||||
*
|
|
||||||
* Does the same thing as channels()
|
|
||||||
*
|
*
|
||||||
* @param nOut
|
* @param nOut
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder nOut(int nOut){
|
public B nOut(int nOut) {
|
||||||
return channels(nOut);
|
return channels(nOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the number of dimensions to use in the capsules.
|
|
||||||
* @param capsuleDimensions
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder capsuleDimensions(int capsuleDimensions){
|
|
||||||
this.capsuleDimensions = capsuleDimensions;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Usually inferred automatically.
|
|
||||||
* @param capsules
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder capsules(int capsules){
|
|
||||||
this.capsules = capsules;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder hasBias(boolean hasBias){
|
|
||||||
this.hasBias = hasBias;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The convolution mode to use in the 2d convolution
|
|
||||||
* @param convolutionMode
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder convolutionMode(ConvolutionMode convolutionMode){
|
|
||||||
this.convolutionMode = convolutionMode;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether to use a ReLU activation on the 2d convolution
|
|
||||||
* @param useRelu
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Builder useReLU(boolean useRelu){
|
|
||||||
this.useRelu = useRelu;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use a ReLU activation on the 2d convolution
|
* Use a ReLU activation on the 2d convolution
|
||||||
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder useReLU(){
|
public B useReLU() {
|
||||||
return useReLU(true);
|
return useRelU(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Use a LeakyReLU activation on the 2d convolution
|
* Use a LeakyReLU activation on the 2d convolution. Implies {@link #useReLU()} set true.
|
||||||
|
*
|
||||||
* @param leak the alpha value for the LeakyReLU activation.
|
* @param leak the alpha value for the LeakyReLU activation.
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public Builder useLeakyReLU(double leak){
|
public B useLeakyReLU(double leak) {
|
||||||
this.useRelu = true;
|
this.useRelU(true);
|
||||||
this.leak = leak;
|
this.useLeakyReLU$value = leak;
|
||||||
return this;
|
this.useLeakyReLU$set = true;
|
||||||
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public C build() {
|
||||||
public <E extends LayerConfiguration> E build() {
|
C l = initBuild();
|
||||||
return (E) new PrimaryCapsules(this);
|
if (capsuleDimensions <= 0 || channels$value <= 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Invalid configuration for Primary Capsules (layer name = \""
|
||||||
|
+ l.getName()
|
||||||
|
+ "\"):"
|
||||||
|
+ " capsuleDimensions and channels must be > 0. Got: "
|
||||||
|
+ capsuleDimensions
|
||||||
|
+ ", "
|
||||||
|
+ channels$value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (capsules < 0) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Invalid configuration for Capsule ILayer (layer name = \""
|
||||||
|
+ l.getName()
|
||||||
|
+ "\"):"
|
||||||
|
+ " capsules must be >= 0 if set. Got: "
|
||||||
|
+ capsules);
|
||||||
|
}
|
||||||
|
return l;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
@ -41,15 +42,63 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@SuperBuilder(buildMethodName = "initBuild")
|
||||||
public class RecurrentAttentionLayer extends SameDiffLayer {
|
public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
private long nIn;
|
|
||||||
private long nOut;
|
public static abstract class RecurrentAttentionLayerBuilder<C extends RecurrentAttentionLayer, B extends RecurrentAttentionLayerBuilder<C,B>>
|
||||||
|
extends SameDiffLayerBuilder<C,B> {
|
||||||
|
|
||||||
|
public C build() {
|
||||||
|
Preconditions.checkArgument(this.projectInput$value || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||||
|
Preconditions.checkArgument(this.projectInput$value || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||||
|
Preconditions.checkArgument(!this.projectInput$value || nOut != 0, "nOut must be specified when projectInput is true");
|
||||||
|
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||||
|
|
||||||
|
C l = initBuild();
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of inputs to the layer (input size)
|
||||||
|
*/
|
||||||
|
private int nIn;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of outputs (output size)
|
||||||
|
*/
|
||||||
|
private int nOut;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of Attention Heads
|
||||||
|
*/
|
||||||
private int nHeads;
|
private int nHeads;
|
||||||
private long headSize;
|
|
||||||
private boolean projectInput;
|
/**
|
||||||
private Activation activation;
|
* Size of attention heads
|
||||||
private boolean hasBias;
|
*/
|
||||||
|
private int headSize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Project input before applying attention or not.
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private boolean projectInput = true;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If true (default is true) the layer will have a bias
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private boolean hasBias = true;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Activation function for the layer
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
private Activation activation = Activation.TANH;
|
||||||
|
|
||||||
|
|
||||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||||
|
@ -60,18 +109,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
private static final String RECURRENT_WEIGHT_KEY = SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY;
|
private static final String RECURRENT_WEIGHT_KEY = SimpleRnnParamInitializer.RECURRENT_WEIGHT_KEY;
|
||||||
private int timeSteps;
|
private int timeSteps;
|
||||||
|
|
||||||
private RecurrentAttentionLayer(){/*No arg constructor for serialization*/}
|
|
||||||
|
|
||||||
protected RecurrentAttentionLayer(Builder builder){
|
|
||||||
super(builder);
|
|
||||||
nIn = builder.nIn;
|
|
||||||
nOut = builder.nOut;
|
|
||||||
nHeads = builder.nHeads;
|
|
||||||
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
|
|
||||||
projectInput = builder.projectInput;
|
|
||||||
activation = builder.activation;
|
|
||||||
hasBias = builder.hasBias;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
|
@ -87,7 +125,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
|
|
||||||
if (nIn <= 0 || override) {
|
if (nIn <= 0 || override) {
|
||||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||||
this.nIn = r.getSize();
|
this.nIn = (int) r.getSize();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -206,109 +244,5 @@ public class RecurrentAttentionLayer extends SameDiffLayer {
|
||||||
return sameDiff.concat(2, outputSlices);
|
return sameDiff.concat(2, outputSlices);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public static class Builder extends SameDiffLayer.Builder<RecurrentAttentionLayer.Builder> {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of inputs to the layer (input size)
|
|
||||||
*/
|
|
||||||
private int nIn;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of outputs (output size)
|
|
||||||
*/
|
|
||||||
private int nOut;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of Attention Heads
|
|
||||||
*/
|
|
||||||
private int nHeads;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Size of attention heads
|
|
||||||
*/
|
|
||||||
private int headSize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Project input before applying attention or not.
|
|
||||||
*/
|
|
||||||
private boolean projectInput = true;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* If true (default is true) the layer will have a bias
|
|
||||||
*/
|
|
||||||
private boolean hasBias = true;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Activation function for the layer
|
|
||||||
*/
|
|
||||||
private Activation activation = Activation.TANH;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param nIn Number of inputs to the layer (input size)
|
|
||||||
*/
|
|
||||||
public Builder nIn(int nIn) {
|
|
||||||
this.nIn = nIn;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param nOut Number of outputs (output size)
|
|
||||||
*/
|
|
||||||
public Builder nOut(int nOut) {
|
|
||||||
this.nOut = nOut;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of Attention Heads
|
|
||||||
*/
|
|
||||||
public Builder nHeads(int nHeads){
|
|
||||||
this.nHeads = nHeads;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Size of attention heads
|
|
||||||
*/
|
|
||||||
public Builder headSize(int headSize){
|
|
||||||
this.headSize = headSize;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Project input before applying attention or not.
|
|
||||||
*/
|
|
||||||
public Builder projectInput(boolean projectInput){
|
|
||||||
this.projectInput = projectInput;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param hasBias If true (default is true) the layer will have a bias
|
|
||||||
*/
|
|
||||||
public Builder hasBias(boolean hasBias) {
|
|
||||||
this.hasBias = hasBias;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param activation Activation function for the layer
|
|
||||||
*/
|
|
||||||
public Builder activation(Activation activation) {
|
|
||||||
this.activation = activation;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public RecurrentAttentionLayer build() {
|
|
||||||
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
|
||||||
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
|
||||||
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
|
||||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
|
||||||
return new RecurrentAttentionLayer(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
import org.deeplearning4j.nn.conf.inputs.InputType;
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||||||
|
@ -34,32 +36,25 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@NoArgsConstructor()
|
||||||
|
@SuperBuilder(buildMethodName = "initBuild")
|
||||||
public class SelfAttentionLayer extends SameDiffLayer {
|
public class SelfAttentionLayer extends SameDiffLayer {
|
||||||
private long nIn;
|
|
||||||
private long nOut;
|
|
||||||
private int nHeads;
|
|
||||||
private long headSize;
|
|
||||||
private boolean projectInput;
|
|
||||||
|
|
||||||
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";
|
||||||
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
private static final String WEIGHT_KEY_KEY_PROJECTION = "Wk";
|
||||||
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
private static final String WEIGHT_KEY_VALUE_PROJECTION = "Wv";
|
||||||
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
private static final String WEIGHT_KEY_OUT_PROJECTION = "Wo";
|
||||||
|
/** Number of inputs to the layer (input size) */
|
||||||
private SelfAttentionLayer(){/*No arg constructor for serialization*/}
|
private int nIn;
|
||||||
|
/** Number of outputs (output size) */
|
||||||
protected SelfAttentionLayer(Builder builder){
|
private int nOut;
|
||||||
super(builder);
|
/** Number of Attention Heads */
|
||||||
nIn = builder.nIn;
|
private int nHeads;
|
||||||
nOut = builder.nOut;
|
/** Size of attention heads */
|
||||||
nHeads = builder.nHeads;
|
private int headSize;
|
||||||
headSize = builder.headSize == 0 ? nOut / nHeads : builder.headSize;
|
/** Project input before applying attention or not. */
|
||||||
projectInput = builder.projectInput;
|
private boolean projectInput;
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
|
@ -69,29 +64,36 @@ public class SelfAttentionLayer extends SameDiffLayer {
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||||
throw new IllegalStateException("Invalid input for Self Attention layer (layer name = \"" + getName()
|
throw new IllegalStateException(
|
||||||
+ "\"): expect RNN input type with size > 0. Got: " + inputType);
|
"Invalid input for Self Attention layer (layer name = \""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): expect RNN input type with size > 0. Got: "
|
||||||
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nIn <= 0 || override) {
|
if (nIn <= 0 || override) {
|
||||||
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
InputType.InputTypeRecurrent r = (InputType.InputTypeRecurrent) inputType;
|
||||||
this.nIn = r.getSize();
|
this.nIn = (int) r.getSize();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
if (inputType == null || inputType.getType() != InputType.Type.RNN) {
|
||||||
throw new IllegalStateException("Invalid input for Self Attention layer (layer index = " + layerIndex
|
throw new IllegalStateException(
|
||||||
+ ", layer name = \"" + getName() + "\"): expect RNN input type with size > 0. Got: "
|
"Invalid input for Self Attention layer (layer index = "
|
||||||
|
+ layerIndex
|
||||||
|
+ ", layer name = \""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): expect RNN input type with size > 0. Got: "
|
||||||
+ inputType);
|
+ inputType);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
InputType.InputTypeRecurrent itr = (InputType.InputTypeRecurrent) inputType;
|
||||||
|
|
||||||
if(projectInput){
|
if (projectInput) {
|
||||||
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
|
return InputType.recurrent(nOut, itr.getTimeSeriesLength());
|
||||||
}else{
|
} else {
|
||||||
return InputType.recurrent(nIn, itr.getTimeSeriesLength());
|
return InputType.recurrent(nIn, itr.getTimeSeriesLength());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -100,7 +102,7 @@ public class SelfAttentionLayer extends SameDiffLayer {
|
||||||
public void defineParameters(SDLayerParams params) {
|
public void defineParameters(SDLayerParams params) {
|
||||||
params.clear();
|
params.clear();
|
||||||
|
|
||||||
if(projectInput){
|
if (projectInput) {
|
||||||
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
params.addWeightParam(WEIGHT_KEY_QUERY_PROJECTION, nHeads, headSize, nIn);
|
||||||
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
params.addWeightParam(WEIGHT_KEY_KEY_PROJECTION, nHeads, headSize, nIn);
|
||||||
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
params.addWeightParam(WEIGHT_KEY_VALUE_PROJECTION, nHeads, headSize, nIn);
|
||||||
|
@ -112,108 +114,52 @@ public class SelfAttentionLayer extends SameDiffLayer {
|
||||||
public void initializeParameters(Map<String, INDArray> params) {
|
public void initializeParameters(Map<String, INDArray> params) {
|
||||||
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
for (Map.Entry<String, INDArray> e : params.entrySet()) {
|
||||||
if(e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)){
|
if (e.getKey().equals(WEIGHT_KEY_OUT_PROJECTION)) {
|
||||||
WeightInitUtil.initWeights(nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
WeightInitUtil.initWeights(
|
||||||
}else{
|
nIn, headSize, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||||
WeightInitUtil.initWeights(nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
} else {
|
||||||
|
WeightInitUtil.initWeights(
|
||||||
|
nHeads * headSize, nOut, e.getValue().shape(), weightInit, null, 'c', e.getValue());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable, SDVariable mask) {
|
public SDVariable defineLayer(
|
||||||
if(projectInput){
|
SameDiff sameDiff,
|
||||||
|
SDVariable layerInput,
|
||||||
|
Map<String, SDVariable> paramTable,
|
||||||
|
SDVariable mask) {
|
||||||
|
if (projectInput) {
|
||||||
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
val Wq = paramTable.get(WEIGHT_KEY_QUERY_PROJECTION);
|
||||||
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
val Wk = paramTable.get(WEIGHT_KEY_KEY_PROJECTION);
|
||||||
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
val Wv = paramTable.get(WEIGHT_KEY_VALUE_PROJECTION);
|
||||||
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
val Wo = paramTable.get(WEIGHT_KEY_OUT_PROJECTION);
|
||||||
|
|
||||||
return sameDiff.nn.multiHeadDotProductAttention(getName(), layerInput, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
return sameDiff.nn.multiHeadDotProductAttention(
|
||||||
}else{
|
getName(), layerInput, layerInput, layerInput, Wq, Wk, Wv, Wo, mask, true);
|
||||||
return sameDiff.nn.dotProductAttention(getName(), layerInput, layerInput, layerInput, mask, true);
|
} else {
|
||||||
|
return sameDiff.nn.dotProductAttention(
|
||||||
|
getName(), layerInput, layerInput, layerInput, mask, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public abstract static class SelfAttentionLayerBuilder<
|
||||||
|
C extends SelfAttentionLayer, B extends SelfAttentionLayerBuilder<C, B>>
|
||||||
|
extends SameDiffLayerBuilder<C, B> {
|
||||||
|
public C build() {
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
||||||
|
Preconditions.checkArgument(
|
||||||
|
this.nOut % nHeads == 0 || headSize > 0,
|
||||||
|
"nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
||||||
|
|
||||||
@Getter
|
return initBuild();
|
||||||
@Setter
|
|
||||||
public static class Builder extends SameDiffLayer.Builder<SelfAttentionLayer.Builder> {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of inputs to the layer (input size)
|
|
||||||
*/
|
|
||||||
private int nIn;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of outputs (output size)
|
|
||||||
*/
|
|
||||||
private int nOut;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of Attention Heads
|
|
||||||
*/
|
|
||||||
private int nHeads;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Size of attention heads
|
|
||||||
*/
|
|
||||||
private int headSize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Project input before applying attention or not.
|
|
||||||
*/
|
|
||||||
private boolean projectInput;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param nIn Number of inputs to the layer (input size)
|
|
||||||
*/
|
|
||||||
public Builder nIn(int nIn) {
|
|
||||||
this.nIn = nIn;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param nOut Number of outputs (output size)
|
|
||||||
*/
|
|
||||||
public Builder nOut(int nOut) {
|
|
||||||
this.nOut = nOut;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Number of Attention Heads
|
|
||||||
*/
|
|
||||||
public Builder nHeads(int nHeads){
|
|
||||||
this.nHeads = nHeads;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Size of attention heads
|
|
||||||
*/
|
|
||||||
public Builder headSize(int headSize){
|
|
||||||
this.headSize = headSize;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Project input before applying attention or not.
|
|
||||||
*/
|
|
||||||
public Builder projectInput(boolean projectInput){
|
|
||||||
this.projectInput = projectInput;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public SelfAttentionLayer build() {
|
|
||||||
Preconditions.checkArgument(this.projectInput || this.nHeads == 1, "projectInput must be true when nHeads != 1");
|
|
||||||
Preconditions.checkArgument(this.projectInput || nIn == nOut, "nIn must be equal to nOut when projectInput is false");
|
|
||||||
Preconditions.checkArgument(!this.projectInput || nOut != 0, "nOut must be specified when projectInput is true");
|
|
||||||
Preconditions.checkArgument(this.nOut % nHeads == 0 || headSize > 0, "nOut isn't divided by nHeads cleanly. Specify the headSize manually.");
|
|
||||||
return new SelfAttentionLayer(this);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,16 @@ public class SeparableConvolution2D extends ConvolutionLayer {
|
||||||
* @return Builder
|
* @return Builder
|
||||||
*/
|
*/
|
||||||
@Builder.Default private int depthMultiplier = 1;
|
@Builder.Default private int depthMultiplier = 1;
|
||||||
|
/**
|
||||||
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
* Default: NCHW
|
||||||
|
*
|
||||||
|
* @param format Format for activations (in and out)
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
protected CNN2DFormat dataFormat =
|
||||||
|
CNN2DFormat.NCHW; // default value for legacy serialization reasons
|
||||||
public static SeparableConvolution2DBuilder<?, ?> builder() {
|
public static SeparableConvolution2DBuilder<?, ?> builder() {
|
||||||
return innerBuilder();
|
return innerBuilder();
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
@ -35,27 +38,50 @@ import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@SuperBuilder(builderMethodName = "innerBuilder")
|
||||||
public class SpaceToBatchLayer extends NoParamLayer {
|
public class SpaceToBatchLayer extends NoParamLayer {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
||||||
|
* dimensions
|
||||||
|
*/
|
||||||
|
protected int[] blockSize;
|
||||||
|
/** A 2d array, with format [[padTop, padBottom], [padLeft, padRight]] */
|
||||||
|
@Builder.Default protected int[][] padding = new int[][] {{0, 0}, {0, 0}};
|
||||||
|
/**
|
||||||
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
* Default: NCHW
|
||||||
|
*
|
||||||
|
* @param format Format for activations (in and out)
|
||||||
|
*/
|
||||||
|
@Builder.Default protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||||
|
|
||||||
|
public static SpaceToBatchLayerBuilder<?, ?> builder() {
|
||||||
|
return innerBuilder();
|
||||||
|
}
|
||||||
// TODO: throw error when block and padding dims don't match
|
// TODO: throw error when block and padding dims don't match
|
||||||
|
|
||||||
protected int[] blocks;
|
/**
|
||||||
protected int[][] padding;
|
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and
|
||||||
protected CNN2DFormat format = CNN2DFormat.NCHW;
|
* width dimensions
|
||||||
|
*/
|
||||||
|
public static SpaceToBatchLayerBuilder<?, ?> builder(int[] blocks) {
|
||||||
|
return innerBuilder().blockSize(blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
protected SpaceToBatchLayer(Builder builder) {
|
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and
|
||||||
super(builder);
|
* width dimensions
|
||||||
this.blocks = builder.blocks;
|
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft,
|
||||||
this.padding = builder.padding;
|
* padRight]]
|
||||||
this.format = builder.format;
|
*/
|
||||||
|
public static SpaceToBatchLayerBuilder<?, ?> builder(int[] blocks, int[][] padding) {
|
||||||
|
return innerBuilder().blockSize(blocks).padding(padding);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -64,9 +90,13 @@ public class SpaceToBatchLayer extends NoParamLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
public org.deeplearning4j.nn.api.Layer instantiate(
|
||||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
NeuralNetConfiguration conf,
|
||||||
boolean initializeParams, DataType networkDataType) {
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
|
||||||
org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret =
|
org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret =
|
||||||
|
@ -83,23 +113,31 @@ public class SpaceToBatchLayer extends NoParamLayer {
|
||||||
@Override
|
@Override
|
||||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional) inputType;
|
||||||
InputType.InputTypeConvolutional outputType = (InputType.InputTypeConvolutional) getOutputType(-1, inputType);
|
InputType.InputTypeConvolutional outputType =
|
||||||
|
(InputType.InputTypeConvolutional) getOutputType(-1, inputType);
|
||||||
|
|
||||||
return new LayerMemoryReport.Builder(name, SpaceToBatchLayer.class, inputType, outputType)
|
return new LayerMemoryReport.Builder(name, SpaceToBatchLayer.class, inputType, outputType)
|
||||||
.standardMemory(0, 0) //No params
|
.standardMemory(0, 0) // No params
|
||||||
.cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching
|
.cacheMemory(
|
||||||
|
MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) // No caching
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
|
if (inputType == null || inputType.getType() != InputType.Type.CNN) {
|
||||||
throw new IllegalStateException("Invalid input for Subsampling layer (layer name=\"" + getName()
|
throw new IllegalStateException(
|
||||||
+ "\"): Expected CNN input, got " + inputType);
|
"Invalid input for Subsampling layer (layer name=\""
|
||||||
|
+ getName()
|
||||||
|
+ "\"): Expected CNN input, got "
|
||||||
|
+ inputType);
|
||||||
}
|
}
|
||||||
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
InputType.InputTypeConvolutional i = (InputType.InputTypeConvolutional) inputType;
|
||||||
return InputType.convolutional((i.getHeight() + padding[0][0] + padding[0][1]) / blocks[0],
|
return InputType.convolutional(
|
||||||
(i.getWidth() + padding[1][0] + padding[1][1]) / blocks[1], i.getChannels(), i.getFormat());
|
(i.getHeight() + padding[0][0] + padding[0][1]) / blockSize[0],
|
||||||
|
(i.getWidth() + padding[1][0] + padding[1][1]) / blockSize[1],
|
||||||
|
i.getChannels(),
|
||||||
|
i.getFormat());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -107,17 +145,21 @@ public class SpaceToBatchLayer extends NoParamLayer {
|
||||||
return EmptyParamInitializer.getInstance();
|
return EmptyParamInitializer.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
Preconditions.checkState(inputType.getType() == InputType.Type.CNN, "Only CNN input types can be used with SpaceToBatchLayer, got %s", inputType);
|
Preconditions.checkState(
|
||||||
this.format = ((InputType.InputTypeConvolutional)inputType).getFormat();
|
inputType.getType() == InputType.Type.CNN,
|
||||||
|
"Only CNN input types can be used with SpaceToBatchLayer, got %s",
|
||||||
|
inputType);
|
||||||
|
this.dataFormat = ((InputType.InputTypeConvolutional) inputType).getFormat();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
if (inputType == null) {
|
if (inputType == null) {
|
||||||
throw new IllegalStateException("Invalid input for space to batch layer (layer name=\"" + getName()
|
throw new IllegalStateException(
|
||||||
|
"Invalid input for space to batch layer (layer name=\""
|
||||||
|
+ getName()
|
||||||
+ "\"): input is null");
|
+ "\"): input is null");
|
||||||
}
|
}
|
||||||
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
|
return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, getName());
|
||||||
|
@ -128,102 +170,28 @@ public class SpaceToBatchLayer extends NoParamLayer {
|
||||||
throw new UnsupportedOperationException("SpaceToBatchLayer does not contain parameters");
|
throw new UnsupportedOperationException("SpaceToBatchLayer does not contain parameters");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public abstract static class SpaceToBatchLayerBuilder<
|
||||||
@NoArgsConstructor
|
C extends SpaceToBatchLayer, B extends SpaceToBatchLayerBuilder<C, B>>
|
||||||
@Getter
|
extends NoParamLayerBuilder<C, B> {
|
||||||
@Setter
|
|
||||||
public static class Builder<T extends Builder<T>> extends LayerConfiguration.Builder<T> {
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height
|
||||||
* dimensions
|
* and width dimensions
|
||||||
|
* @return
|
||||||
*/
|
*/
|
||||||
@Setter(AccessLevel.NONE)
|
public B blockSize(int... blocks) {
|
||||||
protected int[] blocks;
|
this.blockSize = ValidationUtils.validate2NonNegative(blocks, false, "blocks");
|
||||||
|
return self();
|
||||||
/**
|
|
||||||
* A 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
|
||||||
*/
|
|
||||||
protected int[][] padding;
|
|
||||||
|
|
||||||
protected CNN2DFormat format = CNN2DFormat.NCHW;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
|
||||||
* dimensions
|
|
||||||
*/
|
|
||||||
public void setBlocks(int... blocks) {
|
|
||||||
this.blocks = ValidationUtils.validate2NonNegative(blocks, false, "blocks");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft,
|
||||||
|
* padRight]]
|
||||||
|
* @return
|
||||||
*/
|
*/
|
||||||
public void setPadding(int[][] padding) {
|
public B padding(int[][] padding) {
|
||||||
this.padding = ValidationUtils.validate2x2NonNegative(padding, "padding");
|
this.padding$value = ValidationUtils.validate2x2NonNegative(padding, "padding");
|
||||||
}
|
this.padding$set = true;
|
||||||
|
return self();
|
||||||
|
|
||||||
/**
|
|
||||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
|
||||||
* dimensions
|
|
||||||
*/
|
|
||||||
public Builder(int[] blocks) {
|
|
||||||
this.setBlocks(blocks);
|
|
||||||
this.setPadding(new int[][] {{0, 0}, {0, 0}});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
|
||||||
* dimensions
|
|
||||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
|
||||||
*/
|
|
||||||
public Builder(int[] blocks, int[][] padding) {
|
|
||||||
this.setBlocks(blocks);
|
|
||||||
this.setPadding(padding);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
|
||||||
* See {@link CNN2DFormat} for more details.<br>
|
|
||||||
* Default: NCHW
|
|
||||||
* @param format Format for activations (in and out)
|
|
||||||
*/
|
|
||||||
public T dataFormat(CNN2DFormat format){
|
|
||||||
this.format = format;
|
|
||||||
return (T)this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param blocks Block size for SpaceToBatch layer. Should be a length 2 array for the height and width
|
|
||||||
* dimensions
|
|
||||||
*/
|
|
||||||
public T blocks(int... blocks) {
|
|
||||||
this.setBlocks(blocks);
|
|
||||||
return (T) this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param padding Padding - should be a 2d array, with format [[padTop, padBottom], [padLeft, padRight]]
|
|
||||||
*/
|
|
||||||
public T padding(int[][] padding) {
|
|
||||||
this.setPadding(padding);
|
|
||||||
return (T) this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public T name(String layerName) {
|
|
||||||
this.setLayerName(layerName);
|
|
||||||
return (T) this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public SpaceToBatchLayer build() {
|
|
||||||
if(padding == null)
|
|
||||||
setPadding(new int[][] {{0, 0}, {0, 0}});
|
|
||||||
return new SpaceToBatchLayer(this);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
package org.deeplearning4j.nn.conf.layers;
|
package org.deeplearning4j.nn.conf.layers;
|
||||||
|
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
import org.deeplearning4j.nn.conf.InputPreProcessor;
|
||||||
|
@ -40,6 +41,7 @@ import java.util.Map;
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
@EqualsAndHashCode(callSuper = true)
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@SuperBuilder
|
||||||
public class SpaceToDepthLayer extends NoParamLayer {
|
public class SpaceToDepthLayer extends NoParamLayer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -53,16 +55,20 @@ public class SpaceToDepthLayer extends NoParamLayer {
|
||||||
return this == NCHW ? CNN2DFormat.NCHW : CNN2DFormat.NHWC;
|
return this == NCHW ? CNN2DFormat.NCHW : CNN2DFormat.NHWC;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/**
|
||||||
|
* @param blockSize Block size
|
||||||
|
*/
|
||||||
protected int blockSize;
|
protected int blockSize;
|
||||||
protected CNN2DFormat dataFormat;
|
/**
|
||||||
|
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
||||||
|
* See {@link CNN2DFormat} for more details.<br>
|
||||||
|
* Default: NCHW
|
||||||
|
* @param dataFormat Format for activations (in and out)
|
||||||
|
*/
|
||||||
|
@Builder.Default
|
||||||
|
protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
||||||
|
|
||||||
|
|
||||||
protected SpaceToDepthLayer(Builder builder) {
|
|
||||||
super(builder);
|
|
||||||
this.setBlockSize(builder.blockSize);
|
|
||||||
this.setDataFormat(builder.dataFormat);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SpaceToDepthLayer clone() {
|
public SpaceToDepthLayer clone() {
|
||||||
|
@ -74,7 +80,7 @@ public class SpaceToDepthLayer extends NoParamLayer {
|
||||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
||||||
boolean initializeParams, DataType networkDataType) {
|
boolean initializeParams, DataType networkDataType) {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
runInheritance();
|
||||||
org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret =
|
org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret =
|
||||||
new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(lconf, networkDataType);
|
new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(lconf, networkDataType);
|
||||||
ret.addTrainingListeners(trainingListeners);
|
ret.addTrainingListeners(trainingListeners);
|
||||||
|
@ -133,78 +139,5 @@ public class SpaceToDepthLayer extends NoParamLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
public static class Builder<T extends Builder<T>> extends LayerConfiguration.Builder<T> {
|
|
||||||
|
|
||||||
protected int blockSize;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Data format for input activations. Note DL4J uses NCHW in most cases
|
|
||||||
*/
|
|
||||||
protected CNN2DFormat dataFormat = CNN2DFormat.NCHW;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param blockSize Block size
|
|
||||||
*/
|
|
||||||
public Builder(int blockSize) {
|
|
||||||
this.setBlockSize(blockSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param blockSize Block size
|
|
||||||
* @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public Builder(int blockSize, DataFormat dataFormat) {
|
|
||||||
this(blockSize, dataFormat.toFormat());
|
|
||||||
}
|
|
||||||
|
|
||||||
public Builder(int blockSize, CNN2DFormat dataFormat) {
|
|
||||||
this.setBlockSize(blockSize);
|
|
||||||
this.setDataFormat(dataFormat);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param blockSize Block size
|
|
||||||
*/
|
|
||||||
public T blocks(int blockSize) {
|
|
||||||
this.setBlockSize(blockSize);
|
|
||||||
return (T) this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param dataFormat Data format for input activations. Note DL4J uses NCHW in most cases
|
|
||||||
* @deprecated Use {@link #dataFormat(CNN2DFormat)}
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public T dataFormat(DataFormat dataFormat) {
|
|
||||||
return dataFormat(dataFormat.toFormat());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last).
|
|
||||||
* See {@link CNN2DFormat} for more details.<br>
|
|
||||||
* Default: NCHW
|
|
||||||
* @param dataFormat Format for activations (in and out)
|
|
||||||
*/
|
|
||||||
public T dataFormat(CNN2DFormat dataFormat) {
|
|
||||||
this.setDataFormat(dataFormat);
|
|
||||||
return (T) this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public T name(String layerName) {
|
|
||||||
this.setLayerName(layerName);
|
|
||||||
return (T) this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
public SpaceToDepthLayer build() {
|
|
||||||
return new SpaceToDepthLayer(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,14 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers.objdetect;
|
package org.deeplearning4j.nn.conf.layers.objdetect;
|
||||||
|
|
||||||
import lombok.Data;
|
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
||||||
import lombok.EqualsAndHashCode;
|
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
||||||
import lombok.Getter;
|
import java.util.Arrays;
|
||||||
import lombok.Setter;
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.*;
|
||||||
|
import lombok.experimental.SuperBuilder;
|
||||||
import org.deeplearning4j.nn.api.Layer;
|
import org.deeplearning4j.nn.api.Layer;
|
||||||
import org.deeplearning4j.nn.api.ParamInitializer;
|
import org.deeplearning4j.nn.api.ParamInitializer;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
|
@ -41,44 +45,57 @@ import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
import org.nd4j.linalg.lossfunctions.ILossFunction;
|
||||||
import org.nd4j.linalg.lossfunctions.impl.LossL2;
|
import org.nd4j.linalg.lossfunctions.impl.LossL2;
|
||||||
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
|
|
||||||
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = false)
|
@EqualsAndHashCode(callSuper = false)
|
||||||
|
@SuperBuilder(buildMethodName = "initBuild")
|
||||||
public class Yolo2OutputLayer extends LayerConfiguration {
|
public class Yolo2OutputLayer extends LayerConfiguration {
|
||||||
|
|
||||||
private double lambdaCoord;
|
/**
|
||||||
private double lambdaNoObj;
|
* Loss function coefficient for position and size/scale components of the loss function. Default
|
||||||
private ILossFunction lossPositionScale;
|
* (as per paper): 5
|
||||||
private ILossFunction lossClassPredictions;
|
*/
|
||||||
|
@Builder.Default private double lambdaCoord = 5;
|
||||||
|
/**
|
||||||
|
* Loss function coefficient for the "no object confidence" components of the loss function.
|
||||||
|
* Default (as per paper): 0.5
|
||||||
|
*/
|
||||||
|
@Builder.Default private double lambdaNoObj = 0.5;
|
||||||
|
/** Loss function for position/scale component of the loss function */
|
||||||
|
@Builder.Default private ILossFunction lossPositionScale = new LossL2();
|
||||||
|
/**
|
||||||
|
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as
|
||||||
|
* per the paper), however Loss MCXENT could also be used (which is more common for
|
||||||
|
* classification).
|
||||||
|
*/
|
||||||
|
@Builder.Default private ILossFunction lossClassPredictions = new LossL2();
|
||||||
|
;
|
||||||
|
/**
|
||||||
|
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows,
|
||||||
|
* columns] = [N, 2] Note that dimensions should be specified as fraction of grid size. For
|
||||||
|
* example, a network with 13x13 output, a value of 1.0 would correspond to one grid cell; a value
|
||||||
|
* of 13 would correspond to the entire image.
|
||||||
|
*/
|
||||||
@JsonSerialize(using = NDArrayTextSerializer.class)
|
@JsonSerialize(using = NDArrayTextSerializer.class)
|
||||||
@JsonDeserialize(using = BoundingBoxesDeserializer.class)
|
@JsonDeserialize(using = BoundingBoxesDeserializer.class)
|
||||||
|
@Builder.Default
|
||||||
private INDArray boundingBoxes;
|
private INDArray boundingBoxes;
|
||||||
|
|
||||||
private CNN2DFormat format = CNN2DFormat.NCHW; //Default for serialization of old formats
|
@Builder.Default
|
||||||
|
private CNN2DFormat format = CNN2DFormat.NCHW; // Default for serialization of old formats
|
||||||
|
|
||||||
private Yolo2OutputLayer() {
|
private Yolo2OutputLayer() {
|
||||||
//No-arg constructor for Jackson JSON
|
// No-arg constructor for Jackson JSON
|
||||||
}
|
|
||||||
|
|
||||||
private Yolo2OutputLayer(Builder builder) {
|
|
||||||
super(builder);
|
|
||||||
this.lambdaCoord = builder.lambdaCoord;
|
|
||||||
this.lambdaNoObj = builder.lambdaNoObj;
|
|
||||||
this.lossPositionScale = builder.lossPositionScale;
|
|
||||||
this.lossClassPredictions = builder.lossClassPredictions;
|
|
||||||
this.boundingBoxes = builder.boundingBoxes;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners,
|
public Layer instantiate(
|
||||||
int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
|
NeuralNetConfiguration conf,
|
||||||
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType) {
|
||||||
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex);
|
||||||
|
|
||||||
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret =
|
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret =
|
||||||
|
@ -99,7 +116,7 @@ public class Yolo2OutputLayer extends LayerConfiguration {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputType getOutputType(int layerIndex, InputType inputType) {
|
public InputType getOutputType(int layerIndex, InputType inputType) {
|
||||||
return inputType; //Same shape output as input
|
return inputType; // Same shape output as input
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -126,133 +143,41 @@ public class Yolo2OutputLayer extends LayerConfiguration {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Regularization> getRegularizationByParam(String paramName) {
|
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||||
//Not applicable
|
// Not applicable
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isPretrainParam(String paramName) {
|
public boolean isPretrainParam(String paramName) {
|
||||||
return false; //No params
|
return false; // No params
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
long numValues = inputType.arrayElementsPerExample();
|
long numValues = inputType.arrayElementsPerExample();
|
||||||
|
|
||||||
//This is a VERY rough estimate...
|
// This is a VERY rough estimate...
|
||||||
return new LayerMemoryReport.Builder(name, Yolo2OutputLayer.class, inputType, inputType)
|
return new LayerMemoryReport.Builder(name, Yolo2OutputLayer.class, inputType, inputType)
|
||||||
.standardMemory(0, 0) //No params
|
.standardMemory(0, 0) // No params
|
||||||
.workingMemory(0, numValues, 0, 6 * numValues).cacheMemory(0, 0) //No cache
|
.workingMemory(0, numValues, 0, 6 * numValues)
|
||||||
|
.cacheMemory(0, 0) // No cache
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Getter
|
public static abstract class Yolo2OutputLayerBuilder<
|
||||||
@Setter
|
C extends Yolo2OutputLayer, B extends Yolo2OutputLayerBuilder<C, B>>
|
||||||
public static class Builder extends LayerConfiguration.Builder<Builder> {
|
extends LayerConfigurationBuilder<C, B> {
|
||||||
|
public C build() {
|
||||||
/**
|
if (boundingBoxes$value == null) {
|
||||||
* Loss function coefficient for position and size/scale components of the loss function. Default (as per
|
|
||||||
* paper): 5
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private double lambdaCoord = 5;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loss function coefficient for the "no object confidence" components of the loss function. Default (as per
|
|
||||||
* paper): 0.5
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private double lambdaNoObj = 0.5;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loss function for position/scale component of the loss function
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private ILossFunction lossPositionScale = new LossL2();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as per the
|
|
||||||
* paper), however Loss MCXENT could also be used (which is more common for classification).
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private ILossFunction lossClassPredictions = new LossL2();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows, columns] = [N,
|
|
||||||
* 2] Note that dimensions should be specified as fraction of grid size. For example, a network with 13x13
|
|
||||||
* output, a value of 1.0 would correspond to one grid cell; a value of 13 would correspond to the entire
|
|
||||||
* image.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private INDArray boundingBoxes;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loss function coefficient for position and size/scale components of the loss function. Default (as per
|
|
||||||
* paper): 5
|
|
||||||
*
|
|
||||||
* @param lambdaCoord Lambda value for size/scale component of loss function
|
|
||||||
*/
|
|
||||||
public Builder lambdaCoord(double lambdaCoord) {
|
|
||||||
this.setLambdaCoord(lambdaCoord);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loss function coefficient for the "no object confidence" components of the loss function. Default (as per
|
|
||||||
* paper): 0.5
|
|
||||||
*
|
|
||||||
* @param lambdaNoObj Lambda value for no-object (confidence) component of the loss function
|
|
||||||
*/
|
|
||||||
public Builder lambdaNoObj(double lambdaNoObj) {
|
|
||||||
this.setLambdaNoObj(lambdaNoObj);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loss function for position/scale component of the loss function
|
|
||||||
*
|
|
||||||
* @param lossPositionScale Loss function for position/scale
|
|
||||||
*/
|
|
||||||
public Builder lossPositionScale(ILossFunction lossPositionScale) {
|
|
||||||
this.setLossPositionScale(lossPositionScale);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Loss function for the class predictions - defaults to L2 loss (i.e., sum of squared errors, as per the
|
|
||||||
* paper), however Loss MCXENT could also be used (which is more common for classification).
|
|
||||||
*
|
|
||||||
* @param lossClassPredictions Loss function for the class prediction error component of the YOLO loss function
|
|
||||||
*/
|
|
||||||
public Builder lossClassPredictions(ILossFunction lossClassPredictions) {
|
|
||||||
this.setLossClassPredictions(lossClassPredictions);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Bounding box priors dimensions [width, height]. For N bounding boxes, input has shape [rows, columns] = [N,
|
|
||||||
* 2] Note that dimensions should be specified as fraction of grid size. For example, a network with 13x13
|
|
||||||
* output, a value of 1.0 would correspond to one grid cell; a value of 13 would correspond to the entire
|
|
||||||
* image.
|
|
||||||
*
|
|
||||||
* @param boundingBoxes Bounding box prior dimensions (width, height)
|
|
||||||
*/
|
|
||||||
public Builder boundingBoxPriors(INDArray boundingBoxes) {
|
|
||||||
this.setBoundingBoxes(boundingBoxes);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Yolo2OutputLayer build() {
|
|
||||||
if (boundingBoxes == null) {
|
|
||||||
throw new IllegalStateException("Bounding boxes have not been set");
|
throw new IllegalStateException("Bounding boxes have not been set");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (boundingBoxes.rank() != 2 || boundingBoxes.size(1) != 2) {
|
if (boundingBoxes$value.rank() != 2 || boundingBoxes$value.size(1) != 2) {
|
||||||
throw new IllegalStateException("Bounding box priors must have shape [nBoxes, 2]. Has shape: "
|
throw new IllegalStateException(
|
||||||
+ Arrays.toString(boundingBoxes.shape()));
|
"Bounding box priors must have shape [nBoxes, 2]. Has shape: "
|
||||||
|
+ Arrays.toString(boundingBoxes$value.shape()));
|
||||||
}
|
}
|
||||||
|
return initBuild();
|
||||||
return new Yolo2OutputLayer(this);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ public class SimpleRnn extends BaseRecurrentLayer {
|
||||||
* If true (default = false): enable layer normalization on this layer
|
* If true (default = false): enable layer normalization on this layer
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@lombok.Builder.Default @Accessors
|
@lombok.Builder.Default @Accessors @Getter
|
||||||
private boolean hasLayerNorm = false;
|
private boolean hasLayerNorm = false;
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.nn.conf.layers.samediff;
|
package org.deeplearning4j.nn.conf.layers.samediff;
|
||||||
|
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.experimental.SuperBuilder;
|
import lombok.experimental.SuperBuilder;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
@ -44,11 +47,6 @@ import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.Regularization;
|
import org.nd4j.linalg.learning.regularization.Regularization;
|
||||||
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
import org.nd4j.linalg.learning.regularization.WeightDecay;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Data
|
@Data
|
||||||
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
|
@EqualsAndHashCode(callSuper = true, doNotUseGetters = true)
|
||||||
|
@ -59,15 +57,17 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
/**
|
/**
|
||||||
* The regularization for the parameters (excluding biases) - for example {@link WeightDecay}
|
* The regularization for the parameters (excluding biases) - for example {@link WeightDecay}
|
||||||
*
|
*
|
||||||
* -- SETTER --
|
* <p>-- SETTER -- Set the regularization for the parameters (excluding biases) - for example
|
||||||
* Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay}
|
* {@link WeightDecay}
|
||||||
* @param regularization Regularization to apply for the network parameters/weights (excluding biases)
|
*
|
||||||
|
* @param regularization Regularization to apply for the network parameters/weights (excluding
|
||||||
|
* biases)
|
||||||
*/
|
*/
|
||||||
protected List<Regularization> regularization;
|
protected List<Regularization> regularization;
|
||||||
/**
|
/**
|
||||||
* The regularization for the biases only - for example {@link WeightDecay}
|
* The regularization for the biases only - for example {@link WeightDecay} -- SETTER -- Set the
|
||||||
* -- SETTER --
|
* regularization for the biases only - for example {@link WeightDecay}
|
||||||
* Set the regularization for the biases only - for example {@link WeightDecay}
|
*
|
||||||
* @param regularizationBias Regularization to apply for the network biases only
|
* @param regularizationBias Regularization to apply for the network biases only
|
||||||
*/
|
*/
|
||||||
protected List<Regularization> regularizationBias;
|
protected List<Regularization> regularizationBias;
|
||||||
|
@ -79,14 +79,13 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
*/
|
*/
|
||||||
protected @Getter @Setter IUpdater updater;
|
protected @Getter @Setter IUpdater updater;
|
||||||
/**
|
/**
|
||||||
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as set by {@link
|
* Gradient updater configuration, for the biases only. If not set, biases will use the updater as
|
||||||
* #setUpdater(IUpdater)}
|
* set by {@link #setUpdater(IUpdater)}
|
||||||
*
|
*
|
||||||
* @param biasUpdater Updater to use for bias parameters
|
* @param biasUpdater Updater to use for bias parameters
|
||||||
*/
|
*/
|
||||||
protected @Getter @Setter IUpdater biasUpdater;
|
protected @Getter @Setter IUpdater biasUpdater;
|
||||||
|
|
||||||
|
|
||||||
protected GradientNormalization gradientNormalization;
|
protected GradientNormalization gradientNormalization;
|
||||||
protected double gradientNormalizationThreshold = Double.NaN;
|
protected double gradientNormalizationThreshold = Double.NaN;
|
||||||
|
|
||||||
|
@ -94,9 +93,9 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Regularization> getRegularizationByParam(String paramName) {
|
public List<Regularization> getRegularizationByParam(String paramName) {
|
||||||
if(layerParams.isWeightParam(paramName)){
|
if (layerParams.isWeightParam(paramName)) {
|
||||||
return regularization;
|
return regularization;
|
||||||
} else if(layerParams.isBiasParam(paramName)){
|
} else if (layerParams.isBiasParam(paramName)) {
|
||||||
return regularizationBias;
|
return regularizationBias;
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
|
@ -112,23 +111,23 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setNIn(InputType inputType, boolean override) {
|
public void setNIn(InputType inputType, boolean override) {
|
||||||
//Default implementation: no-op
|
// Default implementation: no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
|
||||||
//Default implementation: no-op
|
// Default implementation: no-op
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void applyGlobalConfigToLayer(
|
||||||
public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) {
|
||||||
//Default implementation: no op
|
// Default implementation: no op
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String, long...)} and {@link
|
* Define the parameters for the network. Use {@link SDLayerParams#addWeightParam(String,
|
||||||
* SDLayerParams#addBiasParam(String, long...)}
|
* long...)} and {@link SDLayerParams#addBiasParam(String, long...)}
|
||||||
*
|
*
|
||||||
* @param params Object used to set parameters for this layer
|
* @param params Object used to set parameters for this layer
|
||||||
*/
|
*/
|
||||||
|
@ -142,11 +141,15 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
public abstract void initializeParameters(Map<String, INDArray> params);
|
public abstract void initializeParameters(Map<String, INDArray> params);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf,
|
public abstract org.deeplearning4j.nn.api.Layer instantiate(
|
||||||
Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView,
|
NeuralNetConfiguration conf,
|
||||||
boolean initializeParams, DataType networkDataType);
|
Collection<TrainingListener> trainingListeners,
|
||||||
|
int layerIndex,
|
||||||
|
INDArray layerParamsView,
|
||||||
|
boolean initializeParams,
|
||||||
|
DataType networkDataType);
|
||||||
|
|
||||||
//==================================================================================================================
|
// ==================================================================================================================
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParamInitializer initializer() {
|
public ParamInitializer initializer() {
|
||||||
|
@ -157,7 +160,8 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
public IUpdater getUpdaterByParam(String paramName) {
|
public IUpdater getUpdaterByParam(String paramName) {
|
||||||
if (biasUpdater != null && initializer().isBiasParam(this, paramName)) {
|
if (biasUpdater != null && initializer().isBiasParam(this, paramName)) {
|
||||||
return biasUpdater;
|
return biasUpdater;
|
||||||
} else if (initializer().isBiasParam(this, paramName) || initializer().isWeightParam(this, paramName)) {
|
} else if (initializer().isBiasParam(this, paramName)
|
||||||
|
|| initializer().isWeightParam(this, paramName)) {
|
||||||
return updater;
|
return updater;
|
||||||
}
|
}
|
||||||
throw new IllegalStateException("Unknown parameter key: " + paramName);
|
throw new IllegalStateException("Unknown parameter key: " + paramName);
|
||||||
|
@ -170,12 +174,12 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
public LayerMemoryReport getMemoryReport(InputType inputType) {
|
||||||
return new LayerMemoryReport(); //TODO
|
return new LayerMemoryReport(); // TODO
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the memory layout ('c' or 'f' order - i.e., row/column major) of the parameters. In most cases, this
|
* Returns the memory layout ('c' or 'f' order - i.e., row/column major) of the parameters. In
|
||||||
* can/should be left
|
* most cases, this can/should be left
|
||||||
*
|
*
|
||||||
* @param param Name of the parameter
|
* @param param Name of the parameter
|
||||||
* @return Memory layout ('c' or 'f') of the parameter
|
* @return Memory layout ('c' or 'f') of the parameter
|
||||||
|
@ -185,7 +189,8 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void initWeights(int fanIn, int fanOut, WeightInit weightInit, INDArray array) {
|
protected void initWeights(int fanIn, int fanOut, WeightInit weightInit, INDArray array) {
|
||||||
WeightInitUtil.initWeights(fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
|
WeightInitUtil.initWeights(
|
||||||
|
fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) {
|
public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) {
|
||||||
|
@ -213,63 +218,74 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method generates an "all ones" mask array for use in the SameDiff model when none is provided.
|
* This method generates an "all ones" mask array for use in the SameDiff model when none is
|
||||||
|
* provided.
|
||||||
|
*
|
||||||
* @param input Input to the layer
|
* @param input Input to the layer
|
||||||
* @return A mask array - should be same datatype as the input (usually)
|
* @return A mask array - should be same datatype as the input (usually)
|
||||||
*/
|
*/
|
||||||
public INDArray onesMaskForInput(INDArray input){
|
public INDArray onesMaskForInput(INDArray input) {
|
||||||
if(input.rank() == 2){
|
if (input.rank() == 2) {
|
||||||
return Nd4j.ones(input.dataType(), input.size(0), 1);
|
return Nd4j.ones(input.dataType(), input.size(0), 1);
|
||||||
} else if(input.rank() == 3){
|
} else if (input.rank() == 3) {
|
||||||
return Nd4j.ones(input.dataType(), input.size(0), input.size(2)); //mask: [mb, length] vs. input [mb, nIn, length]
|
return Nd4j.ones(
|
||||||
} else if(input.rank() == 4){
|
input.dataType(),
|
||||||
//CNN style - return [mb, 1, 1, 1] for broadcast...
|
input.size(0),
|
||||||
|
input.size(2)); // mask: [mb, length] vs. input [mb, nIn, length]
|
||||||
|
} else if (input.rank() == 4) {
|
||||||
|
// CNN style - return [mb, 1, 1, 1] for broadcast...
|
||||||
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1);
|
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1);
|
||||||
} else if(input.rank() == 5){
|
} else if (input.rank() == 5) {
|
||||||
//CNN3D style - return [mb, 1, 1, 1, 1] for broadcast...
|
// CNN3D style - return [mb, 1, 1, 1, 1] for broadcast...
|
||||||
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1);
|
return Nd4j.ones(input.dataType(), input.size(0), 1, 1, 1, 1);
|
||||||
} else {
|
} else {
|
||||||
throw new IllegalStateException("When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, " +
|
throw new IllegalStateException(
|
||||||
"in order to determine the correct mask shape for this layer");
|
"When using masking with rank 1 or 6+ inputs, the onesMaskForInput method must be implemented, "
|
||||||
|
+ "in order to determine the correct mask shape for this layer");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static abstract class AbstractSameDiffLayerBuilder<C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>> {
|
public abstract static class AbstractSameDiffLayerBuilder<
|
||||||
|
C extends AbstractSameDiffLayer, B extends AbstractSameDiffLayerBuilder<C, B>>
|
||||||
|
extends LayerConfigurationBuilder<C, B> {
|
||||||
/**
|
/**
|
||||||
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1 regularization
|
* L1 regularization coefficient (weights only). Use {@link #l1Bias(double)} to configure the l1
|
||||||
* coefficient for the bias.
|
* regularization coefficient for the bias.
|
||||||
*/
|
*/
|
||||||
public B l1(double l1) {
|
public B l1(double l1) {
|
||||||
//Check if existing L1 exists; if so, replace it
|
// Check if existing L1 exists; if so, replace it
|
||||||
NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
|
NetworkUtils.removeInstances(this.regularization, L1Regularization.class);
|
||||||
if(l1 > 0.0) {
|
if (l1 > 0.0) {
|
||||||
this.regularization.add(new L1Regularization(l1));
|
this.regularization.add(new L1Regularization(l1));
|
||||||
}
|
}
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* L2 regularization coefficient (weights only). Use {@link #l2Bias(double)} to configure the l2 regularization
|
* L2 regularization coefficient (weights only). Use {@link #l2Bias(double)} to configure the l2
|
||||||
* coefficient for the bias.<br>
|
* regularization coefficient for the bias.<br>
|
||||||
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)} should be preferred to
|
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)}
|
||||||
* L2 regularization. See {@link WeightDecay} javadoc for further details.<br>
|
* should be preferred to L2 regularization. See {@link WeightDecay} javadoc for further
|
||||||
|
* details.<br>
|
||||||
*/
|
*/
|
||||||
public B l2(double l2) {
|
public B l2(double l2) {
|
||||||
//Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make sense to use both
|
// Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make
|
||||||
|
// sense to use both
|
||||||
NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
|
NetworkUtils.removeInstances(this.regularization, L2Regularization.class);
|
||||||
if(l2 > 0.0) {
|
if (l2 > 0.0) {
|
||||||
NetworkUtils.removeInstancesWithWarning(this.regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization");
|
NetworkUtils.removeInstancesWithWarning(
|
||||||
|
this.regularization,
|
||||||
|
WeightDecay.class,
|
||||||
|
"WeightDecay regularization removed: incompatible with added L2 regularization");
|
||||||
this.regularization.add(new L2Regularization(l2));
|
this.regularization.add(new L2Regularization(l2));
|
||||||
}
|
}
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/** L1 regularization coefficient for the bias. Default: 0. See also {@link #l1(double)} */
|
||||||
* L1 regularization coefficient for the bias. Default: 0. See also {@link #l1(double)}
|
|
||||||
*/
|
|
||||||
public B l1Bias(double l1Bias) {
|
public B l1Bias(double l1Bias) {
|
||||||
NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
|
NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class);
|
||||||
if(l1Bias > 0.0) {
|
if (l1Bias > 0.0) {
|
||||||
this.regularizationBias.add(new L1Regularization(l1Bias));
|
this.regularizationBias.add(new L1Regularization(l1Bias));
|
||||||
}
|
}
|
||||||
return self();
|
return self();
|
||||||
|
@ -277,13 +293,17 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* L2 regularization coefficient for the bias. Default: 0. See also {@link #l2(double)}<br>
|
* L2 regularization coefficient for the bias. Default: 0. See also {@link #l2(double)}<br>
|
||||||
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)} should be preferred to
|
* <b>Note</b>: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)}
|
||||||
* L2 regularization. See {@link WeightDecay} javadoc for further details.<br>
|
* should be preferred to L2 regularization. See {@link WeightDecay} javadoc for further
|
||||||
|
* details.<br>
|
||||||
*/
|
*/
|
||||||
public B l2Bias(double l2Bias) {
|
public B l2Bias(double l2Bias) {
|
||||||
NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
|
NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class);
|
||||||
if(l2Bias > 0.0) {
|
if (l2Bias > 0.0) {
|
||||||
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, WeightDecay.class, "WeightDecay bias regularization removed: incompatible with added L2 regularization");
|
NetworkUtils.removeInstancesWithWarning(
|
||||||
|
this.regularizationBias,
|
||||||
|
WeightDecay.class,
|
||||||
|
"WeightDecay bias regularization removed: incompatible with added L2 regularization");
|
||||||
this.regularizationBias.add(new L2Regularization(l2Bias));
|
this.regularizationBias.add(new L2Regularization(l2Bias));
|
||||||
}
|
}
|
||||||
return self();
|
return self();
|
||||||
|
@ -291,7 +311,8 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add weight decay regularization for the network parameters (excluding biases).<br>
|
* Add weight decay regularization for the network parameters (excluding biases).<br>
|
||||||
* This applies weight decay <i>with</i> multiplying the learning rate - see {@link WeightDecay} for more details.<br>
|
* This applies weight decay <i>with</i> multiplying the learning rate - see {@link WeightDecay}
|
||||||
|
* for more details.<br>
|
||||||
*
|
*
|
||||||
* @param coefficient Weight decay regularization coefficient
|
* @param coefficient Weight decay regularization coefficient
|
||||||
* @see #weightDecay(double, boolean)
|
* @see #weightDecay(double, boolean)
|
||||||
|
@ -301,25 +322,31 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add weight decay regularization for the network parameters (excluding biases). See {@link WeightDecay} for more details.<br>
|
* Add weight decay regularization for the network parameters (excluding biases). See {@link
|
||||||
|
* WeightDecay} for more details.<br>
|
||||||
*
|
*
|
||||||
* @param coefficient Weight decay regularization coefficient
|
* @param coefficient Weight decay regularization coefficient
|
||||||
* @param applyLR Whether the learning rate should be multiplied in when performing weight decay updates. See {@link WeightDecay} for more details.
|
* @param applyLR Whether the learning rate should be multiplied in when performing weight decay
|
||||||
|
* updates. See {@link WeightDecay} for more details.
|
||||||
* @see #weightDecay(double, boolean)
|
* @see #weightDecay(double, boolean)
|
||||||
*/
|
*/
|
||||||
public B weightDecay(double coefficient, boolean applyLR) {
|
public B weightDecay(double coefficient, boolean applyLR) {
|
||||||
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both
|
// Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't
|
||||||
|
// make sense to use both
|
||||||
NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
|
NetworkUtils.removeInstances(this.regularization, WeightDecay.class);
|
||||||
if(coefficient > 0.0) {
|
if (coefficient > 0.0) {
|
||||||
NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization");
|
NetworkUtils.removeInstancesWithWarning(
|
||||||
|
this.regularization,
|
||||||
|
L2Regularization.class,
|
||||||
|
"L2 regularization removed: incompatible with added WeightDecay regularization");
|
||||||
this.regularization.add(new WeightDecay(coefficient, applyLR));
|
this.regularization.add(new WeightDecay(coefficient, applyLR));
|
||||||
}
|
}
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details.
|
* Weight decay for the biases only - see {@link #weightDecay(double)} for more details. This
|
||||||
* This applies weight decay <i>with</i> multiplying the learning rate.<br>
|
* applies weight decay <i>with</i> multiplying the learning rate.<br>
|
||||||
*
|
*
|
||||||
* @param coefficient Weight decay regularization coefficient
|
* @param coefficient Weight decay regularization coefficient
|
||||||
* @see #weightDecayBias(double, boolean)
|
* @see #weightDecayBias(double, boolean)
|
||||||
|
@ -334,10 +361,14 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
|
||||||
* @param coefficient Weight decay regularization coefficient
|
* @param coefficient Weight decay regularization coefficient
|
||||||
*/
|
*/
|
||||||
public B weightDecayBias(double coefficient, boolean applyLR) {
|
public B weightDecayBias(double coefficient, boolean applyLR) {
|
||||||
//Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both
|
// Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't
|
||||||
|
// make sense to use both
|
||||||
NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
|
NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class);
|
||||||
if(coefficient > 0.0) {
|
if (coefficient > 0.0) {
|
||||||
NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization");
|
NetworkUtils.removeInstancesWithWarning(
|
||||||
|
this.regularizationBias,
|
||||||
|
L2Regularization.class,
|
||||||
|
"L2 bias regularization removed: incompatible with added WeightDecay regularization");
|
||||||
this.regularizationBias.add(new WeightDecay(coefficient, applyLR));
|
this.regularizationBias.add(new WeightDecay(coefficient, applyLR));
|
||||||
}
|
}
|
||||||
return self();
|
return self();
|
||||||
|
|
|
@ -176,7 +176,7 @@ public abstract class SameDiffVertex extends GraphVertex implements ITraininable
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String getLayerName() {
|
public String getName() {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -285,5 +285,14 @@ public class VariationalAutoencoder extends BasePretrainNetwork {
|
||||||
super.nOut(nOut);
|
super.nOut(nOut);
|
||||||
return self();
|
return self();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public B pzxActivationFunction(IActivation activation) {
|
||||||
|
this.pzxActivationFunction$value = activation;
|
||||||
|
this.pzxActivationFunction$set = true;
|
||||||
|
return self();
|
||||||
|
}
|
||||||
|
public B pzxActivationFunction(Activation activation) {
|
||||||
|
return this.pzxActivationFunction(activation.getActivationFunction());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -107,7 +107,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
INDArray origInput = input;
|
INDArray origInput = input;
|
||||||
INDArray origEps = epsilon;
|
INDArray origEps = epsilon;
|
||||||
if(getTypedLayerConfiguration().getDataFormat() != CNN2DFormat.NCHW) {
|
if(getTypedLayerConfiguration().getConvFormat() != CNN2DFormat.NCHW) {
|
||||||
input = input.permute(0,3,1,2); //NHWC to NCHW
|
input = input.permute(0,3,1,2); //NHWC to NCHW
|
||||||
epsilon = epsilon.permute(0,3,1,2); //NHWC to NCHW
|
epsilon = epsilon.permute(0,3,1,2); //NHWC to NCHW
|
||||||
}
|
}
|
||||||
|
@ -151,7 +151,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
|
Pair<INDArray, INDArray> p = preOutput4d(true, true, workspaceMgr);
|
||||||
INDArray z = p.getFirst();
|
INDArray z = p.getFirst();
|
||||||
CNN2DFormat f = getTypedLayerConfiguration().getDataFormat();
|
CNN2DFormat f = getTypedLayerConfiguration().getConvFormat();
|
||||||
if(f != CNN2DFormat.NCHW){
|
if(f != CNN2DFormat.NCHW){
|
||||||
z = z.permute(0,3,1,2); //NHWC to NCHW
|
z = z.permute(0,3,1,2); //NHWC to NCHW
|
||||||
}
|
}
|
||||||
|
@ -159,7 +159,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())) {
|
if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())) {
|
||||||
INDArray helperDelta = delta;
|
INDArray helperDelta = delta;
|
||||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC)
|
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC)
|
||||||
helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC
|
helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
|
||||||
if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){
|
if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){
|
||||||
|
@ -177,7 +177,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
ret = helper.backpropGradient(origInput, weights, bias, helperDelta, kernel, strides,
|
ret = helper.backpropGradient(origInput, weights, bias, helperDelta, kernel, strides,
|
||||||
pad, biasGradView, weightGradView, afn,
|
pad, biasGradView, weightGradView, afn,
|
||||||
getTypedLayerConfiguration().getCudnnAlgoMode(), getTypedLayerConfiguration().getCudnnBwdFilterAlgo(), getTypedLayerConfiguration().getCudnnBwdDataAlgo(),
|
getTypedLayerConfiguration().getCudnnAlgoMode(), getTypedLayerConfiguration().getCudnnBwdFilterAlgo(), getTypedLayerConfiguration().getCudnnBwdDataAlgo(),
|
||||||
convolutionMode, dilation, getTypedLayerConfiguration().getDataFormat(), workspaceMgr);
|
convolutionMode, dilation, getTypedLayerConfiguration().getConvFormat(), workspaceMgr);
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
|
@ -261,7 +261,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
epsNext = backpropDropOutIfPresent(epsNext);
|
epsNext = backpropDropOutIfPresent(epsNext);
|
||||||
|
|
||||||
if(getTypedLayerConfiguration().getDataFormat() != CNN2DFormat.NCHW){
|
if(getTypedLayerConfiguration().getConvFormat()!= CNN2DFormat.NCHW){
|
||||||
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
|
epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -295,7 +295,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void validateInputDepth(long inDepth) {
|
protected void validateInputDepth(long inDepth) {
|
||||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||||
int dim = format == CNN2DFormat.NHWC ? 3 : 1;
|
int dim = format == CNN2DFormat.NHWC ? 3 : 1;
|
||||||
if (input.size(dim) != inDepth) {
|
if (input.size(dim) != inDepth) {
|
||||||
String layerName = layerConfiguration.getName();
|
String layerName = layerConfiguration.getName();
|
||||||
|
@ -304,7 +304,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
|
String s = "Cannot do forward pass in Convolution layer (layer name = " + layerName
|
||||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||||
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + getTypedLayerConfiguration().getDataFormat().dimensionNames()
|
+ " (data format = " + format + ", data input channels = " + input.size(dim) + ", " + getTypedLayerConfiguration().getConvFormat().dimensionNames()
|
||||||
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
+ "=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||||
+ layerId();
|
+ layerId();
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType);
|
INDArray input = this.input.castTo(dataType);
|
||||||
INDArray inputOrig = input;
|
INDArray inputOrig = input;
|
||||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||||
input = input.permute(0,3,1,2).dup(); //NHWC to NCHW
|
input = input.permute(0,3,1,2).dup(); //NHWC to NCHW
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -421,7 +421,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
INDArray ret = null;
|
INDArray ret = null;
|
||||||
try {
|
try {
|
||||||
ret = helper.preOutput(inputOrig, weights, bias, kernel, strides, pad, getTypedLayerConfiguration().getCudnnAlgoMode(),
|
ret = helper.preOutput(inputOrig, weights, bias, kernel, strides, pad, getTypedLayerConfiguration().getCudnnAlgoMode(),
|
||||||
getTypedLayerConfiguration().getCudnnFwdAlgo(), convolutionMode, dilation, getTypedLayerConfiguration().getDataFormat(), workspaceMgr);
|
getTypedLayerConfiguration().getCudnnFwdAlgo(), convolutionMode, dilation, getTypedLayerConfiguration().getConvFormat(), workspaceMgr);
|
||||||
} catch (ND4JOpProfilerException e){
|
} catch (ND4JOpProfilerException e){
|
||||||
throw e; //NaN panic etc for debugging
|
throw e; //NaN panic etc for debugging
|
||||||
} catch (Exception e){
|
} catch (Exception e){
|
||||||
|
@ -498,7 +498,7 @@ public class ConvolutionLayer extends BaseLayer<org.deeplearning4j.nn.conf.layer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||||
z = z.permute(0,2,3,1); //NCHW to NHWC
|
z = z.permute(0,2,3,1); //NCHW to NHWC
|
||||||
z = workspaceMgr.dup(ArrayType.ACTIVATIONS, z);
|
z = workspaceMgr.dup(ArrayType.ACTIVATIONS, z);
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,13 +61,13 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
if (input.rank() != 4) {
|
if (input.rank() != 4) {
|
||||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||||
+ " array as input to Deconvolution2DLayer with shape " + Arrays.toString(input.shape())
|
+ " array as input to Deconvolution2DLayer with shape " + Arrays.toString(input.shape())
|
||||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
|
||||||
+ layerId());
|
+ layerId());
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr);
|
INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr);
|
||||||
|
|
||||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||||
boolean nchw = format == CNN2DFormat.NCHW;
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
int hDim = nchw ? 2 : 1;
|
int hDim = nchw ? 2 : 1;
|
||||||
int wDim = nchw ? 3 : 2;
|
int wDim = nchw ? 3 : 2;
|
||||||
|
@ -166,7 +166,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
||||||
+ " " + layerId());
|
+ " " + layerId());
|
||||||
}
|
}
|
||||||
|
|
||||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||||
boolean nchw = format == CNN2DFormat.NCHW;
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
int cDim = nchw ? 1 : 3;
|
int cDim = nchw ? 1 : 3;
|
||||||
int hDim = nchw ? 2 : 1;
|
int hDim = nchw ? 2 : 1;
|
||||||
|
|
|
@ -59,12 +59,12 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
@Override
|
@Override
|
||||||
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
|
||||||
assertInputSet(true);
|
assertInputSet(true);
|
||||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||||
boolean nchw = format == CNN2DFormat.NCHW;
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
if (input.rank() != 4) {
|
if (input.rank() != 4) {
|
||||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||||
+ " array as input to Convolution layer with shape " + Arrays.toString(input.shape())
|
+ " array as input to Convolution layer with shape " + Arrays.toString(input.shape())
|
||||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
|
||||||
+ layerId());
|
+ layerId());
|
||||||
}
|
}
|
||||||
INDArray bias;
|
INDArray bias;
|
||||||
|
@ -158,7 +158,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||||
+ " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = "
|
+ " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = "
|
||||||
+ index + ") with shape " + Arrays.toString(input.shape()) + ". "
|
+ index + ") with shape " + Arrays.toString(input.shape()) + ". "
|
||||||
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + "."
|
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + "."
|
||||||
+ (input.rank() == 2
|
+ (input.rank() == 2
|
||||||
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
|
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
|
||||||
: "") + " " + layerId());
|
: "") + " " + layerId());
|
||||||
|
@ -166,7 +166,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType); //no-op if correct dtype
|
INDArray input = this.input.castTo(dataType); //no-op if correct dtype
|
||||||
|
|
||||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||||
boolean nchw = format == CNN2DFormat.NCHW;
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
long inDepth = depthWiseWeights.size(2);
|
long inDepth = depthWiseWeights.size(2);
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
if (input.rank() != 4) {
|
if (input.rank() != 4) {
|
||||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||||
+ " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape())
|
+ " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape())
|
||||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + ". "
|
||||||
+ layerId());
|
+ layerId());
|
||||||
}
|
}
|
||||||
INDArray bias;
|
INDArray bias;
|
||||||
|
@ -74,7 +74,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType);
|
INDArray input = this.input.castTo(dataType);
|
||||||
|
|
||||||
CNN2DFormat format = getTypedLayerConfiguration().getDataFormat();
|
CNN2DFormat format = getTypedLayerConfiguration().getConvFormat();
|
||||||
boolean nchw = format == CNN2DFormat.NCHW;
|
boolean nchw = format == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
long miniBatch = input.size(0);
|
long miniBatch = input.size(0);
|
||||||
|
@ -167,7 +167,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
getParamWithNoise(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY, training, workspaceMgr);
|
getParamWithNoise(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY, training, workspaceMgr);
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType);
|
INDArray input = this.input.castTo(dataType);
|
||||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||||
input = input.permute(0,3,1,2).dup();
|
input = input.permute(0,3,1,2).dup();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,7 +182,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||||
+ " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = "
|
+ " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = "
|
||||||
+ index + ") with shape " + Arrays.toString(input.shape()) + ". "
|
+ index + ") with shape " + Arrays.toString(input.shape()) + ". "
|
||||||
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + "."
|
+ "Expected rank 4 array with shape " + getTypedLayerConfiguration().getConvFormat().dimensionNames() + "."
|
||||||
+ (input.rank() == 2
|
+ (input.rank() == 2
|
||||||
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
|
? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)"
|
||||||
: "")
|
: "")
|
||||||
|
@ -199,7 +199,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
|
|
||||||
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
|
String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName
|
||||||
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
+ ", layer index = " + index + "): input array channels does not match CNN layer configuration"
|
||||||
+ " (data format = " + getTypedLayerConfiguration().getDataFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
+ " (data format = " + getTypedLayerConfiguration().getConvFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]="
|
||||||
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
+ Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") "
|
||||||
+ layerId();
|
+ layerId();
|
||||||
|
|
||||||
|
@ -287,7 +287,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer {
|
||||||
.build();
|
.build();
|
||||||
Nd4j.getExecutioner().exec(op);
|
Nd4j.getExecutioner().exec(op);
|
||||||
|
|
||||||
if(getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NHWC) {
|
if(getTypedLayerConfiguration().getConvFormat() == CNN2DFormat.NHWC) {
|
||||||
output = output.permute(0,2,3,1); //NCHW to NHWC
|
output = output.permute(0,2,3,1); //NCHW to NHWC
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[] getBlocks() {
|
private int[] getBlocks() {
|
||||||
return getTypedLayerConfiguration().getBlocks();
|
return getTypedLayerConfiguration().getBlockSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[][] getPadding() {
|
private int[][] getPadding() {
|
||||||
|
@ -55,7 +55,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
}
|
}
|
||||||
|
|
||||||
private INDArray getBlocksArray() {
|
private INDArray getBlocksArray() {
|
||||||
int[] intBlocks = getTypedLayerConfiguration().getBlocks();
|
int[] intBlocks = getTypedLayerConfiguration().getBlockSize();
|
||||||
return Nd4j.createFromArray(intBlocks);
|
return Nd4j.createFromArray(intBlocks);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
|
|
||||||
INDArray input = this.input.castTo(dataType); //Cast to network dtype if required (no-op if already correct type)
|
INDArray input = this.input.castTo(dataType); //Cast to network dtype if required (no-op if already correct type)
|
||||||
|
|
||||||
boolean nchw = getTypedLayerConfiguration().getFormat() == CNN2DFormat.NCHW;
|
boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
|
INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape(), 'c');
|
||||||
|
|
||||||
|
@ -104,7 +104,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
if (input.rank() != 4) {
|
if (input.rank() != 4) {
|
||||||
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
throw new DL4JInvalidInputException("Got rank " + input.rank()
|
||||||
+ " array as input to space to batch with shape " + Arrays.toString(input.shape())
|
+ " array as input to space to batch with shape " + Arrays.toString(input.shape())
|
||||||
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getFormat().dimensionNames() + ". "
|
+ ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getDataFormat().dimensionNames() + ". "
|
||||||
+ layerId());
|
+ layerId());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ public class SpaceToBatch extends AbstractLayer<org.deeplearning4j.nn.conf.layer
|
||||||
return preOutput;
|
return preOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean nchw = getTypedLayerConfiguration().getFormat() == CNN2DFormat.NCHW;
|
boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW;
|
||||||
|
|
||||||
long inMiniBatch = input.size(0);
|
long inMiniBatch = input.size(0);
|
||||||
long depth = input.size(nchw ? 1 : 3);
|
long depth = input.size(nchw ? 1 : 3);
|
||||||
|
|
|
@ -87,7 +87,7 @@ public class ZeroPadding1DLayer extends AbstractLayer<org.deeplearning4j.nn.conf
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Layer clone() {
|
public Layer clone() {
|
||||||
return ZeroPadding1DLayer.builder(layerConfiguration.clone(), dataType);
|
return new ZeroPadding1DLayer(layerConfiguration.clone(), dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -312,6 +312,6 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean hasLayerNorm(){
|
public boolean hasLayerNorm(){
|
||||||
return getTypedLayerConfiguration().hasLayerNorm();
|
return getTypedLayerConfiguration().isHasLayerNorm();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -167,7 +167,7 @@ public class SimpleRnnParamInitializer extends AbstractParamInitializer {
|
||||||
|
|
||||||
protected boolean hasLayerNorm(LayerConfiguration layer){
|
protected boolean hasLayerNorm(LayerConfiguration layer){
|
||||||
if(layer instanceof SimpleRnn){
|
if(layer instanceof SimpleRnn){
|
||||||
return ((SimpleRnn) layer).hasLayerNorm();
|
return ((SimpleRnn) layer).isHasLayerNorm();
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,39 +21,74 @@
|
||||||
package org.deeplearning4j.nn.weights;
|
package org.deeplearning4j.nn.weights;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
import org.deeplearning4j.nn.conf.distribution.Distribution;
|
||||||
|
import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer;
|
||||||
|
import org.deeplearning4j.nn.weights.embeddings.WeightInitEmbedding;
|
||||||
|
|
||||||
public enum WeightInit {
|
public enum WeightInit {
|
||||||
DISTRIBUTION, ZERO, ONES, SIGMOID_UNIFORM, NORMAL, LECUN_NORMAL, UNIFORM, XAVIER, XAVIER_UNIFORM, XAVIER_FAN_IN, XAVIER_LEGACY, RELU,
|
DISTRIBUTION,
|
||||||
RELU_UNIFORM, IDENTITY, LECUN_UNIFORM, VAR_SCALING_NORMAL_FAN_IN, VAR_SCALING_NORMAL_FAN_OUT, VAR_SCALING_NORMAL_FAN_AVG,
|
ZERO,
|
||||||
VAR_SCALING_UNIFORM_FAN_IN, VAR_SCALING_UNIFORM_FAN_OUT, VAR_SCALING_UNIFORM_FAN_AVG;
|
ONES,
|
||||||
|
SIGMOID_UNIFORM,
|
||||||
|
NORMAL,
|
||||||
|
LECUN_NORMAL,
|
||||||
|
UNIFORM,
|
||||||
|
XAVIER,
|
||||||
|
XAVIER_UNIFORM,
|
||||||
|
XAVIER_FAN_IN,
|
||||||
|
XAVIER_LEGACY,
|
||||||
|
RELU,
|
||||||
|
RELU_UNIFORM,
|
||||||
|
IDENTITY,
|
||||||
|
LECUN_UNIFORM,
|
||||||
|
VAR_SCALING_NORMAL_FAN_IN,
|
||||||
|
VAR_SCALING_NORMAL_FAN_OUT,
|
||||||
|
VAR_SCALING_NORMAL_FAN_AVG,
|
||||||
|
CONSTANT,
|
||||||
|
EMBEDDING,
|
||||||
|
VAR_SCALING_UNIFORM_FAN_IN,
|
||||||
|
VAR_SCALING_UNIFORM_FAN_OUT,
|
||||||
|
VAR_SCALING_UNIFORM_FAN_AVG;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an instance of the weight initialization function
|
||||||
|
*
|
||||||
|
* @return a new {@link IWeightInit} instance
|
||||||
|
*/
|
||||||
|
public IWeightInit getWeightInitFunction(Distribution distribution) {
|
||||||
|
switch (this) {
|
||||||
|
case DISTRIBUTION:
|
||||||
|
return new WeightInitDistribution(distribution);
|
||||||
|
default:
|
||||||
|
return getWeightInitFunction();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public IWeightInit getWeightInitFunction(
|
||||||
|
EmbeddingInitializer initializer) { // EmbeddingInitializer
|
||||||
|
switch (this) {
|
||||||
|
case EMBEDDING:
|
||||||
|
return new WeightInitEmbedding(initializer);
|
||||||
|
default:
|
||||||
|
return getWeightInitFunction();
|
||||||
|
}
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* Create an instance of the weight initialization function
|
* Create an instance of the weight initialization function
|
||||||
*
|
*
|
||||||
* @return a new {@link IWeightInit} instance
|
* @return a new {@link IWeightInit} instance
|
||||||
*/
|
*/
|
||||||
public IWeightInit getWeightInitFunction() {
|
public IWeightInit getWeightInitFunction() {
|
||||||
return getWeightInitFunction(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create an instance of the weight initialization function
|
|
||||||
*
|
|
||||||
* @param distribution Distribution of the weights (Only used in case DISTRIBUTION)
|
|
||||||
* @return a new {@link IWeightInit} instance
|
|
||||||
*/
|
|
||||||
public IWeightInit getWeightInitFunction(Distribution distribution) {
|
|
||||||
switch (this) {
|
switch (this) {
|
||||||
|
case CONSTANT:
|
||||||
|
return new WeightInitConstant();
|
||||||
case ZERO:
|
case ZERO:
|
||||||
return new WeightInitConstant(0.0);
|
return new WeightInitConstant(0.0);
|
||||||
case ONES:
|
case ONES:
|
||||||
return new WeightInitConstant(1.0);
|
return new WeightInitConstant(1.0);
|
||||||
case DISTRIBUTION:
|
|
||||||
return new WeightInitDistribution(distribution);
|
|
||||||
case SIGMOID_UNIFORM:
|
case SIGMOID_UNIFORM:
|
||||||
return new WeightInitSigmoidUniform();
|
return new WeightInitSigmoidUniform();
|
||||||
case LECUN_NORMAL: //Fall through: these 3 are equivalent
|
case LECUN_NORMAL: // Fall through: these 3 are equivalent
|
||||||
case XAVIER_FAN_IN:
|
case XAVIER_FAN_IN:
|
||||||
case NORMAL:
|
case NORMAL:
|
||||||
return new WeightInitNormal();
|
return new WeightInitNormal();
|
||||||
|
@ -87,7 +122,8 @@ public enum WeightInit {
|
||||||
return new WeightInitVarScalingUniformFanAvg();
|
return new WeightInitVarScalingUniformFanAvg();
|
||||||
|
|
||||||
default:
|
default:
|
||||||
throw new UnsupportedOperationException("Unknown or not supported weight initialization function: " + this);
|
throw new UnsupportedOperationException(
|
||||||
|
"Unknown or not supported weight initialization function: " + this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,4 +42,13 @@ public class WeightInitConstant implements IWeightInit {
|
||||||
paramView.assign(value);
|
paramView.assign(value);
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.CONSTANT;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,4 +51,13 @@ public class WeightInitDistribution implements IWeightInit {
|
||||||
}
|
}
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.DISTRIBUTION;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,4 +38,13 @@ public class WeightInitLecunUniform implements IWeightInit {
|
||||||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-b, b));
|
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-b, b));
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.LECUN_UNIFORM;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,4 +38,13 @@ public class WeightInitNormal implements IWeightInit {
|
||||||
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
|
Nd4j.randn(paramView).divi(FastMath.sqrt(fanIn));
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.NORMAL;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,4 +33,13 @@ public class WeightInitRelu implements IWeightInit {
|
||||||
Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn)
|
Nd4j.randn(paramView).muli(FastMath.sqrt(2.0 / fanIn)); //N(0, 2/nIn)
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.RELU;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,4 +35,13 @@ public class WeightInitReluUniform implements IWeightInit {
|
||||||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
|
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.RELU_UNIFORM;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,4 +35,13 @@ public class WeightInitSigmoidUniform implements IWeightInit {
|
||||||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-r, r));
|
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-r, r));
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.SIGMOID_UNIFORM;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,4 +34,13 @@ public class WeightInitUniform implements IWeightInit {
|
||||||
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-a, a));
|
Nd4j.rand(paramView, Nd4j.getDistributions().createUniform(-a, a));
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.UNIFORM;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,4 +50,13 @@ public class WeightInitVarScalingNormalFanAvg implements IWeightInit {
|
||||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.VAR_SCALING_NORMAL_FAN_AVG;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -48,4 +48,13 @@ public class WeightInitVarScalingNormalFanIn implements IWeightInit {
|
||||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.VAR_SCALING_NORMAL_FAN_IN;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,4 +50,13 @@ public class WeightInitVarScalingNormalFanOut implements IWeightInit {
|
||||||
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
Nd4j.exec(new TruncatedNormalDistribution(paramView, 0.0, std));
|
||||||
return paramView.reshape(order, shape);
|
return paramView.reshape(order, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public WeightInit enumValue() {
|
||||||
|
return WeightInit.VAR_SCALING_NORMAL_FAN_OUT;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue