parent
679e42199a
commit
24980efde3
|
@ -39,6 +39,7 @@ public class LogSumExp extends DynamicCustomOp {
|
||||||
super(sameDiff, i_v);
|
super(sameDiff, i_v);
|
||||||
if(dimensions != null) {
|
if(dimensions != null) {
|
||||||
addIArgument(dimensions);
|
addIArgument(dimensions);
|
||||||
|
this.dimensions = dimensions;
|
||||||
}
|
}
|
||||||
addTArgument(keepDims ? 1.0 : 0.0);
|
addTArgument(keepDims ? 1.0 : 0.0);
|
||||||
this.keepDims = keepDims;
|
this.keepDims = keepDims;
|
||||||
|
|
|
@ -1970,4 +1970,24 @@ public class TransformOpValidation extends BaseOpValidation {
|
||||||
INDArray log = Transforms.log(sum);
|
INDArray log = Transforms.log(sum);
|
||||||
assertEquals(log, out);
|
assertEquals(log, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testLogSumExp2(){
|
||||||
|
|
||||||
|
for( int dim=0; dim<=2; dim++ ) {
|
||||||
|
Nd4j.getRandom().setSeed(12345);
|
||||||
|
INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5);
|
||||||
|
SameDiff sd = SameDiff.create();
|
||||||
|
SDVariable in = sd.var(inputArr);
|
||||||
|
SDVariable lse = sd.math().logSumExp(in, dim);
|
||||||
|
|
||||||
|
INDArray exp = Transforms.exp(inputArr, true);
|
||||||
|
INDArray sum = exp.sum(dim);
|
||||||
|
INDArray log = Transforms.log(sum);
|
||||||
|
|
||||||
|
OpValidation.validate(new TestCase(sd)
|
||||||
|
.expectedOutput(lse.name(), log)
|
||||||
|
.gradientCheck(true));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue