From a856922fe97d9cc4284bb3130c1432c054acedcf Mon Sep 17 00:00:00 2001 From: Alex Black Date: Sat, 16 Nov 2019 23:09:41 +1100 Subject: [PATCH] #8409 Fix compgraph backprop issue with dual embedding layers from single input (#52) Signed-off-by: AlexDBlack --- .../nn/graph/TestComputationGraphNetwork.java | 19 +++++++++++++++++++ .../nn/graph/ComputationGraph.java | 7 ++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index daf657d0a..3e330d248 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -2143,4 +2143,23 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { INDArray in = Nd4j.create(DataType.FLOAT, 1, 3, 16, 16, 16); INDArray out = cg.outputSingle(in); } + + @Test + public void testDualEmbedding(){ + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .graphBuilder() + .addInputs("in") + .addLayer("e1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in") + .addLayer("e2", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in") + .addLayer("out", new OutputLayer.Builder().nIn(10).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "e1", "e2") + .setOutputs("out") + .build(); + + ComputationGraph cg = new ComputationGraph(conf); + cg.init(); + + INDArray in = Nd4j.createFromArray(3).reshape(1, 1); + INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); + cg.fit(new DataSet(in, label)); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 32d7bfb73..1be13ddf3 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -2734,7 +2734,12 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (setVertexEpsilon[gv.getVertexIndex()]) { //This vertex: must output to multiple vertices... we want to add the epsilons here INDArray currentEps = gv.getEpsilon(); - gv.setEpsilon(currentEps.addi(epsilons[j++])); //TODO is this always safe? + if(currentEps == null){ + //Edge case: this can be null for dual embedding layer case - in -> e1, in -> e2 + gv.setEpsilon(currentEps); + } else { + gv.setEpsilon(currentEps.addi(epsilons[j++])); //TODO is this always safe? + } } else { gv.setEpsilon(epsilons[j++]); }