diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index ff29189cd..5ccca2dd2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -39,6 +39,7 @@ public class LogSumExp extends DynamicCustomOp { super(sameDiff, i_v); if(dimensions != null) { addIArgument(dimensions); + this.dimensions = dimensions; } addTArgument(keepDims ? 1.0 : 0.0); this.keepDims = keepDims; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index bb3bab213..6a42d21e1 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -1970,4 +1970,24 @@ public class TransformOpValidation extends BaseOpValidation { INDArray log = Transforms.log(sum); assertEquals(log, out); } + + @Test + public void testLogSumExp2(){ + + for( int dim=0; dim<=2; dim++ ) { + Nd4j.getRandom().setSeed(12345); + INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var(inputArr); + SDVariable lse = sd.math().logSumExp(in, dim); + + INDArray exp = Transforms.exp(inputArr, true); + INDArray sum = exp.sum(dim); + INDArray log = Transforms.log(sum); + + OpValidation.validate(new TestCase(sd) + .expectedOutput(lse.name(), log) + .gradientCheck(true)); + } + } }