Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
09a827fb6d
commit
a856922fe9
|
@ -2143,4 +2143,23 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
INDArray in = Nd4j.create(DataType.FLOAT, 1, 3, 16, 16, 16);
|
INDArray in = Nd4j.create(DataType.FLOAT, 1, 3, 16, 16, 16);
|
||||||
INDArray out = cg.outputSingle(in);
|
INDArray out = cg.outputSingle(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDualEmbedding(){
|
||||||
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs("in")
|
||||||
|
.addLayer("e1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in")
|
||||||
|
.addLayer("e2", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in")
|
||||||
|
.addLayer("out", new OutputLayer.Builder().nIn(10).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "e1", "e2")
|
||||||
|
.setOutputs("out")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
|
cg.init();
|
||||||
|
|
||||||
|
INDArray in = Nd4j.createFromArray(3).reshape(1, 1);
|
||||||
|
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
|
||||||
|
cg.fit(new DataSet(in, label));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2734,7 +2734,12 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
if (setVertexEpsilon[gv.getVertexIndex()]) {
|
if (setVertexEpsilon[gv.getVertexIndex()]) {
|
||||||
//This vertex: must output to multiple vertices... we want to add the epsilons here
|
//This vertex: must output to multiple vertices... we want to add the epsilons here
|
||||||
INDArray currentEps = gv.getEpsilon();
|
INDArray currentEps = gv.getEpsilon();
|
||||||
|
if(currentEps == null){
|
||||||
|
//Edge case: this can be null for dual embedding layer case - in -> e1, in -> e2
|
||||||
|
gv.setEpsilon(currentEps);
|
||||||
|
} else {
|
||||||
gv.setEpsilon(currentEps.addi(epsilons[j++])); //TODO is this always safe?
|
gv.setEpsilon(currentEps.addi(epsilons[j++])); //TODO is this always safe?
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
gv.setEpsilon(epsilons[j++]);
|
gv.setEpsilon(epsilons[j++]);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue