Fix LogSumExp along dimension (#35)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
This commit is contained in:
		
							parent
							
								
									679e42199a
								
							
						
					
					
						commit
						24980efde3
					
				| @ -39,6 +39,7 @@ public class LogSumExp extends DynamicCustomOp { | ||||
|         super(sameDiff, i_v); | ||||
|         if(dimensions != null) { | ||||
|             addIArgument(dimensions); | ||||
|             this.dimensions = dimensions; | ||||
|         } | ||||
|         addTArgument(keepDims ? 1.0 : 0.0); | ||||
|         this.keepDims = keepDims; | ||||
|  | ||||
| @ -1970,4 +1970,24 @@ public class TransformOpValidation extends BaseOpValidation { | ||||
|         INDArray log = Transforms.log(sum); | ||||
|         assertEquals(log, out); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testLogSumExp2(){ | ||||
| 
 | ||||
|         for( int dim=0; dim<=2; dim++ ) { | ||||
|             Nd4j.getRandom().setSeed(12345); | ||||
|             INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5); | ||||
|             SameDiff sd = SameDiff.create(); | ||||
|             SDVariable in = sd.var(inputArr); | ||||
|             SDVariable lse = sd.math().logSumExp(in, dim); | ||||
| 
 | ||||
|             INDArray exp = Transforms.exp(inputArr, true); | ||||
|             INDArray sum = exp.sum(dim); | ||||
|             INDArray log = Transforms.log(sum); | ||||
| 
 | ||||
|             OpValidation.validate(new TestCase(sd) | ||||
|                     .expectedOutput(lse.name(), log) | ||||
|                     .gradientCheck(true)); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user