Fix LogSumExp along dimension (#35)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-07 23:36:15 +11:00 committed by GitHub
parent 679e42199a
commit 24980efde3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 0 deletions

View File

@ -39,6 +39,7 @@ public class LogSumExp extends DynamicCustomOp {
super(sameDiff, i_v); super(sameDiff, i_v);
if(dimensions != null) { if(dimensions != null) {
addIArgument(dimensions); addIArgument(dimensions);
this.dimensions = dimensions;
} }
addTArgument(keepDims ? 1.0 : 0.0); addTArgument(keepDims ? 1.0 : 0.0);
this.keepDims = keepDims; this.keepDims = keepDims;

View File

@ -1970,4 +1970,24 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray log = Transforms.log(sum); INDArray log = Transforms.log(sum);
assertEquals(log, out); assertEquals(log, out);
} }
@Test
public void testLogSumExp2(){
for( int dim=0; dim<=2; dim++ ) {
Nd4j.getRandom().setSeed(12345);
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5);
SameDiff sd = SameDiff.create();
SDVariable in = sd.var(inputArr);
SDVariable lse = sd.math().logSumExp(in, dim);
INDArray exp = Transforms.exp(inputArr, true);
INDArray sum = exp.sum(dim);
INDArray log = Transforms.log(sum);
OpValidation.validate(new TestCase(sd)
.expectedOutput(lse.name(), log)
.gradientCheck(true));
}
}
} }