- 2 mod tests

- ModOp mapping added

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-19 12:57:24 +03:00
parent 01cb57041a
commit b8ab1a00b0
4 changed files with 24 additions and 1 deletions

View File

@ -2200,7 +2200,6 @@ public class DifferentialFunctionFactory {
public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) { public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) {
validateDifferentialFunctionsameDiff(differentialFunction); validateDifferentialFunctionsameDiff(differentialFunction);
return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
} }
public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) { public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) {

View File

@ -452,6 +452,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class,

View File

@ -3565,6 +3565,17 @@ public class SameDiffTests extends BaseNd4jTest {
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
assertEquals(115, SD.exec(null, outName).get(outName).getInt(0)); assertEquals(115, SD.exec(null, outName).get(outName).getInt(0));
}
@Test
public void testMod_1(){
val sd = SameDiff.create();
val initial = sd.constant("initial", Nd4j.createFromArray(5.f, 6.f, 7.f));
val four = sd.constant("four", 4.0f);
val mod = initial.mod("mod", four);
val e = Nd4j.createFromArray(1.f, 2.f, 3.f);
assertEquals(e, mod.eval());
} }
} }

View File

@ -33,6 +33,7 @@ import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal; import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
@ -569,6 +570,17 @@ public class CustomOpsTests extends BaseNd4jTest {
} }
} }
@Test
public void testMod_1() {
val x = Nd4j.createFromArray(5.f, 6.f, 7.f);
val y = Nd4j.scalar(4.f);
val e = Nd4j.createFromArray(1.f, 2.f, 3.f);
val z = Nd4j.exec(new ModOp(new INDArray[]{x, y}, new INDArray[]{}))[0];
assertEquals(e, z);
}
@Test @Test
public void testScalarVector_edge_1() { public void testScalarVector_edge_1() {
val x = Nd4j.scalar(2.0f); val x = Nd4j.scalar(2.0f);