From 99d77e138412309eb8411b8d91ec7613950ecd7c Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Wed, 16 Oct 2019 19:16:47 +0300 Subject: [PATCH] Ops exported for sameDiff --- .../DifferentialFunctionFactory.java | 28 +++++++++++++++++++ .../linalg/api/ops/custom/AdjustContrast.java | 6 ++++ .../api/ops/custom/AdjustContrastV2.java | 6 ++++ .../api/ops/custom/BaseAdjustContrast.java | 10 +++++++ .../nd4j/linalg/api/ops/custom/BitCast.java | 6 ++++ .../api/ops/custom/CompareAndBitpack.java | 6 ++++ .../linalg/api/ops/custom/DivideNoNan.java | 6 ++++ .../api/ops/custom/DrawBoundingBoxes.java | 6 ++++ .../FakeQuantWithMinMaxVarsPerChannel.java | 6 ++++ 9 files changed, 80 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 1a40fbd11..4b042dded 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -33,6 +33,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.NoOp; +import org.nd4j.linalg.api.ops.custom.*; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; @@ -2649,6 +2650,33 @@ public class DifferentialFunctionFactory { return new NextIteration(sameDiff, x).outputVariable(); } + public SDVariable adjustContrast(SDVariable in, SDVariable factor) { + return new AdjustContrast(sameDiff, in, factor).outputVariable(); + } + + public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) { + return new AdjustContrastV2(sameDiff, in, factor).outputVariable(); + } + + public SDVariable bitCast(SDVariable in, SDVariable dataType) { + return new BitCast(sameDiff, in, dataType).outputVariable(); + } + + public SDVariable compareAndBitpack(SDVariable threshold) { + return new CompareAndBitpack(sameDiff, threshold).outputVariable(); + } + + public SDVariable divideNoNan(SDVariable in1, SDVariable in2) { + return new DivideNoNan(sameDiff, in1, in2).outputVariable(); + } + + public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) { + return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); + } + + public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) { + return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); + } public String toString() { return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java index aad384a26..181b1657d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrast.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -12,6 +14,10 @@ public class AdjustContrast extends BaseAdjustContrast { super(in, factor, out); } + public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { + super(sameDiff,new SDVariable[]{in,factor}); + } + @Override public String opName() { return "adjust_contrast"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java index 4be4ae098..74359da7f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/AdjustContrastV2.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -12,6 +14,10 @@ public class AdjustContrastV2 extends BaseAdjustContrast { super(in, factor, out); } + public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { + super( sameDiff,new SDVariable[]{in,factor}); + } + @Override public String opName() { return "adjust_contrast_v2"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java index 7057118c5..fe14fe69c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BaseAdjustContrast.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -16,4 +18,12 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp { addTArgument(factor); } + + public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { + super("", sameDiff, vars); + } + + public BaseAdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor, SDVariable out) { + super(null, sameDiff, new SDVariable[]{in, factor, out}); + } } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java index 7a1f125c6..fbfad0305 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/BitCast.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -14,6 +16,10 @@ public class BitCast extends DynamicCustomOp { iArguments.add(Long.valueOf(dataType)); } + public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { + super("", sameDiff, new SDVariable[]{in, dataType}); + } + @Override public String opName() { return "bitcast"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java index 4f0aad2ee..eb0762f0f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/CompareAndBitpack.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.factory.Nd4j; @@ -13,6 +15,10 @@ public class CompareAndBitpack extends DynamicCustomOp { outputArguments.add(out); } + public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) { + super("", sameDiff, new SDVariable[]{threshold}); + } + @Override public String opName() { return "compare_and_bitpack"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java index b2eafb791..ce67b14f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DivideNoNan.java @@ -1,6 +1,8 @@ package org.nd4j.linalg.api.ops.custom; import org.apache.commons.math3.analysis.function.Divide; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -14,6 +16,10 @@ public class DivideNoNan extends DynamicCustomOp { outputArguments.add(out); } + public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) { + super("", sameDiff, new SDVariable[]{in1, in2}); + } + @Override public String opName() { return "divide_no_nan"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java index c6cf04b62..2ac6e6458 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/DrawBoundingBoxes.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -14,6 +16,10 @@ public class DrawBoundingBoxes extends DynamicCustomOp { outputArguments.add(output); } + public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) { + super("", sameDiff, new SDVariable[]{boxes, colors}); + } + @Override public String opName() { return "draw_bounding_boxes"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java index 3bdcf6dd3..2043732d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/FakeQuantWithMinMaxVarsPerChannel.java @@ -1,5 +1,7 @@ package org.nd4j.linalg.api.ops.custom; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -18,6 +20,10 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { outputArguments.add(output); } + public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) { + super("", sameDiff, new SDVariable[]{x, min, max}); + } + @Override public String opName() { return "fake_quant_with_min_max_vars_per_channel";