Small test fixes (#165)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-26 11:45:33 +10:00 committed by GitHub
parent ece6a17b11
commit d607bec6f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 12 deletions

View File

@ -291,7 +291,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
for (Metric m : Metric.values()) { for (Metric m : Metric.values()) {
double d1 = e4d.scoreForMetric(m); double d1 = e4d.scoreForMetric(m);
double d2 = e2d.scoreForMetric(m); double d2 = e2d.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6); assertEquals(m.toString(), d2, d1, 1e-5);
} }
} }
@ -385,7 +385,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
for(Metric m : Metric.values()){ for(Metric m : Metric.values()){
double d1 = e4d_m1.scoreForMetric(m); double d1 = e4d_m1.scoreForMetric(m);
double d2 = e2d_m1.scoreForMetric(m); double d2 = e2d_m1.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6); assertEquals(m.toString(), d2, d1, 1e-5);
} }
//Check per-output masking: //Check per-output masking:

View File

@ -551,15 +551,13 @@ public class SpecialTests extends BaseNd4jTest {
int[] inputShape = new int[]{1, 2, 2, 1}; int[] inputShape = new int[]{1, 2, 2, 1};
int M = 2; int M = 2;
int[] blockShape = new int[]{M, 1};
int[] paddingShape = new int[]{M, 2};
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT); INDArray blocks = Nd4j.createFromArray(2, 2);
INDArray padding = Nd4j.create(new float[]{0, 0, 0, 0}, paddingShape).castTo(DataType.INT); INDArray padding = Nd4j.createFromArray(0, 0, 0, 0).reshape(2,2);
INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1); INDArray expOut = Nd4j.create(DataType.DOUBLE, miniBatch, 1, 1, 1);
val op = DynamicCustomOp.builder("space_to_batch") val op = DynamicCustomOp.builder("space_to_batch_nd")
.addInputs(input, blocks, padding) .addInputs(input, blocks, padding)
.addOutputs(expOut).build(); .addOutputs(expOut).build();
Nd4j.getExecutioner().execAndReturn(op); Nd4j.getExecutioner().execAndReturn(op);
@ -573,15 +571,13 @@ public class SpecialTests extends BaseNd4jTest {
int[] inputShape = new int[]{miniBatch, 1, 1, 1}; int[] inputShape = new int[]{miniBatch, 1, 1, 1};
int M = 2; int M = 2;
int[] blockShape = new int[]{M, 1};
int[] cropShape = new int[]{M, 2};
INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE); INDArray input = Nd4j.randn(inputShape).castTo(DataType.DOUBLE);
INDArray blocks = Nd4j.create(new float[]{2, 2}, blockShape).castTo(DataType.INT); INDArray blocks = Nd4j.createFromArray(2, 2);
INDArray crops = Nd4j.create(new float[]{0, 0, 0, 0}, cropShape).castTo(DataType.INT); INDArray crops = Nd4j.createFromArray(0, 0, 0, 0).reshape(2,2);
INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1); INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 2, 2, 1);
DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space") DynamicCustomOp op = DynamicCustomOp.builder("batch_to_space_nd")
.addInputs(input, blocks, crops) .addInputs(input, blocks, crops)
.addOutputs(expOut).build(); .addOutputs(expOut).build();
Nd4j.getExecutioner().execAndReturn(op); Nd4j.getExecutioner().execAndReturn(op);