Using @SuperBuilder for LayerConfigurations

Signed-off-by: brian <brian@brutex.de>
master
Brian Rosenberger 2023-04-25 16:44:47 +02:00
parent 8f524827e4
commit 3267b06bde
62 changed files with 122 additions and 108 deletions

View File

@ -267,11 +267,11 @@ public class RnnGradientChecks extends BaseDL4JTest {
.activation(Activation.TANH) .activation(Activation.TANH)
.updater(new NoOp()) .updater(new NoOp())
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list()
.layer(simple ? SimpleRnn.builder().nOut(layerSize).hasLayerNorm(hasLayerNorm).build() : .layer(simple ? SimpleRnn.builder().nOut(layerSize).hasLayerNorm(hasLayerNorm).build() :
LSTM.builder().nOut(layerSize).build()) LSTM.builder().nOut(layerSize).build())
.layer(new LastTimeStep(simple ? SimpleRnn.builder().nOut(layerSize).hasLayerNorm(hasLayerNorm).build() : .layer(LastTimeStep.builder().underlying(simple ? SimpleRnn.builder().nOut(layerSize).hasLayerNorm(hasLayerNorm).build() :
LSTM.builder().nOut(layerSize).build())) LSTM.builder().nOut(layerSize).build()).build())
.layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(OutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.inputType(InputType.recurrent(nIn)) .inputType(InputType.recurrent(nIn))
@ -335,7 +335,7 @@ public class RnnGradientChecks extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(LSTM.builder().nOut(layerSize).build()) .layer(LSTM.builder().nOut(layerSize).build())
.layer(new TimeDistributed(DenseLayer.builder().nOut(layerSize).activation(Activation.SOFTMAX).build())) .layer(TimeDistributed.builder().underlying(DenseLayer.builder().nOut(layerSize).activation(Activation.SOFTMAX).build()).build())
.layer(RnnOutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX) .layer(RnnOutputLayer.builder().nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.inputType(InputType.recurrent(nIn)) .inputType(InputType.recurrent(nIn))

View File

@ -482,7 +482,7 @@ public class DTypeTests extends BaseDL4JTest {
break; break;
case 1: case 1:
ol = LossLayer.builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT.getILossFunction()).build(); ol = LossLayer.builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT.getILossFunction()).build();
secondLast = new FrozenLayerWithBackprop(DenseLayer.builder().nOut(10).activation(Activation.SIGMOID).build()); secondLast = FrozenLayerWithBackprop.builder().underlying(DenseLayer.builder().nOut(10).activation(Activation.SIGMOID).build()).build();
break; break;
case 2: case 2:
ol =CenterLossOutputLayer.builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); ol =CenterLossOutputLayer.builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build();
@ -889,7 +889,7 @@ public class DTypeTests extends BaseDL4JTest {
break; break;
case 2: case 2:
ol = OutputLayer.builder().nOut(5).build(); ol = OutputLayer.builder().nOut(5).build();
secondLast = new LastTimeStep(SimpleRnn.builder().nOut(5).activation(Activation.TANH).build()); secondLast = LastTimeStep.builder().underlying(SimpleRnn.builder().nOut(5).activation(Activation.TANH).build()).build();
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
@ -905,7 +905,7 @@ public class DTypeTests extends BaseDL4JTest {
.layer(DenseLayer.builder().nOut(5).build()) .layer(DenseLayer.builder().nOut(5).build())
.layer(GravesBidirectionalLSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()) .layer(GravesBidirectionalLSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(Bidirectional.builder(LSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()).build()) .layer(Bidirectional.builder(LSTM.builder().nIn(5).nOut(5).activation(Activation.TANH).build()).build())
.layer(new TimeDistributed(DenseLayer.builder().nIn(10).nOut(5).activation(Activation.TANH).build())) .layer(TimeDistributed.builder().underlying(DenseLayer.builder().nIn(10).nOut(5).activation(Activation.TANH).build()).build())
.layer(SimpleRnn.builder().nIn(5).nOut(5).build()) .layer(SimpleRnn.builder().nIn(5).nOut(5).build())
.layer(MaskZeroLayer.builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskingValue(0.0).build()) .layer(MaskZeroLayer.builder().underlying(SimpleRnn.builder().nIn(5).nOut(5).build()).maskingValue(0.0).build())
.layer(secondLast) .layer(secondLast)
@ -1062,7 +1062,7 @@ public class DTypeTests extends BaseDL4JTest {
INDArray input; INDArray input;
if (test == 0) { if (test == 0) {
if (frozen) { if (frozen) {
conf.layer("0", new FrozenLayer(EmbeddingLayer.builder().nIn(5).nOut(5).build()), "in"); conf.layer("0", FrozenLayer.builder(EmbeddingLayer.builder().nIn(5).nOut(5).build()).build(), "in");
} else { } else {
conf.layer("0", EmbeddingLayer.builder().nIn(5).nOut(5).build(), "in"); conf.layer("0", EmbeddingLayer.builder().nIn(5).nOut(5).build(), "in");
} }
@ -1071,7 +1071,7 @@ public class DTypeTests extends BaseDL4JTest {
conf.setInputTypes(InputType.feedForward(1)); conf.setInputTypes(InputType.feedForward(1));
} else if (test == 1) { } else if (test == 1) {
if (frozen) { if (frozen) {
conf.layer("0", new FrozenLayer(EmbeddingSequenceLayer.builder().nIn(5).nOut(5).build()), "in"); conf.layer("0", FrozenLayer.builder(EmbeddingSequenceLayer.builder().nIn(5).nOut(5).build()).build(), "in");
} else { } else {
conf.layer("0", EmbeddingSequenceLayer.builder().nIn(5).nOut(5).build(), "in"); conf.layer("0", EmbeddingSequenceLayer.builder().nIn(5).nOut(5).build(), "in");
} }

View File

@ -1925,7 +1925,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
.setOutputs("output") .setOutputs("output")
.addLayer("0", ConvolutionLayer.builder().nOut(5).convolutionMode(ConvolutionMode.Same).build(),"input" ) .addLayer("0", ConvolutionLayer.builder().nOut(5).convolutionMode(ConvolutionMode.Same).build(),"input" )
.addVertex("dummyAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "0") .addVertex("dummyAdd", new ElementWiseVertex(ElementWiseVertex.Op.Add), "0")
.addLayer("output", new CnnLossLayer(), "dummyAdd") .addLayer("output", CnnLossLayer.builder(), "dummyAdd")
.build()); .build());
graph.init(); graph.init();
graph.outputSingle(Nd4j.randn(1, 2, 10, 10)); graph.outputSingle(Nd4j.randn(1, 2, 10, 10));

View File

@ -289,11 +289,11 @@ public class FrozenLayerTest extends BaseDL4JTest {
.build(); .build();
NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).list().layer(0, NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).list().layer(0,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(DenseLayer.builder().nIn(10).nOut(10) org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.builder(DenseLayer.builder().nIn(10).nOut(10)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())) .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer( .layer(1, org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.builder(
DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH) DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())) .weightInit(WeightInit.XAVIER).build()).build())
.layer(2, org.deeplearning4j.nn.conf.layers.OutputLayer.builder().lossFunction( .layer(2, org.deeplearning4j.nn.conf.layers.OutputLayer.builder().lossFunction(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build()) .nOut(10).build())

View File

@ -60,11 +60,11 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.build(); .build();
NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).list().layer(0, NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).list().layer(0,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(DenseLayer.builder().nIn(10).nOut(10) org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(DenseLayer.builder().nIn(10).nOut(10)
.activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())) .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).build())
.layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .layer(1, org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH) DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build())) .weightInit(WeightInit.XAVIER).build()).build())
.layer(2, OutputLayer.builder( .layer(2, OutputLayer.builder(
LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10) LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10)
.nOut(10).build()) .nOut(10).build())
@ -113,10 +113,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).graphBuilder() ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).graphBuilder()
.addInputs("in") .addInputs("in")
.addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .addLayer("0", org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH) DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()), "in") .weightInit(WeightInit.XAVIER).build()), "in")
.addLayer("1", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .addLayer("1", org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH) DenseLayer.builder().nIn(10).nOut(10).activation(Activation.TANH)
.weightInit(WeightInit.XAVIER).build()), "0") .weightInit(WeightInit.XAVIER).build()), "0")
.addLayer("2", OutputLayer.builder( .addLayer("2", OutputLayer.builder(
@ -160,11 +160,11 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.updater(new Sgd(2)) .updater(new Sgd(2))
.list() .list()
.layer(DenseLayer.builder().nIn(4).nOut(3).build()) .layer(DenseLayer.builder().nIn(4).nOut(3).build())
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .layer(org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(3).nOut(4).build())) DenseLayer.builder().nIn(3).nOut(4).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .layer(org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(4).nOut(2).build())) DenseLayer.builder().nIn(4).nOut(2).build()))
.layer(new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .layer(org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())) OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build(); .build();
@ -213,15 +213,15 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.addInputs("input") .addInputs("input")
.addLayer(initialLayer, DenseLayer.builder().nIn(4).nOut(4).build(),"input") .addLayer(initialLayer, DenseLayer.builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0, DenseLayer.builder().nIn(4).nOut(3).build(),initialLayer) .addLayer(frozenBranchUnfrozenLayer0, DenseLayer.builder().nIn(4).nOut(3).build(),initialLayer)
.addLayer(frozenBranchFrozenLayer1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .addLayer(frozenBranchFrozenLayer1, org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0) DenseLayer.builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .addLayer(frozenBranchFrozenLayer2, org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1) DenseLayer.builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0, DenseLayer.builder().nIn(4).nOut(4).build(),initialLayer) .addLayer(unfrozenLayer0, DenseLayer.builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1, DenseLayer.builder().nIn(4).nOut(2).build(),unfrozenLayer0) .addLayer(unfrozenLayer1, DenseLayer.builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2, DenseLayer.builder().nIn(2).nOut(1).build(),unfrozenLayer1) .addLayer(unfrozenBranch2, DenseLayer.builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) .addVertex("merge", new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .addLayer(frozenBranchOutput,org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge") OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.setOutputs(frozenBranchOutput) .setOutputs(frozenBranchOutput)
.build(); .build();
@ -269,9 +269,9 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.updater(new Sgd(2)) .updater(new Sgd(2))
.list() .list()
.layer(0,DenseLayer.builder().nIn(4).nOut(3).build()) .layer(0,DenseLayer.builder().nIn(4).nOut(3).build())
.layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(DenseLayer.builder().nIn(3).nOut(4).build())) .layer(1,org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(DenseLayer.builder().nIn(3).nOut(4).build()))
.layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(DenseLayer.builder().nIn(4).nOut(2).build())) .layer(2,org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(DenseLayer.builder().nIn(4).nOut(2).build()))
.layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build())) .layer(3,org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build(); .build();
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
frozenNetwork.init(); frozenNetwork.init();
@ -327,16 +327,16 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.addInputs("input") .addInputs("input")
.addLayer(initialLayer,DenseLayer.builder().nIn(4).nOut(4).build(),"input") .addLayer(initialLayer,DenseLayer.builder().nIn(4).nOut(4).build(),"input")
.addLayer(frozenBranchUnfrozenLayer0,DenseLayer.builder().nIn(4).nOut(3).build(), initialLayer) .addLayer(frozenBranchUnfrozenLayer0,DenseLayer.builder().nIn(4).nOut(3).build(), initialLayer)
.addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .addLayer(frozenBranchFrozenLayer1,org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0) DenseLayer.builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.addLayer(frozenBranchFrozenLayer2, .addLayer(frozenBranchFrozenLayer2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
DenseLayer.builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1) DenseLayer.builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.addLayer(unfrozenLayer0,DenseLayer.builder().nIn(4).nOut(4).build(),initialLayer) .addLayer(unfrozenLayer0,DenseLayer.builder().nIn(4).nOut(4).build(),initialLayer)
.addLayer(unfrozenLayer1,DenseLayer.builder().nIn(4).nOut(2).build(),unfrozenLayer0) .addLayer(unfrozenLayer1,DenseLayer.builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.addLayer(unfrozenBranch2,DenseLayer.builder().nIn(2).nOut(1).build(),unfrozenLayer1) .addLayer(unfrozenBranch2,DenseLayer.builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2) .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( .addLayer(frozenBranchOutput, org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop.builder(
OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge") OutputLayer.builder().lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.setOutputs(frozenBranchOutput) .setOutputs(frozenBranchOutput)
.build(); .build();

View File

@ -243,7 +243,7 @@ public class RnnDataFormatTests extends BaseDL4JTest {
layer = MaskZeroLayer.builder().maskingValue(0.).underlying(layer).build(); layer = MaskZeroLayer.builder().maskingValue(0.).underlying(layer).build();
} }
if(lastTimeStep){ if(lastTimeStep){
layer = new LastTimeStep(layer); layer = LastTimeStep.builder(layer);
} }
NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder() NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder()
.seed(12345) .seed(12345)

View File

@ -63,7 +63,7 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
public void testLastTimeStepVertex() { public void testLastTimeStepVertex() {
ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in")
.addLayer("lastTS", new LastTimeStep(SimpleRnn.builder() .addLayer("lastTS", LastTimeStep.builder(SimpleRnn.builder()
.nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in") .nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in")
.setOutputs("lastTS") .setOutputs("lastTS")
.build(); .build();
@ -134,7 +134,7 @@ public class TestLastTimeStepLayer extends BaseDL4JTest {
.graphBuilder() .graphBuilder()
.addInputs("in") .addInputs("in")
.setInputTypes(InputType.recurrent(1, rnnDataFormat)) .setInputTypes(InputType.recurrent(1, rnnDataFormat))
.addLayer("RNN", new LastTimeStep(LSTM.builder() .addLayer("RNN", LastTimeStep.builder(LSTM.builder()
.nOut(10).dataFormat(rnnDataFormat) .nOut(10).dataFormat(rnnDataFormat)
.build()), "in") .build()), "in")
.addLayer("dense", DenseLayer.builder() .addLayer("dense", DenseLayer.builder()

View File

@ -79,9 +79,8 @@ public class TestTimeDistributed extends BaseDL4JTest {
.inferenceWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm)
.seed(12345) .seed(12345)
.updater(new Adam(0.1)) .updater(new Adam(0.1))
.list()
.layer(LSTM.builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build()) .layer(LSTM.builder().nIn(3).nOut(3).dataFormat(rnnDataFormat).build())
.layer(new TimeDistributed(DenseLayer.builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat)) .layer(TimeDistributed.builder().underlying(DenseLayer.builder().nIn(3).nOut(3).activation(Activation.TANH).build()).rnnDataFormat(rnnDataFormat))
.layer(RnnOutputLayer.builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat) .layer(RnnOutputLayer.builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat)
.lossFunction(LossFunctions.LossFunction.MCXENT).build()) .lossFunction(LossFunctions.LossFunction.MCXENT).build())
.inputType(InputType.recurrent(3, rnnDataFormat)) .inputType(InputType.recurrent(3, rnnDataFormat))

View File

@ -314,7 +314,7 @@ public class TestMasking extends BaseDL4JTest {
) )
.addInputs("m1", "m2") .addInputs("m1", "m2")
.addVertex("stack", new StackVertex(), "m1", "m2") .addVertex("stack", new StackVertex(), "m1", "m2")
.addLayer("lastUnStacked", new LastTimeStep(LSTM.builder().nIn(3).nOut(1).activation(Activation.TANH).build()), "stack") .addLayer("lastUnStacked", LastTimeStep.builder(LSTM.builder().nIn(3).nOut(1).activation(Activation.TANH).build()), "stack")
.addVertex("unstacked1", new UnstackVertex(0, 2), "lastUnStacked") .addVertex("unstacked1", new UnstackVertex(0, 2), "lastUnStacked")
.addVertex("unstacked2", new UnstackVertex(1, 2), "lastUnStacked") .addVertex("unstacked2", new UnstackVertex(1, 2), "lastUnStacked")
.addVertex("restacked", new StackVertex(), "unstacked1", "unstacked2") .addVertex("restacked", new StackVertex(), "unstacked1", "unstacked2")

View File

@ -336,12 +336,12 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In")
.setInputTypes(InputType.convolutionalFlat(28,28, 3)) .setInputTypes(InputType.convolutionalFlat(28,28, 3))
.addLayer("layer0", .addLayer("layer0",
new FrozenLayer(ConvolutionLayer.builder(5, 5).nIn(3) FrozenLayer.builder(ConvolutionLayer.builder(5, 5).nIn(3)
.stride(1, 1).nOut(20) .stride(1, 1).nOut(20)
.activation(Activation.IDENTITY).build()), .activation(Activation.IDENTITY).build()),
"layer0In") "layer0In")
.addLayer("layer1", .addLayer("layer1",
new FrozenLayer(SubsamplingLayer.builder( FrozenLayer.builder(SubsamplingLayer.builder(
SubsamplingLayer.PoolingType.MAX) SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2).stride(2, 2) .kernelSize(2, 2).stride(2, 2)
.build()), .build()),
@ -430,11 +430,11 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.graphBuilder().addInputs("in") .graphBuilder().addInputs("in")
.addLayer("blstm1", .addLayer("blstm1",
new FrozenLayer(GravesBidirectionalLSTM.builder().nIn(10).nOut(10) FrozenLayer.builder(GravesBidirectionalLSTM.builder().nIn(10).nOut(10)
.activation(Activation.TANH).build()), .activation(Activation.TANH).build()),
"in") "in")
.addLayer("pool", new FrozenLayer(GlobalPoolingLayer.builder().build()), "blstm1") .addLayer("pool", FrozenLayer.builder(GlobalPoolingLayer.builder().build()), "blstm1")
.addLayer("dense", new FrozenLayer(DenseLayer.builder().nIn(10).nOut(10).build()), "pool") .addLayer("dense", FrozenLayer.builder(DenseLayer.builder().nIn(10).nOut(10).build()), "pool")
.addLayer("out", OutputLayer.builder().nIn(10).nOut(5).activation(Activation.SOFTMAX) .addLayer("out", OutputLayer.builder().nIn(10).nOut(5).activation(Activation.SOFTMAX)
.updater(new Adam(0.1)) .updater(new Adam(0.1))
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "dense") .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "dense")

View File

@ -203,7 +203,7 @@ public class KerasLSTM extends KerasLayer {
this.layer = builder.build(); this.layer = builder.build();
if (!returnSequences) { if (!returnSequences) {
this.layer = new LastTimeStep(this.layer); this.layer = LastTimeStep.builder(this.layer);
} }
if (maskingConfig.getFirst()) { if (maskingConfig.getFirst()) {
this.layer = new MaskZeroLayer(this.layer, maskingConfig.getSecond()); this.layer = new MaskZeroLayer(this.layer, maskingConfig.getSecond());

View File

@ -174,7 +174,7 @@ public class KerasSimpleRnn extends KerasLayer {
this.layer = builder.build(); this.layer = builder.build();
if (!returnSequences) { if (!returnSequences) {
this.layer = new LastTimeStep(this.layer); this.layer = LastTimeStep.builder(this.layer);
} }
if (maskingConfig.getFirst()) { if (maskingConfig.getFirst()) {
this.layer = new MaskZeroLayer(this.layer, maskingConfig.getSecond()); this.layer = new MaskZeroLayer(this.layer, maskingConfig.getSecond());

View File

@ -819,6 +819,9 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable {
public GraphBuilder addLayer(String layerName, LayerConfiguration layer, String... layerInputs) { public GraphBuilder addLayer(String layerName, LayerConfiguration layer, String... layerInputs) {
return addLayer(layerName, layer, null, layerInputs); return addLayer(layerName, layer, null, layerInputs);
} }
public GraphBuilder addLayer(String layerName, LayerConfiguration.LayerConfigurationBuilder<?,?> layer, String... layerInputs) {
return addLayer(layerName, layer.build(), null, layerInputs);
}
/** /**
* Add a layer, with no {@link InputPreProcessor}, with the specified name * Add a layer, with no {@link InputPreProcessor}, with the specified name

View File

@ -661,7 +661,17 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
public B layer(Integer index, @NonNull LayerConfiguration layer) { public B layer(Integer index, @NonNull LayerConfiguration layer) {
innerConfigurations$value.add(index, layer); innerConfigurations$value.add(index, layer);
innerConfigurations$set = true; innerConfigurations$set = true;
return (B) this; return self();
}
/**
* Set layer at index
*
* @param index where to insert
* @param layer the layer
* @return builder
*/
public B layer(Integer index, @NonNull LayerConfiguration.LayerConfigurationBuilder<?,?> layer) {
return this.layer(index, layer.build());
} }
/** /**
@ -675,6 +685,9 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor
innerConfigurations$set = true; innerConfigurations$set = true;
return (B) this; return (B) this;
} }
public B layer(@NonNull LayerConfiguration.LayerConfigurationBuilder<?, ?> layer) {
return this.layer(layer.build());
}
// TODO this is a dirty workaround // TODO this is a dirty workaround
public boolean isOverrideNinUponBuild() { public boolean isOverrideNinUponBuild() {

View File

@ -212,6 +212,16 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration
C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>> C extends BaseLayerConfiguration, B extends BaseLayerConfigurationBuilder<C, B>>
extends LayerConfigurationBuilder<C, B> { extends LayerConfigurationBuilder<C, B> {
public B updater(Updater upd) {
this.updater = upd.getIUpdaterWithDefaultConfig();
return self();
}
public B updater(IUpdater upd) {
this.updater = upd;
return self();
}
/** /**
* Set weight initialization scheme to random sampling via the specified distribution. * Set weight initialization scheme to random sampling via the specified distribution.

View File

@ -38,7 +38,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class CapsuleLayer extends SameDiffLayer { public class CapsuleLayer extends SameDiffLayer {
@ -50,33 +50,33 @@ public class CapsuleLayer extends SameDiffLayer {
* @param hasBias * @param hasBias
* @return * @return
*/ */
@Builder.Default private boolean hasBias = false; @Builder.Default @Getter @Setter private boolean hasBias = false;
/** /**
* Usually inferred automatically. * Usually inferred automatically.
* @param inputCapsules * @param inputCapsules
* @return * @return
*/ */
@Builder.Default private long inputCapsules = 0; @Builder.Default @Getter @Setter private long inputCapsules = 0;
/** /**
* Usually inferred automatically. * Usually inferred automatically.
* @param inputCapsuleDimensions * @param inputCapsuleDimensions
* @return * @return
*/ */
@Builder.Default private long inputCapsuleDimensions = 0; @Builder.Default @Getter @Setter private long inputCapsuleDimensions = 0;
/** /**
* Set the number of capsules to use. * Set the number of capsules to use.
* @param capsules * @param capsules
* @return * @return
*/ */
private int capsules; @Getter @Setter private int capsules;
private int capsuleDimensions; @Getter @Setter private int capsuleDimensions;
/** /**
* Set the number of dynamic routing iterations to use. * Set the number of dynamic routing iterations to use.
* The default is 3 (recommendedded in Dynamic Routing Between Capsules) * The default is 3 (recommendedded in Dynamic Routing Between Capsules)
* @param routings * @param routings
* @return * @return
*/ */
@Builder.Default private int routings = 3; @Builder.Default @Getter @Setter private int routings = 3;
@Override @Override
public void setNIn(InputType inputType, boolean override) { public void setNIn(InputType inputType, boolean override) {

View File

@ -32,7 +32,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
public class CapsuleStrengthLayer extends SameDiffLambdaLayer { public class CapsuleStrengthLayer extends SameDiffLambdaLayer {

View File

@ -39,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder

View File

@ -46,7 +46,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* to be used in the net or in other words the channels The builder specifies the filter/kernel * to be used in the net or in other words the channels The builder specifies the filter/kernel
* size, the stride and padding The pooling layer takes the kernel size * size, the stride and padding The pooling layer takes the kernel size
*/ */
@Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
@ -72,7 +72,7 @@ public class ConvolutionLayer extends FeedForwardLayer {
* *
* @param format Format for activations (in and out) * @param format Format for activations (in and out)
*/ */
@Builder.Default @Builder.Default @Getter @Setter
private CNN2DFormat convFormat = private CNN2DFormat convFormat =
CNN2DFormat.NCHW; // default value for legacy serialization reasons CNN2DFormat.NCHW; // default value for legacy serialization reasons
@ -86,24 +86,29 @@ public class ConvolutionLayer extends FeedForwardLayer {
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html#dilated-convolutions</a> * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html#dilated-convolutions</a>
* <br> * <br>
*/ */
@Getter @Setter
private @Builder.Default int[] dilation = new int[] {1, 1}; private @Builder.Default int[] dilation = new int[] {1, 1};
/** Default is 2. Down-sample by a factor of 2 */ /** Default is 2. Down-sample by a factor of 2 */
@Getter @Setter
private @Builder.Default int[] stride = new int[] {1, 1}; private @Builder.Default int[] stride = new int[] {1, 1};
@Getter @Setter
private @Builder.Default int[] padding = new int[] {0, 0}; private @Builder.Default int[] padding = new int[] {0, 0};
/** /**
* When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation * When using CuDNN and an error is encountered, should fallback to the non-CuDNN implementatation
* be allowed? If set to false, an exception in CuDNN will be propagated back to the user. If * be allowed? If set to false, an exception in CuDNN will be propagated back to the user. If
* false, the built-in (non-CuDNN) implementation for ConvolutionLayer will be used * false, the built-in (non-CuDNN) implementation for ConvolutionLayer will be used
*/ */
@Getter
@Builder.Default private boolean cudnnAllowFallback = true; @Builder.Default private boolean cudnnAllowFallback = true;
/** Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */ /** Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. */
@Getter
@Builder.Default private AlgoMode cudnnAlgoMode = AlgoMode.PREFER_FASTEST; @Builder.Default private AlgoMode cudnnAlgoMode = AlgoMode.PREFER_FASTEST;
private FwdAlgo cudnnFwdAlgo; private FwdAlgo cudnnFwdAlgo;
private BwdFilterAlgo cudnnBwdFilterAlgo; private BwdFilterAlgo cudnnBwdFilterAlgo;
private BwdDataAlgo cudnnBwdDataAlgo; private BwdDataAlgo cudnnBwdDataAlgo;
@Getter @Setter
@Builder.Default private int convolutionDim = 2; // 2D convolution by default @Builder.Default private int convolutionDim = 2; // 2D convolution by default
/** Causal convolution - allowed for 1D only */ /** Causal convolution - allowed for 1D only */
@Builder.Default private boolean allowCausal = false; @Builder.Default private boolean allowCausal = false;

View File

@ -45,7 +45,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")

View File

@ -44,7 +44,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")

View File

@ -39,7 +39,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.*; import java.util.*;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@Deprecated @Deprecated

View File

@ -41,7 +41,6 @@ import java.util.Collections;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder(buildMethodName = "initBuild")

View File

@ -48,7 +48,7 @@ import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
/** A neural network layer. */ /** A neural network layer. */
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") //@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@EqualsAndHashCode @EqualsAndHashCode
// @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id") // @JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id")
@Slf4j @Slf4j

View File

@ -40,7 +40,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")

View File

@ -47,7 +47,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@JsonIgnoreProperties({"paramShapes"}) @JsonIgnoreProperties({"paramShapes"})
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder(buildMethodName = "initBuild")

View File

@ -40,7 +40,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.ILossFunction;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder

View File

@ -37,7 +37,6 @@ import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")

View File

@ -38,7 +38,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")

View File

@ -32,7 +32,6 @@ import lombok.ToString;
* @author Max Pumperla * @author Max Pumperla
*/ */
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
public class Pooling1D extends Subsampling1DLayer { public class Pooling1D extends Subsampling1DLayer {

View File

@ -32,7 +32,6 @@ import lombok.ToString;
* @author Max Pumperla * @author Max Pumperla
*/ */
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
public class Pooling2D extends SubsamplingLayer { public class Pooling2D extends SubsamplingLayer {

View File

@ -40,7 +40,6 @@ import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")
public class PrimaryCapsules extends SameDiffLayer { public class PrimaryCapsules extends SameDiffLayer {

View File

@ -42,7 +42,6 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder(buildMethodName = "initBuild")
public class RecurrentAttentionLayer extends SameDiffLayer { public class RecurrentAttentionLayer extends SameDiffLayer {

View File

@ -40,7 +40,6 @@ import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder

View File

@ -40,7 +40,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")

View File

@ -38,7 +38,6 @@ import org.nd4j.linalg.factory.Nd4j;
@Data @Data
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@NoArgsConstructor()
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder(buildMethodName = "initBuild")
public class SelfAttentionLayer extends SameDiffLayer { public class SelfAttentionLayer extends SameDiffLayer {
private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq"; private static final String WEIGHT_KEY_QUERY_PROJECTION = "Wq";

View File

@ -39,7 +39,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")

View File

@ -38,7 +38,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder

View File

@ -48,7 +48,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* wide. * wide.
*/ */
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild") @SuperBuilder(buildMethodName = "initBuild")

View File

@ -43,7 +43,6 @@ import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild") @SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild")

View File

@ -43,7 +43,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder") @SuperBuilder(buildMethodName = "initBuild", builderMethodName = "innerBuilder")

View File

@ -41,7 +41,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")

View File

@ -39,7 +39,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")

View File

@ -37,7 +37,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")

View File

@ -38,7 +38,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class ZeroPadding1DLayer extends NoParamLayer { public class ZeroPadding1DLayer extends NoParamLayer {

View File

@ -39,7 +39,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class ZeroPadding3DLayer extends NoParamLayer { public class ZeroPadding3DLayer extends NoParamLayer {

View File

@ -40,7 +40,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild") @SuperBuilder(builderMethodName = "innerBuilder", buildMethodName = "initBuild")
public class ZeroPaddingLayer extends NoParamLayer { public class ZeroPaddingLayer extends NoParamLayer {

View File

@ -40,7 +40,6 @@ import org.nd4j.linalg.api.ndarray.INDArray;
/** Amount of cropping to apply to both the top and the bottom of the input activations */ /** Amount of cropping to apply to both the top and the bottom of the input activations */
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class Cropping1D extends NoParamLayer { public class Cropping1D extends NoParamLayer {

View File

@ -41,7 +41,6 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class Cropping2D extends NoParamLayer { public class Cropping2D extends NoParamLayer {

View File

@ -40,7 +40,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder(builderMethodName = "innerBuilder") @SuperBuilder(builderMethodName = "innerBuilder")
public class Cropping3D extends NoParamLayer { public class Cropping3D extends NoParamLayer {

View File

@ -42,7 +42,6 @@ import java.util.Map;
@Data @Data
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@NoArgsConstructor
@SuperBuilder @SuperBuilder
public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.layers.FeedForwardLayer { public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.layers.FeedForwardLayer {

View File

@ -57,6 +57,9 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayerConfiguration {
public static FrozenLayerWithBackpropBuilder<?, ?> builder(LayerConfiguration innerConfiguration) { public static FrozenLayerWithBackpropBuilder<?, ?> builder(LayerConfiguration innerConfiguration) {
return innerBuilder().underlying(innerConfiguration); return innerBuilder().underlying(innerConfiguration);
} }
public static FrozenLayerWithBackpropBuilder<?, ?> builder(LayerConfigurationBuilder<?,?> innerConfiguration) {
return innerBuilder().underlying(innerConfiguration.build());
}
public NeuralNetConfiguration getInnerConf(NeuralNetConfiguration conf) { public NeuralNetConfiguration getInnerConf(NeuralNetConfiguration conf) {
NeuralNetConfiguration nnc = conf.clone(); NeuralNetConfiguration nnc = conf.clone();
nnc.getLayerConfigurations().add(0, underlying); nnc.getLayerConfigurations().add(0, underlying);

View File

@ -39,7 +39,6 @@ import java.util.Collection;
import java.util.Map; import java.util.Map;
@Data @Data
@NoArgsConstructor
@ToString(callSuper = true) @ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder

View File

@ -48,7 +48,6 @@ import java.util.Map;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval; import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
@NoArgsConstructor
@Data @Data
@EqualsAndHashCode(callSuper = true, exclude = {"initializer"}) @EqualsAndHashCode(callSuper = true, exclude = {"initializer"})
@JsonIgnoreProperties({"initializer"}) @JsonIgnoreProperties({"initializer"})

View File

@ -61,6 +61,7 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
* @param regularization Regularization to apply for the network parameters/weights (excluding * @param regularization Regularization to apply for the network parameters/weights (excluding
* biases) * biases)
*/ */
@Getter
protected List<Regularization> regularization; protected List<Regularization> regularization;
/** /**
* The regularization for the biases only - for example {@link WeightDecay} -- SETTER -- Set the * The regularization for the biases only - for example {@link WeightDecay} -- SETTER -- Set the
@ -68,6 +69,7 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
* *
* @param regularizationBias Regularization to apply for the network biases only * @param regularizationBias Regularization to apply for the network biases only
*/ */
@Getter
protected List<Regularization> regularizationBias; protected List<Regularization> regularizationBias;
/** /**
* Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link * Gradient updater. For example, {@link org.nd4j.linalg.learning.config.Adam} or {@link
@ -83,10 +85,11 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration {
* @param biasUpdater Updater to use for bias parameters * @param biasUpdater Updater to use for bias parameters
*/ */
protected @Getter @Setter IUpdater biasUpdater; protected @Getter @Setter IUpdater biasUpdater;
@Getter @Setter
protected GradientNormalization gradientNormalization; protected GradientNormalization gradientNormalization;
@Getter @Setter
protected double gradientNormalizationThreshold = Double.NaN; protected double gradientNormalizationThreshold = Double.NaN;
@Getter @Setter
private SDLayerParams layerParams; private SDLayerParams layerParams;
@Override @Override

View File

@ -28,7 +28,6 @@ import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Map; import java.util.Map;
@NoArgsConstructor
@SuperBuilder @SuperBuilder
public abstract class SameDiffLambdaLayer extends SameDiffLayer { public abstract class SameDiffLambdaLayer extends SameDiffLayer {

View File

@ -45,7 +45,6 @@ import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
@Data @Data
@NoArgsConstructor
@EqualsAndHashCode(callSuper = true) @EqualsAndHashCode(callSuper = true)
@SuperBuilder @SuperBuilder
public class VariationalAutoencoder extends BasePretrainNetwork { public class VariationalAutoencoder extends BasePretrainNetwork {

View File

@ -48,31 +48,31 @@ public class OCNNOutputLayer extends BaseOutputLayer {
* The hidden layer size for the one class neural network. Note this would be nOut on a dense * The hidden layer size for the one class neural network. Note this would be nOut on a dense
* layer. NOut in this neural net is always set to 1 though. * layer. NOut in this neural net is always set to 1 though.
*/ */
@Builder.Default private int hiddenLayerSize; // embedded hidden layer size aka "K" @Builder.Default @Getter private int hiddenLayerSize; // embedded hidden layer size aka "K"
/** For nu definition see the paper */ /** For nu definition see the paper */
@Builder.Default private double nu = 0.04; @Builder.Default @Getter private double nu = 0.04;
/** /**
* The number of examples to use for computing the quantile for the r value update. This value * The number of examples to use for computing the quantile for the r value update. This value
* should generally be the same as the number of examples in the dataset * should generally be the same as the number of examples in the dataset
*/ */
@Builder.Default private int windowSize = 10000; @Builder.Default @Getter private int windowSize = 10000;
/** /**
* The initial r value to use for ocnn for definition, see the paper, note this is only active * The initial r value to use for ocnn for definition, see the paper, note this is only active
* when {@link #configureR} is specified as true * when {@link #configureR} is specified as true
*/ */
@Builder.Default private double initialRValue = 0.1; @Builder.Default @Getter private double initialRValue = 0.1;
/** /**
* Whether to use the specified {@link #initialRValue} or use the weight initialization with the * Whether to use the specified {@link #initialRValue} or use the weight initialization with the
* neural network for the r value * neural network for the r value
*/ */
@Builder.Default private boolean configureR = true; @Builder.Default @Getter private boolean configureR = true;
/** /**
* Psuedo code from keras: start_time = time.time() for epoch in range(100): # Train with each * Psuedo code from keras: start_time = time.time() for epoch in range(100): # Train with each
* example sess.run(updates, feed_dict={X: train_X,r:rvalue}) rvalue = nnScore(train_X, w_1, w_2, * example sess.run(updates, feed_dict={X: train_X,r:rvalue}) rvalue = nnScore(train_X, w_1, w_2,
* g) with sess.as_default(): rvalue = rvalue.eval() rvalue = np.percentile(rvalue,q=100*nu) * g) with sess.as_default(): rvalue = rvalue.eval() rvalue = np.percentile(rvalue,q=100*nu)
* print("Epoch = %d, r = %f" % (epoch + 1,rvalue)) * print("Epoch = %d, r = %f" % (epoch + 1,rvalue))
*/ */
@Builder.Default private int lastEpochSinceRUpdated = 0; @Builder.Default @Getter @Setter private int lastEpochSinceRUpdated = 0;
@Override @Override
public Layer instantiate( public Layer instantiate(

View File

@ -21,8 +21,10 @@
package org.deeplearning4j.nn.layers.ocnn; package org.deeplearning4j.nn.layers.ocnn;
import lombok.Builder;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.DefaultGradient;
@ -46,10 +48,12 @@ import static org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer.R_KEY;
import static org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer.V_KEY; import static org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer.V_KEY;
import static org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer.W_KEY; import static org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer.W_KEY;
public class OCNNOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer> { public class OCNNOutputLayer extends BaseOutputLayer<org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer> {
@Setter @Setter
@Getter @Getter
private IActivation activation = new ActivationReLU(); private IActivation activation = new ActivationReLU();
private static final IActivation relu = new ActivationReLU(); private static final IActivation relu = new ActivationReLU();

View File

@ -21,16 +21,21 @@
package org.deeplearning4j.nn.layers.util; package org.deeplearning4j.nn.layers.util;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.experimental.SuperBuilder;
import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
@NoArgsConstructor @SuperBuilder(builderMethodName = "innerBuilder")
public class IdentityLayer extends SameDiffLambdaLayer { public class IdentityLayer extends SameDiffLambdaLayer {
public IdentityLayer(String name) { public static IdentityLayerBuilder<?,?> builder() {
this.name = name; return innerBuilder();
}
public static IdentityLayerBuilder<?,?> builder(String name) {
return innerBuilder()
.name(name);
} }
@Override @Override

View File

@ -235,7 +235,10 @@ chipList.each { thisChip ->
/* Get VCVARS in case we want to build CUDA /* Get VCVARS in case we want to build CUDA
* MinGW64 g++ on MSYS is used otherwise */ * MinGW64 g++ on MSYS is used otherwise */
if (thisChip.equals('cuda') && osdetector.os.startsWith("win") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) { if (thisChip.equals('cuda') && osdetector.os.startsWith("win")
&& project.hasProperty("skip-native")
&& !project.getProperty("skip-native").equals("true")
&& !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && set"].execute() def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && set"].execute()
it.environmentVariables = it.environmentVariables ?: [:] it.environmentVariables = it.environmentVariables ?: [:]
def lines = proc.text.split("\\r?\\n") def lines = proc.text.split("\\r?\\n")
@ -329,7 +332,8 @@ chipList.each { thisChip ->
thisTask.properties = getBuildPlatform( thisChip, thisTask ) thisTask.properties = getBuildPlatform( thisChip, thisTask )
if(thisChip.equals('cuda') && osdetector.os.startsWith("win") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) { if(thisChip.equals('cuda') && osdetector.os.startsWith("win") && project.hasProperty("skip-native")
&& !project.getProperty("skip-native").equals("true") && !VISUAL_STUDIO_INSTALL_DIR.isEmpty()) {
def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute() def proc = ["cmd.exe", "/c", "${VISUAL_STUDIO_VCVARS_CMD} > nul && where.exe cl.exe"].execute()
def outp = proc.text def outp = proc.text
def cl = outp.replace("\\", "\\\\").trim() def cl = outp.replace("\\", "\\\\").trim()

View File

@ -28,7 +28,9 @@
****************************************************************************/ ****************************************************************************/
if (!hasProperty("VISUAL_STUDIO_INSTALL_DIR") && osdetector.os.equals("windows")) { if (!hasProperty("VISUAL_STUDIO_INSTALL_DIR") && osdetector.os.equals("windows")) {
configureVisualStudio() if (project.hasProperty("skip-native") && !project.getProperty("skip-native").equals("true")) {
configureVisualStudio()
}
} }
def configureVisualStudio() { def configureVisualStudio() {