Add new tests (#171)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-27 12:15:56 +10:00 committed by GitHub
parent 25e5c23eae
commit e92f7218f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 126 additions and 0 deletions

View File

@ -438,6 +438,27 @@ public class MiscOpValidation extends BaseOpValidation {
assertEquals(failed.toString(), 0, failed.size()); assertEquals(failed.toString(), 0, failed.size());
} }
@Test
public void testScatterUpdate(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
INDArray updates = Nd4j.create(new float[][]{
{100, 101, 102},
{200, 201, 202}});
INDArray indices = Nd4j.createFromArray(2, 5);
INDArray exp = x.dup();
exp.putRow(2, updates.getRow(0));
exp.putRow(5, updates.getRow(1));
INDArray out = exp.ulike();
Nd4j.exec(DynamicCustomOp.builder("scatter_upd")
.addInputs(x, indices, updates)
.addOutputs(out)
.build());
assertEquals(exp, out);
}
@Test @Test
public void testGatherGradient() { public void testGatherGradient() {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
@ -1688,4 +1709,41 @@ public class MiscOpValidation extends BaseOpValidation {
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
} }
@Test
public void testHistogramFixedWidth(){
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
INDArray range = Nd4j.createFromArray(0.0, 1.0);
INDArray n = Nd4j.scalar(5);
INDArray out = Nd4j.create(DataType.INT, 5);
Nd4j.exec(DynamicCustomOp.builder("histogram_fixed_width")
.addInputs(in, range, n)
.addOutputs(out)
.build());
INDArray exp = Nd4j.createFromArray(3, 1, 2, 0, 1);
assertEquals(exp, out);
}
@Test
public void testListDiff(){
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
INDArray y = Nd4j.createFromArray(3, 1);
INDArray out = Nd4j.create(DataType.INT, 2);
INDArray outIdx = Nd4j.create(DataType.INT, 2);
Nd4j.exec(DynamicCustomOp.builder("listdiff")
.addInputs(x, y)
.addOutputs(out, outIdx)
.build());
INDArray exp = Nd4j.createFromArray(0, 2);
assertEquals(exp, out); //Values in x not in y
assertEquals(exp, outIdx); //Indices of the values in x not in y
}
} }

View File

@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli; import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
import org.nd4j.linalg.api.ops.random.custom.RandomExponential; import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
@ -371,6 +372,14 @@ public class RandomOpValidation extends BaseOpValidation {
assertNull(OpValidation.validate(tc)); assertNull(OpValidation.validate(tc));
} }
}
@Test
public void testAllEmptyReduce(){
INDArray x = Nd4j.createFromArray(true, true, true);
All all = new All(x);
all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction)
INDArray out = Nd4j.exec(all);
assertEquals(x, out);
} }
} }

View File

@ -1342,6 +1342,26 @@ public class ShapeOpValidation extends BaseOpValidation {
assertEquals(failed.toString(), 0, failed.size()); assertEquals(failed.toString(), 0, failed.size());
} }
@Test
public void testSegmentMean(){
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
INDArray out = Nd4j.create(DataType.FLOAT, 3, 3);
Nd4j.exec(DynamicCustomOp.builder("segment_mean")
.addInputs(x, segmentIds)
.addOutputs(out)
.build());
INDArray exp = out.like();
exp.putRow(0, x.getRow(0).add(x.getRow(1)).muli(0.5));
exp.putRow(1, x.getRow(2).add(x.getRow(3)).muli(0.5));
exp.putRow(2, x.getRow(4).add(x.getRow(5)).muli(0.5));
assertEquals(exp, out);
}
@Test @Test
public void testSequenceMask() { public void testSequenceMask() {
OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue? OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue?

View File

@ -415,6 +415,24 @@ public class TransformOpValidation extends BaseOpValidation {
assertNull(err, err); assertNull(err, err);
} }
@Test
public void testDynamicPartition2(){
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
.addIntegerArguments(3) //3 partitions
.addInputs(data, partitions).build());
INDArray exp0 = Nd4j.createFromArray(2, 0);
INDArray exp1 = Nd4j.createFromArray(2);
INDArray exp2 = Nd4j.createFromArray(1);
assertEquals(exp0, out[0]); //Usually just gives [0,0]
assertEquals(exp1, out[1]);
assertEquals(exp2, out[2]);
}
@Test @Test
public void testDynamicStitch() { public void testDynamicStitch() {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
@ -1612,6 +1630,27 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
@Test
public void testTopK1(){
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
INDArray k = Nd4j.scalar(1);
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
INDArray outIdx = Nd4j.create(DataType.INT, 1);
Nd4j.exec(DynamicCustomOp.builder("top_k")
.addInputs(x, k)
.addOutputs(outValue, outIdx)
.addBooleanArguments(false) //not sorted
.addIntegerArguments(1)
.build());
INDArray expValue = Nd4j.createFromArray(10.0);
INDArray expIdx = Nd4j.createFromArray(3);
assertEquals(expValue, outValue);
assertEquals(expIdx, outIdx);
}
@Test @Test
public void testInTopK() { public void testInTopK() {
for( int k=4; k>= 1; k--){ for( int k=4; k>= 1; k--){