Add test from reported issue (confirmed fixed) (#359)
Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
1a35ebec2e
commit
986ec4b51a
|
@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
||||
|
@ -410,4 +411,29 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRandomExponential2(){
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("random_exponential")
|
||||
.addInputs(Nd4j.createFromArray(100))
|
||||
.addOutputs(Nd4j.create(DataType.FLOAT, 100))
|
||||
.addFloatingPointArguments(0.5)
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
|
||||
INDArray out = op.getOutputArgument(0);
|
||||
int count0 = out.eq(0.0).castTo(DataType.INT32).sumNumber().intValue();
|
||||
int count1 = out.eq(1.0).castTo(DataType.INT32).sumNumber().intValue();
|
||||
|
||||
assertEquals(0, count0);
|
||||
assertEquals(0, count1);
|
||||
|
||||
double min = out.minNumber().doubleValue();
|
||||
double max = out.maxNumber().doubleValue();
|
||||
|
||||
assertTrue(String.valueOf(min), min > 0.0);
|
||||
assertTrue(String.valueOf(max), max > 1.0);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue