- 2 mod tests
- ModOp mapping added Signed-off-by: raver119 <raver119@gmail.com>master
parent
01cb57041a
commit
b8ab1a00b0
|
@ -2200,7 +2200,6 @@ public class DifferentialFunctionFactory {
|
|||
public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) {
|
||||
validateDifferentialFunctionsameDiff(differentialFunction);
|
||||
return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable();
|
||||
|
||||
}
|
||||
|
||||
public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) {
|
||||
|
|
|
@ -452,6 +452,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAddOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldAtan2Op.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldDivOp.class,
|
||||
|
|
|
@ -3565,6 +3565,17 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
SD = SameDiff.fromFlatBuffers(SD.asFlatBuffers(false));
|
||||
|
||||
assertEquals(115, SD.exec(null, outName).get(outName).getInt(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMod_1(){
|
||||
val sd = SameDiff.create();
|
||||
val initial = sd.constant("initial", Nd4j.createFromArray(5.f, 6.f, 7.f));
|
||||
val four = sd.constant("four", 4.0f);
|
||||
val mod = initial.mod("mod", four);
|
||||
|
||||
val e = Nd4j.createFromArray(1.f, 2.f, 3.f);
|
||||
|
||||
assertEquals(e, mod.eval());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.nd4j.linalg.api.ops.executioner.OpStatus;
|
|||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp;
|
||||
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
|
@ -569,6 +570,17 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMod_1() {
|
||||
val x = Nd4j.createFromArray(5.f, 6.f, 7.f);
|
||||
val y = Nd4j.scalar(4.f);
|
||||
val e = Nd4j.createFromArray(1.f, 2.f, 3.f);
|
||||
|
||||
val z = Nd4j.exec(new ModOp(new INDArray[]{x, y}, new INDArray[]{}))[0];
|
||||
|
||||
assertEquals(e, z);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testScalarVector_edge_1() {
|
||||
val x = Nd4j.scalar(2.0f);
|
||||
|
|
Loading…
Reference in New Issue