diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 053f3a70b..e6e8962ec 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; @@ -410,4 +411,29 @@ public class RandomOpValidation extends BaseOpValidation { } } } + + @Test + public void testRandomExponential2(){ + Nd4j.getRandom().setSeed(12345); + DynamicCustomOp op = DynamicCustomOp.builder("random_exponential") + .addInputs(Nd4j.createFromArray(100)) + .addOutputs(Nd4j.create(DataType.FLOAT, 100)) + .addFloatingPointArguments(0.5) + .build(); + + Nd4j.exec(op); + + INDArray out = op.getOutputArgument(0); + int count0 = out.eq(0.0).castTo(DataType.INT32).sumNumber().intValue(); + int count1 = out.eq(1.0).castTo(DataType.INT32).sumNumber().intValue(); + + assertEquals(0, count0); + assertEquals(0, count1); + + double min = out.minNumber().doubleValue(); + double max = out.maxNumber().doubleValue(); + + assertTrue(String.valueOf(min), min > 0.0); + assertTrue(String.valueOf(max), max > 1.0); + } }