parent
679e42199a
commit
24980efde3
|
@ -39,6 +39,7 @@ public class LogSumExp extends DynamicCustomOp {
|
|||
super(sameDiff, i_v);
|
||||
if(dimensions != null) {
|
||||
addIArgument(dimensions);
|
||||
this.dimensions = dimensions;
|
||||
}
|
||||
addTArgument(keepDims ? 1.0 : 0.0);
|
||||
this.keepDims = keepDims;
|
||||
|
|
|
@ -1970,4 +1970,24 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
INDArray log = Transforms.log(sum);
|
||||
assertEquals(log, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLogSumExp2(){
|
||||
|
||||
for( int dim=0; dim<=2; dim++ ) {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var(inputArr);
|
||||
SDVariable lse = sd.math().logSumExp(in, dim);
|
||||
|
||||
INDArray exp = Transforms.exp(inputArr, true);
|
||||
INDArray sum = exp.sum(dim);
|
||||
INDArray log = Transforms.log(sum);
|
||||
|
||||
OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput(lse.name(), log)
|
||||
.gradientCheck(true));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue