Add test from reported issue (confirmed fixed) (#359)

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-04-06 15:02:09 +10:00 committed by GitHub
parent 1a35ebec2e
commit 986ec4b51a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 0 deletions

View File

@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
@ -410,4 +411,29 @@ public class RandomOpValidation extends BaseOpValidation {
}
}
}
@Test
public void testRandomExponential2(){
Nd4j.getRandom().setSeed(12345);
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")
.addInputs(Nd4j.createFromArray(100))
.addOutputs(Nd4j.create(DataType.FLOAT, 100))
.addFloatingPointArguments(0.5)
.build();
Nd4j.exec(op);
INDArray out = op.getOutputArgument(0);
int count0 = out.eq(0.0).castTo(DataType.INT32).sumNumber().intValue();
int count1 = out.eq(1.0).castTo(DataType.INT32).sumNumber().intValue();
assertEquals(0, count0);
assertEquals(0, count1);
double min = out.minNumber().doubleValue();
double max = out.maxNumber().doubleValue();
assertTrue(String.valueOf(min), min > 0.0);
assertTrue(String.valueOf(max), max > 1.0);
}
}