diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 0ad489cbf..f8e4827ce 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -2200,7 +2200,6 @@ public class DifferentialFunctionFactory { public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) { validateDifferentialFunctionsameDiff(differentialFunction); return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index c7f8cbd64..4c6ce710d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -452,6 +452,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class, + org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class, diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index f4984679a..becd05aa7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -3565,6 +3565,17 @@ public class SameDiffTests extends BaseNd4jTest { SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false)); assertEquals(115, SD.exec(null, outName).get(outName).getInt(0)); + } + @Test + public void testMod_1(){ + val sd = SameDiff.create(); + val initial = sd.constant("initial", Nd4j.createFromArray(5.f, 6.f, 7.f)); + val four = sd.constant("four", 4.0f); + val mod = initial.mod("mod", four); + + val e = Nd4j.createFromArray(1.f, 2.f, 3.f); + + assertEquals(e, mod.eval()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 57f2e75d9..675f60007 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -33,6 +33,7 @@ import org.nd4j.linalg.api.ops.executioner.OpStatus; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp; import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -569,6 +570,17 @@ public class CustomOpsTests extends BaseNd4jTest { } } + @Test + public void testMod_1() { + val x = Nd4j.createFromArray(5.f, 6.f, 7.f); + val y = Nd4j.scalar(4.f); + val e = Nd4j.createFromArray(1.f, 2.f, 3.f); + + val z = Nd4j.exec(new ModOp(new INDArray[]{x, y}, new INDArray[]{}))[0]; + + assertEquals(e, z); + } + @Test public void testScalarVector_edge_1() { val x = Nd4j.scalar(2.0f);