parent
25e5c23eae
commit
e92f7218f3
|
@ -438,6 +438,27 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
assertEquals(failed.toString(), 0, failed.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScatterUpdate(){
|
||||
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 30, 1).reshape(10, 3);
|
||||
INDArray updates = Nd4j.create(new float[][]{
|
||||
{100, 101, 102},
|
||||
{200, 201, 202}});
|
||||
INDArray indices = Nd4j.createFromArray(2, 5);
|
||||
|
||||
INDArray exp = x.dup();
|
||||
exp.putRow(2, updates.getRow(0));
|
||||
exp.putRow(5, updates.getRow(1));
|
||||
|
||||
INDArray out = exp.ulike();
|
||||
Nd4j.exec(DynamicCustomOp.builder("scatter_upd")
|
||||
.addInputs(x, indices, updates)
|
||||
.addOutputs(out)
|
||||
.build());
|
||||
|
||||
assertEquals(exp, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGatherGradient() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
@ -1688,4 +1709,41 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testHistogramFixedWidth(){
|
||||
//Bins: [-inf, 0.2), [0.2, 0.4), [0.4, 0.6), [0.6, 0.8), [0.8, inf]
|
||||
INDArray in = Nd4j.createFromArray(0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9);
|
||||
INDArray range = Nd4j.createFromArray(0.0, 1.0);
|
||||
INDArray n = Nd4j.scalar(5);
|
||||
|
||||
INDArray out = Nd4j.create(DataType.INT, 5);
|
||||
|
||||
Nd4j.exec(DynamicCustomOp.builder("histogram_fixed_width")
|
||||
.addInputs(in, range, n)
|
||||
.addOutputs(out)
|
||||
.build());
|
||||
|
||||
INDArray exp = Nd4j.createFromArray(3, 1, 2, 0, 1);
|
||||
assertEquals(exp, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testListDiff(){
|
||||
INDArray x = Nd4j.createFromArray(0, 1, 2, 3);
|
||||
INDArray y = Nd4j.createFromArray(3, 1);
|
||||
|
||||
INDArray out = Nd4j.create(DataType.INT, 2);
|
||||
INDArray outIdx = Nd4j.create(DataType.INT, 2);
|
||||
|
||||
Nd4j.exec(DynamicCustomOp.builder("listdiff")
|
||||
.addInputs(x, y)
|
||||
.addOutputs(out, outIdx)
|
||||
.build());
|
||||
|
||||
INDArray exp = Nd4j.createFromArray(0, 2);
|
||||
|
||||
assertEquals(exp, out); //Values in x not in y
|
||||
assertEquals(exp, outIdx); //Indices of the values in x not in y
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.nd4j.autodiff.validation.TestCase;
|
|||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomBernoulli;
|
||||
import org.nd4j.linalg.api.ops.random.custom.RandomExponential;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
|
||||
|
@ -371,6 +372,14 @@ public class RandomOpValidation extends BaseOpValidation {
|
|||
|
||||
assertNull(OpValidation.validate(tc));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAllEmptyReduce(){
|
||||
INDArray x = Nd4j.createFromArray(true, true, true);
|
||||
All all = new All(x);
|
||||
all.setEmptyReduce(true); //For TF compatibility - empty array for axis (which means no-op - and NOT all array reduction)
|
||||
INDArray out = Nd4j.exec(all);
|
||||
assertEquals(x, out);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1342,6 +1342,26 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
assertEquals(failed.toString(), 0, failed.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSegmentMean(){
|
||||
INDArray x = Nd4j.linspace(DataType.FLOAT, 1, 18, 1).reshape(6, 3);
|
||||
INDArray segmentIds = Nd4j.createFromArray(0, 0, 1, 1, 2, 2);
|
||||
|
||||
INDArray out = Nd4j.create(DataType.FLOAT, 3, 3);
|
||||
|
||||
Nd4j.exec(DynamicCustomOp.builder("segment_mean")
|
||||
.addInputs(x, segmentIds)
|
||||
.addOutputs(out)
|
||||
.build());
|
||||
|
||||
INDArray exp = out.like();
|
||||
exp.putRow(0, x.getRow(0).add(x.getRow(1)).muli(0.5));
|
||||
exp.putRow(1, x.getRow(2).add(x.getRow(3)).muli(0.5));
|
||||
exp.putRow(2, x.getRow(4).add(x.getRow(5)).muli(0.5));
|
||||
|
||||
assertEquals(exp, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSequenceMask() {
|
||||
OpValidationSuite.ignoreFailing(); //2018-01-09: output datatype issue?
|
||||
|
|
|
@ -415,6 +415,24 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
assertNull(err, err);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicPartition2(){
|
||||
INDArray data = Nd4j.createFromArray(2, 1, 2, 0);
|
||||
INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0);
|
||||
INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition")
|
||||
.addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1))
|
||||
.addIntegerArguments(3) //3 partitions
|
||||
.addInputs(data, partitions).build());
|
||||
|
||||
INDArray exp0 = Nd4j.createFromArray(2, 0);
|
||||
INDArray exp1 = Nd4j.createFromArray(2);
|
||||
INDArray exp2 = Nd4j.createFromArray(1);
|
||||
|
||||
assertEquals(exp0, out[0]); //Usually just gives [0,0]
|
||||
assertEquals(exp1, out[1]);
|
||||
assertEquals(exp2, out[2]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDynamicStitch() {
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
@ -1612,6 +1630,27 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTopK1(){
|
||||
INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0);
|
||||
INDArray k = Nd4j.scalar(1);
|
||||
INDArray outValue = Nd4j.create(DataType.DOUBLE, 1);
|
||||
INDArray outIdx = Nd4j.create(DataType.INT, 1);
|
||||
|
||||
Nd4j.exec(DynamicCustomOp.builder("top_k")
|
||||
.addInputs(x, k)
|
||||
.addOutputs(outValue, outIdx)
|
||||
.addBooleanArguments(false) //not sorted
|
||||
.addIntegerArguments(1)
|
||||
.build());
|
||||
|
||||
INDArray expValue = Nd4j.createFromArray(10.0);
|
||||
INDArray expIdx = Nd4j.createFromArray(3);
|
||||
|
||||
assertEquals(expValue, outValue);
|
||||
assertEquals(expIdx, outIdx);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInTopK() {
|
||||
for( int k=4; k>= 1; k--){
|
||||
|
|
Loading…
Reference in New Issue