Misc fixes (#66)
* Small fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Flaky test fix Signed-off-by: Alex Black <blacka101@gmail.com>master
parent
d94bc7257c
commit
7939cf384b
|
@ -163,7 +163,7 @@ public class LossOpValidation extends BaseOpValidation {
|
||||||
//Loss loss aka binary cross entropy loss
|
//Loss loss aka binary cross entropy loss
|
||||||
//Labels are random bernoulli
|
//Labels are random bernoulli
|
||||||
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelsArr, 0.5));
|
Nd4j.getExecutioner().exec(new BernoulliDistribution(labelsArr, 0.5));
|
||||||
predictionsArr = Nd4j.rand(predictionsArr.shape());
|
predictionsArr = Nd4j.rand(predictionsArr.shape()).muli(0.8).addi(0.1);
|
||||||
INDArray logP = Transforms.log(predictionsArr.add(eps), true);
|
INDArray logP = Transforms.log(predictionsArr.add(eps), true);
|
||||||
INDArray log1p = Transforms.log(predictionsArr.rsub(1.0).add(eps), true);
|
INDArray log1p = Transforms.log(predictionsArr.rsub(1.0).add(eps), true);
|
||||||
expOut = labelsArr.mul(logP).addi(labelsArr.rsub(1).mul(log1p)).negi();
|
expOut = labelsArr.mul(logP).addi(labelsArr.rsub(1).mul(log1p)).negi();
|
||||||
|
|
|
@ -535,7 +535,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
name = "norm1";
|
name = "norm1";
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
maxRelError = 1e-4;
|
maxRelError = 1e-3; //Norm2 can also run into numerical precision issues
|
||||||
reduced = sd.norm2("reduced", second, reduceDim);
|
reduced = sd.norm2("reduced", second, reduceDim);
|
||||||
name = "norm2";
|
name = "norm2";
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -310,7 +310,7 @@ public class NumpyFormatTests extends BaseNd4jTest {
|
||||||
INDArray act1 = Nd4j.createFromNpyFile(f);
|
INDArray act1 = Nd4j.createFromNpyFile(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test(expected = IllegalArgumentException.class)
|
||||||
public void testAbsentNumpyFile_2() throws Exception {
|
public void testAbsentNumpyFile_2() throws Exception {
|
||||||
val f = new File("c:/develop/batch-x-1.npy");
|
val f = new File("c:/develop/batch-x-1.npy");
|
||||||
INDArray act1 = Nd4j.createFromNpyFile(f);
|
INDArray act1 = Nd4j.createFromNpyFile(f);
|
||||||
|
|
Loading…
Reference in New Issue