From e92f7218f3bfe5833faa00670a5144903e552d95 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 27 Aug 2019 12:15:56 +1000 Subject: [PATCH] Add new tests (#171) Signed-off-by: AlexDBlack --- .../opvalidation/MiscOpValidation.java | 58 +++++++++++++++++++ .../opvalidation/RandomOpValidation.java | 9 +++ .../opvalidation/ShapeOpValidation.java | 20 +++++++ .../opvalidation/TransformOpValidation.java | 39 +++++++++++++ 4 files changed, 126 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java index 2a4b032b5..99a2f57ac 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/MiscOpValidation.java @@ -438,6 +438,27 @@ public class MiscOpValidation extends BaseOpValidation { assertEquals(failed.toString(), 0, failed.size()); } + @Test + public void testScatterUpdate(){ + INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3); + INDArray updates = Nd4j.create(new float[][]{ + {100, 101, 102}, + {200, 201, 202}}); + INDArray indices = Nd4j.createFromArray(2, 5); + + INDArray exp = x.dup(); + exp.putRow(2, updates.getRow(0)); + exp.putRow(5, updates.getRow(1)); + + INDArray out = exp.ulike(); + Nd4j.exec(DynamicCustomOp.builder("scatter_upd") + .addInputs(x, indices, updates) + .addOutputs(out) + .build()); + + assertEquals(exp, out); + } + @Test public void testGatherGradient() { Nd4j.getRandom().setSeed(12345); @@ -1688,4 +1709,41 @@ public class MiscOpValidation extends BaseOpValidation { Nd4j.getExecutioner().exec(op); } + + @Test + public void testHistogramFixedWidth(){ + //Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf] + INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9); + INDArray range = Nd4j.createFromArray(0.0, 1.0); + INDArray n = Nd4j.scalar(5); + + INDArray out = Nd4j.create(DataType.INT, 5); + + Nd4j.exec(DynamicCustomOp.builder("histogram_fixed_width") + .addInputs(in, range, n) + .addOutputs(out) + .build()); + + INDArray exp = Nd4j.createFromArray(3, 1, 2, 0, 1); + assertEquals(exp, out); + } + + @Test + public void testListDiff(){ + INDArray x = Nd4j.createFromArray(0, 1, 2, 3); + INDArray y = Nd4j.createFromArray(3, 1); + + INDArray out = Nd4j.create(DataType.INT, 2); + INDArray outIdx = Nd4j.create(DataType.INT, 2); + + Nd4j.exec(DynamicCustomOp.builder("listdiff") + .addInputs(x, y) + .addOutputs(out, outIdx) + .build()); + + INDArray exp = Nd4j.createFromArray(0, 2); + + assertEquals(exp, out); //Values in x not in y + assertEquals(exp, outIdx); //Indices of the values in x not in y + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java index 646cae454..8d64f6404 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/RandomOpValidation.java @@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; @@ -371,6 +372,14 @@ public class RandomOpValidation extends BaseOpValidation { assertNull(OpValidation.validate(tc)); } + } + @Test + public void testAllEmptyReduce(){ + INDArray x = Nd4j.createFromArray(true, true, true); + All all = new All(x); + all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction) + INDArray out = Nd4j.exec(all); + assertEquals(x, out); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 2965f367f..ffb585183 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -1342,6 +1342,26 @@ public class ShapeOpValidation extends BaseOpValidation { assertEquals(failed.toString(), 0, failed.size()); } + @Test + public void testSegmentMean(){ + INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3); + INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2); + + INDArray out = Nd4j.create(DataType.FLOAT, 3, 3); + + Nd4j.exec(DynamicCustomOp.builder("segment_mean") + .addInputs(x, segmentIds) + .addOutputs(out) + .build()); + + INDArray exp = out.like(); + exp.putRow(0, x.getRow(0).add(x.getRow(1)).muli(0.5)); + exp.putRow(1, x.getRow(2).add(x.getRow(3)).muli(0.5)); + exp.putRow(2, x.getRow(4).add(x.getRow(5)).muli(0.5)); + + assertEquals(exp, out); + } + @Test public void testSequenceMask() { OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue? diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 0d177027d..9183a0884 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -415,6 +415,24 @@ public class TransformOpValidation extends BaseOpValidation { assertNull(err, err); } + @Test + public void testDynamicPartition2(){ + INDArray data = Nd4j.createFromArray(2, 1, 2, 0); + INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); + INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) + .addIntegerArguments(3) //3 partitions + .addInputs(data, partitions).build()); + + INDArray exp0 = Nd4j.createFromArray(2, 0); + INDArray exp1 = Nd4j.createFromArray(2); + INDArray exp2 = Nd4j.createFromArray(1); + + assertEquals(exp0, out[0]); //Usually just gives [0,0] + assertEquals(exp1, out[1]); + assertEquals(exp2, out[2]); + } + @Test public void testDynamicStitch() { SameDiff sd = SameDiff.create(); @@ -1612,6 +1630,27 @@ public class TransformOpValidation extends BaseOpValidation { } } + @Test + public void testTopK1(){ + INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); + INDArray k = Nd4j.scalar(1); + INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); + INDArray outIdx = Nd4j.create(DataType.INT, 1); + + Nd4j.exec(DynamicCustomOp.builder("top_k") + .addInputs(x, k) + .addOutputs(outValue, outIdx) + .addBooleanArguments(false) //not sorted + .addIntegerArguments(1) + .build()); + + INDArray expValue = Nd4j.createFromArray(10.0); + INDArray expIdx = Nd4j.createFromArray(3); + + assertEquals(expValue, outValue); + assertEquals(expIdx, outIdx); + } + @Test public void testInTopK() { for( int k=4; k>= 1; k--){