Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									09a827fb6d
								
							
						
					
					
						commit
						a856922fe9
					
				| @ -2143,4 +2143,23 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { | |||||||
|         INDArray in = Nd4j.create(DataType.FLOAT, 1, 3, 16, 16, 16); |         INDArray in = Nd4j.create(DataType.FLOAT, 1, 3, 16, 16, 16); | ||||||
|         INDArray out = cg.outputSingle(in); |         INDArray out = cg.outputSingle(in); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     @Test | ||||||
|  |     public void testDualEmbedding(){ | ||||||
|  |         ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | ||||||
|  |                 .graphBuilder() | ||||||
|  |                 .addInputs("in") | ||||||
|  |                 .addLayer("e1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in") | ||||||
|  |                 .addLayer("e2", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in") | ||||||
|  |                 .addLayer("out", new OutputLayer.Builder().nIn(10).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "e1", "e2") | ||||||
|  |                 .setOutputs("out") | ||||||
|  |                 .build(); | ||||||
|  | 
 | ||||||
|  |         ComputationGraph cg = new ComputationGraph(conf); | ||||||
|  |         cg.init(); | ||||||
|  | 
 | ||||||
|  |         INDArray in = Nd4j.createFromArray(3).reshape(1, 1); | ||||||
|  |         INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); | ||||||
|  |         cg.fit(new DataSet(in, label)); | ||||||
|  |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -2734,7 +2734,12 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { | |||||||
|                         if (setVertexEpsilon[gv.getVertexIndex()]) { |                         if (setVertexEpsilon[gv.getVertexIndex()]) { | ||||||
|                             //This vertex: must output to multiple vertices... we want to add the epsilons here |                             //This vertex: must output to multiple vertices... we want to add the epsilons here | ||||||
|                             INDArray currentEps = gv.getEpsilon(); |                             INDArray currentEps = gv.getEpsilon(); | ||||||
|                             gv.setEpsilon(currentEps.addi(epsilons[j++]));  //TODO is this always safe? |                             if(currentEps == null){ | ||||||
|  |                                 //Edge case: this can be null for dual embedding layer case - in -> e1, in -> e2 | ||||||
|  |                                 gv.setEpsilon(currentEps); | ||||||
|  |                             } else { | ||||||
|  |                                 gv.setEpsilon(currentEps.addi(epsilons[j++]));  //TODO is this always safe? | ||||||
|  |                             } | ||||||
|                         } else { |                         } else { | ||||||
|                             gv.setEpsilon(epsilons[j++]); |                             gv.setEpsilon(epsilons[j++]); | ||||||
|                         } |                         } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user