From 483c3d7b8cb69d1d7e436b450e2c7b4218f5f2cb Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 2 Mar 2020 16:15:49 +1100 Subject: [PATCH] Assorted SameDiff/DL4J fixes (#279) * #8565 Normalizer toString/hashcode Signed-off-by: Alex Black * #8731 ImagePreProcessingScaler lables/segmentation fix Signed-off-by: Alex Black * #8691 Fix SameDiffLayer/Vertx finetuning and parameter setting support Signed-off-by: Alex Black * #8663 DL4J embedding layer weight init - don't depend on vocab size Signed-off-by: Alex Black * EmbeddingLayer test tweak Signed-off-by: Alex Black --- .../embedding/EmbeddingLayerTest.java | 76 ++++++++++++- .../samediff/SameDiffCustomLayerTests.java | 6 ++ .../testlayers/SameDiffDenseVertex.java | 6 ++ .../SameDiffSimpleLambdaVertex.java | 1 + .../TransferLearningCompGraphTest.java | 100 ++++++++++++++++++ .../TransferLearningMLNTest.java | 50 +++++++++ .../conf/ComputationGraphConfiguration.java | 3 +- .../nn/conf/graph/AttentionVertex.java | 14 +++ .../nn/conf/layers/EmbeddingLayer.java | 3 +- .../conf/layers/EmbeddingSequenceLayer.java | 4 +- .../layers/samediff/SameDiffLambdaVertex.java | 11 ++ .../conf/layers/samediff/SameDiffVertex.java | 7 +- .../nn/graph/ComputationGraph.java | 3 +- .../nn/layers/samediff/SameDiffLayer.java | 13 ++- .../EmbeddingLayerParamInitializer.java | 52 +++++++++ .../AbstractMultiDataSetNormalizer.java | 2 - .../ImagePreProcessingScaler.java | 10 +- .../dataset/ImagePreProcessortTest.java | 37 +++++++ .../nd4j/linalg/dataset/NormalizerTests.java | 78 ++++++++++++-- 19 files changed, 449 insertions(+), 27 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/EmbeddingLayerParamInitializer.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 972302d85..d53522c5d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -47,8 +47,7 @@ import java.util.List; import java.util.Map; import java.util.Random; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; public class EmbeddingLayerTest extends BaseDL4JTest { @@ -725,4 +724,77 @@ public class EmbeddingLayerTest extends BaseDL4JTest { assertEquals(new ActivationIdentity(), l2.getActivationFn()); } + + + @Test + public void testEmbeddingWeightInit(){ + // https://github.com/eclipse/deeplearning4j/issues/8663 + //The embedding layer weight initialization should be independent of the vocabulary size (nIn setting) + + for(WeightInit wi : new WeightInit[]{WeightInit.XAVIER, WeightInit.RELU, WeightInit.XAVIER_UNIFORM, WeightInit.LECUN_NORMAL, WeightInit.VAR_SCALING_NORMAL_FAN_OUT}) { + + for (boolean seq : new boolean[]{false, true}) { + + Nd4j.getRandom().setSeed(12345); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(12345) + .list() + .layer(seq ? + new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() : + new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + Nd4j.getRandom().setSeed(12345); + MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + .seed(12345) + .list() + .layer(seq ? + new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100).nOut(100).build() : + new EmbeddingLayer.Builder().weightInit(wi).nIn(100).nOut(100).build()) + .build(); + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); + + Nd4j.getRandom().setSeed(12345); + MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder() + .seed(12345) + .list() + .layer(seq ? + new EmbeddingSequenceLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build() : + new EmbeddingLayer.Builder().weightInit(wi).nIn(100000).nOut(100).build()) + .build(); + MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); + net3.init(); + + INDArray p1 = net.params(); + INDArray p2 = net2.params(); + INDArray p3 = net3.params(); + assertEquals(p1, p2); + + double m1 = p1.meanNumber().doubleValue(); + double s1 = p1.stdNumber().doubleValue(); + + double m3 = p3.meanNumber().doubleValue(); + double s3 = p3.stdNumber().doubleValue(); + + String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi; + + assertEquals(str, m1, m3, 0.1); + assertEquals(str, s1, s3, 0.1); + + double re = relErr(s1, s3); + assertTrue(str + " - " + re, re < 0.05); + } + } + + } + + public static double relErr(double d1, double d2){ + if(d1 == 0.0 && d2 == 0.0) + return 0.0; + return Math.abs(d1 - d2) / (Math.abs(d1) + Math.abs(d2)); + } + } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java index e672c3b0d..034c3e03b 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java @@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -136,6 +137,11 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest { } private class ValidatingSameDiffVertex extends SameDiffVertex { + @Override + public GraphVertex clone() { + return new ValidatingSameDiffVertex(); + } + @Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { return vertexInputs[0]; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java index 3e3631d5b..481816472 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.layers.samediff.testlayers; import lombok.Data; import lombok.NoArgsConstructor; +import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.deeplearning4j.nn.params.DefaultParamInitializer; @@ -74,4 +75,9 @@ public class SameDiffDenseVertex extends SameDiffVertex { public char paramReshapeOrder(String paramName){ return 'f'; //To match DL4J DenseLayer - for easy comparison } + + @Override + public GraphVertex clone() { + return new SameDiffDenseVertex(nIn, nOut, activation, weightInit); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffSimpleLambdaVertex.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffSimpleLambdaVertex.java index 98894c882..d9513cf80 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffSimpleLambdaVertex.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffSimpleLambdaVertex.java @@ -16,6 +16,7 @@ package org.deeplearning4j.nn.layers.samediff.testlayers; +import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java index 094ce531d..93341dbe2 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint; import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; +import org.deeplearning4j.nn.conf.graph.AttentionVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; @@ -35,6 +36,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -44,6 +46,9 @@ import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.util.HashMap; +import java.util.Map; + import static org.junit.Assert.*; /** @@ -565,4 +570,99 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { assertEquals("Incorrect number of inputs!", 5, newGraph.layerInputSize(afterPoolName)); newGraph.output(input); } + + + + + @Test + public void testTransferLearningSameDiffLayersGraph(){ + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + + .graphBuilder() + .addInputs("in") + .layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in") + .layer("l1", new RecurrentAttentionLayer.Builder().nHeads(1).headSize(5).nIn(5).nOut(5).build(), "l0") + .layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") + .setOutputs("out") + .build(); + + ComputationGraph cg = new ComputationGraph(conf); + cg.init(); + + INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + INDArray out = cg.output(arr)[0]; + + + ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out") + .fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()) + .removeVertexAndConnections("out") + .addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") + .setOutputs("newOut") + .build(); + + cg2.output(arr); + + Map m = new HashMap<>(cg.paramTable()); + m.put("newOut_W", m.remove("out_W")); + m.put("newOut_b", m.remove("out_b")); + cg2.setParamTable(m); + + Map p1 = cg.paramTable(); + Map p2 = cg2.paramTable(); + for(String s : p1.keySet()){ + INDArray i1 = p1.get(s); + INDArray i2 = p2.get(s.replaceAll("out", "newOut")); + assertEquals(s, i1, i2); + } + + INDArray out2 = cg2.outputSingle(arr); + assertEquals(out, out2); + } + + @Test + public void testTransferLearningSameDiffLayersGraphVertex(){ + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + + .graphBuilder() + .addInputs("in") + .layer("l0", new LSTM.Builder().nIn(5).nOut(5).build(), "in") + .addVertex("l1", new AttentionVertex.Builder().nHeads(1).headSize(5).nInKeys(5).nInQueries(5).nInValues(5).nOut(5).build(), "l0", "l0", "l0") + .layer("out", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") + .setOutputs("out") + .build(); + + ComputationGraph cg = new ComputationGraph(conf); + cg.init(); + + INDArray arr = Nd4j.rand(DataType.FLOAT, 2, 5, 10); + INDArray out = cg.output(arr)[0]; + + + ComputationGraph cg2 = new TransferLearning.GraphBuilder(cg).removeVertexAndConnections("out") + .fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()) + .removeVertexAndConnections("out") + .addLayer("newOut", new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build(), "l1") + .setOutputs("newOut") + .build(); + + cg2.output(arr); + + Map m = new HashMap<>(cg.paramTable()); + m.put("newOut_W", m.remove("out_W")); + m.put("newOut_b", m.remove("out_b")); + cg2.setParamTable(m); + + Map p1 = cg.paramTable(); + Map p2 = cg2.paramTable(); + for(String s : p1.keySet()){ + INDArray i1 = p1.get(s); + INDArray i2 = p2.get(s.replaceAll("out", "newOut")); + assertEquals(s, i1, i2); + } + + INDArray out2 = cg2.outputSingle(arr); + assertEquals(out, out2); + } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java index 4d15010c9..1ac2bd7bb 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java @@ -41,6 +41,7 @@ import org.deeplearning4j.nn.weights.WeightInitRelu; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -48,6 +49,8 @@ import org.nd4j.linalg.learning.config.*; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.shade.jackson.core.JsonProcessingException; +import java.util.Map; + import static org.junit.Assert.*; /** @@ -689,4 +692,51 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals("Incorrect number of inputs!", 5, newNet.layerInputSize(2)); newNet.output(input); } + + + @Test + public void testTransferLearningSameDiffLayers(){ + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) + .activation(Activation.TANH) + .updater(new Adam(0.01)) + .weightInit(WeightInit.XAVIER) + .list() + .layer(new LSTM.Builder().nOut(8).build()) + .layer( new SelfAttentionLayer.Builder().nOut(4).nHeads(2).projectInput(true).build()) + .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) + .layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.recurrent(4)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray in = Nd4j.rand(DataType.FLOAT, 3, 4, 5); + INDArray out = net.output(in); + + MultiLayerNetwork net2 = new TransferLearning.Builder(net) + .fineTuneConfiguration(FineTuneConfiguration.builder().updater(new Adam(0.01)).build()) + .removeLayersFromOutput(1) + .addLayer(new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .build(); + + net2.setParam("3_W", net.getParam("3_W")); + net2.setParam("3_b", net.getParam("3_b")); + + Map p1 = net.paramTable(); + Map p2 = net2.paramTable(); + for(String s : p1.keySet()){ + INDArray i1 = p1.get(s); + INDArray i2 = p2.get(s); + assertEquals(s, i1, i2); + } + + INDArray out2 = net2.output(in); + + assertEquals(out, out2); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index 5a5ce5665..6a4bc01aa 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -427,7 +427,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { if(!disconnected.isEmpty() && !allowNoOutput){ //If allowing no output: by definition we have disconnected vertices throw new IllegalStateException("Invalid configuration: disconnected vertices found - " + disconnected + ". Disconnected vertices are those that do not connect to either another vertex, and are also" - + " not a network output. To disable this error (i.e., allow network configurations with" + + + " not a network output. This vertex can be set as an output using setOutputs(String...). " + + "To disable this error (i.e., allow network configurations with" + " disconnected vertices) use GraphBuilder.allowDisconnected(true)"); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java index 6ca3b35de..eb2a78b11 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/AttentionVertex.java @@ -72,6 +72,20 @@ public class AttentionVertex extends SameDiffVertex { this.weightInit = builder.weightInit; } + @Override + public AttentionVertex clone() { + AttentionVertex av = new AttentionVertex(); + av.nInKeys = nInKeys; + av.nInValues = nInValues; + av.nInQueries = nInQueries; + av.nOut = nOut; + av.headSize = headSize; + av.nHeads = nHeads; + av.projectInput = projectInput; + av.weightInit = weightInit; + return av; + } + @Override public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException { InputType.InputTypeRecurrent queries = (InputType.InputTypeRecurrent) vertexInputs[0]; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java index 6cd8630d0..6478b6d59 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; @@ -79,7 +80,7 @@ public class EmbeddingLayer extends FeedForwardLayer { @Override public ParamInitializer initializer() { - return DefaultParamInitializer.getInstance(); + return EmbeddingLayerParamInitializer.getInstance(); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 93585a1d0..9b7725801 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -24,7 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer; import org.deeplearning4j.nn.weights.embeddings.EmbeddingInitializer; @@ -92,7 +92,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { @Override public ParamInitializer initializer() { - return DefaultParamInitializer.getInstance(); + return EmbeddingLayerParamInitializer.getInstance(); } @Override diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java index 6890f83e8..d8fc9fd48 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java @@ -16,11 +16,13 @@ package org.deeplearning4j.nn.conf.layers.samediff; +import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import java.lang.reflect.InvocationTargetException; import java.util.*; @@ -75,6 +77,15 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex { //No op, for lambda vertex } + @Override + public GraphVertex clone() { + try { + return getClass().getConstructor().newInstance(); + } catch (Exception e){ + throw new RuntimeException("Unable to create new instance of class " + getClass().getName() + " from no-arg constructor"); + } + } + protected VertexInputs getInputs(SameDiff sd) { if (inputs == null) { inputs = new VertexInputs(sd); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java index bd15b870b..f0066dac7 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex; @@ -36,6 +37,7 @@ import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.util.ArrayUtil; +import java.lang.reflect.Field; import java.util.List; import java.util.Map; @@ -99,11 +101,6 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf return vertexParams; } - @Override - public GraphVertex clone() { - throw new UnsupportedOperationException("Not yet implemented"); - } - @Override public long numParams(boolean backprop) { SDLayerParams params = getVertexParams(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index eb84fefd6..0a34fe95a 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -3394,7 +3394,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { @Override public void setParamTable(@NonNull Map paramTable) { - Preconditions.checkArgument(paramTable.keySet().equals(paramTable().keySet()), "Cannot set param table: parameter set keys are not equal"); + Map m = paramTable(); + Preconditions.checkArgument(paramTable.keySet().equals(m.keySet()), "Cannot set param table: parameter set keys are not equal"); Map current = paramTable(); //Check shapes before doing partial assigment to avoid leaving net in incorrect state for(String s : current.keySet()){ diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index fcf899544..64c2ea25e 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -237,9 +237,16 @@ public class SameDiffLayer extends AbstractLayer { @Override public void setParams(INDArray params) { - if (params != null) { - throw new UnsupportedOperationException("Not supported"); - } + if(this.params == null && params == null) + return; + if(this.params == null) + throw new IllegalStateException("Cannot set parameters of length " + params.length() + " to a layer with no parameters"); + if(params == null) + throw new IllegalStateException("Cannot set null parameters"); + + Preconditions.checkState(this.params.length() == params.length(), "Cannot assign parameter vector of length %s to a layer with %s parameters", + params.length(), this.params.length()); + this.params.assign(params); } protected void setParams(INDArray params, char order) { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/EmbeddingLayerParamInitializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/EmbeddingLayerParamInitializer.java new file mode 100644 index 000000000..ab50e44f0 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/params/EmbeddingLayerParamInitializer.java @@ -0,0 +1,52 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.params; + +import lombok.val; +import org.deeplearning4j.nn.weights.IWeightInit; +import org.deeplearning4j.nn.weights.WeightInitUtil; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * Parameter initializer for EmbeddingLayer and EmbeddingSequenceLayer + * + * @author Alex Black + */ +public class EmbeddingLayerParamInitializer extends DefaultParamInitializer { + + private static final EmbeddingLayerParamInitializer INSTANCE = new EmbeddingLayerParamInitializer(); + + public static EmbeddingLayerParamInitializer getInstance() { + return INSTANCE; + } + + + + protected INDArray createWeightMatrix(long nIn, long nOut, IWeightInit weightInit, + INDArray weightParamView, boolean initializeParameters) { + val shape = new long[] {nIn, nOut}; + + if (initializeParameters) { + INDArray ret = weightInit.init(1, //Fan in - note that fanIn=1 for embedding layer... if we used layer nIn (i.e., vocab size) the init would depend on vocab size (which doesn't make sense) + nOut, //Fan out + shape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView); + return ret; + } else { + return WeightInitUtil.reshapeWeights(shape, weightParamView); + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractMultiDataSetNormalizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractMultiDataSetNormalizer.java index 78f11f1cb..3a713ad63 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractMultiDataSetNormalizer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/AbstractMultiDataSetNormalizer.java @@ -79,7 +79,6 @@ public abstract class AbstractMultiDataSetNormalizer } protected List getFeatureStats() { - assertIsFit(); return featureStats; } @@ -88,7 +87,6 @@ public abstract class AbstractMultiDataSetNormalizer } protected List getLabelStats() { - assertIsFit(); return labelStats; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java index 5fc4491c8..77f3c39d9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImagePreProcessingScaler.java @@ -44,6 +44,7 @@ public class ImagePreProcessingScaler implements DataNormalization { private double minRange, maxRange; private double maxPixelVal; private int maxBits; + private boolean fitLabels = false; public ImagePreProcessingScaler() { this(0, 1, 8); @@ -94,7 +95,10 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void preProcess(DataSet toPreProcess) { INDArray features = toPreProcess.getFeatures(); - this.preProcess(features); + preProcess(features); + if(fitLabels && toPreProcess.getLabels() != null){ + preProcess(toPreProcess.getLabels()); + } } public void preProcess(INDArray features) { @@ -139,6 +143,7 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void revert(DataSet toRevert) { revertFeatures(toRevert.getFeatures()); + revertLabels(toRevert.getLabels()); } @Override @@ -177,10 +182,11 @@ public class ImagePreProcessingScaler implements DataNormalization { @Override public void fitLabel(boolean fitLabels) { //No-op + this.fitLabels = fitLabels; } @Override public boolean isFitLabel() { - return false; + return fitLabels; } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java index c3c89c296..1d1558db7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/ImagePreProcessortTest.java @@ -22,6 +22,7 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.adapter.SingletonDataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.ImageMultiPreProcessingScaler; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.factory.Nd4j; @@ -156,6 +157,42 @@ public class ImagePreProcessortTest extends BaseNd4jTest { assertEquals(orig, before); } + + @Test + public void testSegmentation(){ + + INDArray f = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 3, 16, 16).muli(255)); + INDArray l = Nd4j.math().floor(Nd4j.rand(DataType.FLOAT, 3, 10, 8, 8).muli(255)); + + ImagePreProcessingScaler s = new ImagePreProcessingScaler(); + s.fitLabel(true); + + s.fit(new DataSet(f,l)); + + INDArray expF = f.div(255); + INDArray expL = l.div(255); + + DataSet d = new DataSet(f.dup(), l.dup()); + s.transform(d); + assertEquals(expF, d.getFeatures()); + assertEquals(expL, d.getLabels()); + + + s.fit(new SingletonDataSetIterator(new DataSet(f, l))); + + INDArray f2 = f.dup(); + INDArray l2 = l.dup(); + s.transform(f2); + s.transformLabel(l2); + assertEquals(expF, f2); + assertEquals(expL, l2); + + s.revertFeatures(f2); + s.revertLabels(l2); + assertEquals(f, f2); + assertEquals(l, l2); + } + @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java index 23d928a11..5fa8524a4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/dataset/NormalizerTests.java @@ -22,11 +22,10 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; +import org.nd4j.linalg.dataset.api.preprocessor.*; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -189,11 +188,11 @@ public class NormalizerTests extends BaseNd4jTest { INDArray zeros = Nd4j.zeros(shouldBe0_1.shape()); - for (int j = 0; j < 2; j++) { - System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(), - NDArrayIndex.all())); - System.out.println(); - } +// for (int j = 0; j < 2; j++) { +// System.out.println(ds.getFeatures().get(NDArrayIndex.point(j), NDArrayIndex.all(), +// NDArrayIndex.all())); +// System.out.println(); +// } assertEquals(zeros, shouldBe0_1); assertEquals(zeros, shouldBe0_2); @@ -218,6 +217,69 @@ public class NormalizerTests extends BaseNd4jTest { } } + @Test + public void testNormalizerToStringHashCode(){ + //https://github.com/eclipse/deeplearning4j/issues/8565 + + testNormalizer(new NormalizerMinMaxScaler()); + NormalizerMinMaxScaler n1 = new NormalizerMinMaxScaler(); + n1.fitLabel(true); + testNormalizer(n1); + + testNormalizer(new NormalizerStandardize()); + NormalizerStandardize n2 = new NormalizerStandardize(); + n2.fitLabel(true); + testNormalizer(n2); + + testNormalizer(new ImagePreProcessingScaler()); + ImagePreProcessingScaler n3 = new ImagePreProcessingScaler(); + n3.fitLabel(true); + testNormalizer(n3); + + testNormalizer(new VGG16ImagePreProcessor()); + VGG16ImagePreProcessor n4 = new VGG16ImagePreProcessor(); + n4.fitLabel(true); + testNormalizer(n4); + } + + private static void testNormalizer(DataNormalization n){ + n.toString(); + n.hashCode(); + + n.fit(new IrisDataSetIterator(30, 150)); + + n.toString(); + n.hashCode(); + } + + @Test + public void testMultiNormalizerToStringHashCode(){ + //https://github.com/eclipse/deeplearning4j/issues/8565 + + testMultiNormalizer(new MultiNormalizerMinMaxScaler()); + MultiNormalizerMinMaxScaler n1 = new MultiNormalizerMinMaxScaler(); + n1.fitLabel(true); + testMultiNormalizer(n1); + + testMultiNormalizer(new MultiNormalizerStandardize()); + MultiNormalizerStandardize n2 = new MultiNormalizerStandardize(); + n2.fitLabel(true); + testMultiNormalizer(n2); + + testMultiNormalizer(new ImageMultiPreProcessingScaler(0)); + } + + private static void testMultiNormalizer(MultiDataNormalization n){ + n.toString(); + n.hashCode(); + + n.fit(new MultiDataSetIteratorAdapter(new IrisDataSetIterator(30, 150))); + + n.toString(); + n.hashCode(); + } + + @Override public char ordering() { return 'c';