- 2 mod tests

- ModOp mapping added

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-19 12:57:24 +03:00
parent 01cb57041a
commit b8ab1a00b0
4 changed files with 24 additions and 1 deletions

View File

@ -2200,7 +2200,6 @@ public class DifferentialFunctionFactory {
public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction);
return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
}
public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) {

View File

@ -452,6 +452,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class,

View File

@ -3565,6 +3565,17 @@ public class SameDiffTests extends BaseNd4jTest {
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
assertEquals(115, SD.exec(null, outName).get(outName).getInt(0));
}
@Test
public void testMod_1(){
val sd = SameDiff.create();
val initial = sd.constant("initial", Nd4j.createFromArray(5.f, 6.f, 7.f));
val four = sd.constant("four", 4.0f);
val mod = initial.mod("mod", four);
val e = Nd4j.createFromArray(1.f, 2.f, 3.f);
assertEquals(e, mod.eval());
}
}

View File

@ -33,6 +33,7 @@ import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
@ -569,6 +570,17 @@ public class CustomOpsTests extends BaseNd4jTest {
}
}
@Test
public void testMod_1() {
val x = Nd4j.createFromArray(5.f, 6.f, 7.f);
val y = Nd4j.scalar(4.f);
val e = Nd4j.createFromArray(1.f, 2.f, 3.f);
val z = Nd4j.exec(new ModOp(new INDArray[]{x, y}, new INDArray[]{}))[0];
assertEquals(e, z);
}
@Test
public void testScalarVector_edge_1() {
val x = Nd4j.scalar(2.0f);