Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									09a827fb6d
								
							
						
					
					
						commit
						a856922fe9
					
				@ -2143,4 +2143,23 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
 | 
			
		||||
        INDArray in = Nd4j.create(DataType.FLOAT, 1, 3, 16, 16, 16);
 | 
			
		||||
        INDArray out = cg.outputSingle(in);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    @Test
 | 
			
		||||
    public void testDualEmbedding(){
 | 
			
		||||
        ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
 | 
			
		||||
                .graphBuilder()
 | 
			
		||||
                .addInputs("in")
 | 
			
		||||
                .addLayer("e1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in")
 | 
			
		||||
                .addLayer("e2", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in")
 | 
			
		||||
                .addLayer("out", new OutputLayer.Builder().nIn(10).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "e1", "e2")
 | 
			
		||||
                .setOutputs("out")
 | 
			
		||||
                .build();
 | 
			
		||||
 | 
			
		||||
        ComputationGraph cg = new ComputationGraph(conf);
 | 
			
		||||
        cg.init();
 | 
			
		||||
 | 
			
		||||
        INDArray in = Nd4j.createFromArray(3).reshape(1, 1);
 | 
			
		||||
        INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
 | 
			
		||||
        cg.fit(new DataSet(in, label));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2734,7 +2734,12 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
 | 
			
		||||
                        if (setVertexEpsilon[gv.getVertexIndex()]) {
 | 
			
		||||
                            //This vertex: must output to multiple vertices... we want to add the epsilons here
 | 
			
		||||
                            INDArray currentEps = gv.getEpsilon();
 | 
			
		||||
                            gv.setEpsilon(currentEps.addi(epsilons[j++]));  //TODO is this always safe?
 | 
			
		||||
                            if(currentEps == null){
 | 
			
		||||
                                //Edge case: this can be null for dual embedding layer case - in -> e1, in -> e2
 | 
			
		||||
                                gv.setEpsilon(currentEps);
 | 
			
		||||
                            } else {
 | 
			
		||||
                                gv.setEpsilon(currentEps.addi(epsilons[j++]));  //TODO is this always safe?
 | 
			
		||||
                            }
 | 
			
		||||
                        } else {
 | 
			
		||||
                            gv.setEpsilon(epsilons[j++]);
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user