Ops exported for sameDiff
parent
96a9a1a733
commit
99d77e1384
|
@ -33,6 +33,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.NoOp;
|
import org.nd4j.linalg.api.ops.NoOp;
|
||||||
|
import org.nd4j.linalg.api.ops.custom.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
|
||||||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
|
||||||
|
@ -2649,6 +2650,33 @@ public class DifferentialFunctionFactory {
|
||||||
return new NextIteration(sameDiff, x).outputVariable();
|
return new NextIteration(sameDiff, x).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable adjustContrast(SDVariable in, SDVariable factor) {
|
||||||
|
return new AdjustContrast(sameDiff, in, factor).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) {
|
||||||
|
return new AdjustContrastV2(sameDiff, in, factor).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable bitCast(SDVariable in, SDVariable dataType) {
|
||||||
|
return new BitCast(sameDiff, in, dataType).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable compareAndBitpack(SDVariable threshold) {
|
||||||
|
return new CompareAndBitpack(sameDiff, threshold).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable divideNoNan(SDVariable in1, SDVariable in2) {
|
||||||
|
return new DivideNoNan(sameDiff, in1, in2).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) {
|
||||||
|
return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) {
|
||||||
|
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
|
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -12,6 +14,10 @@ public class AdjustContrast extends BaseAdjustContrast {
|
||||||
super(in, factor, out);
|
super(in, factor, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) {
|
||||||
|
super(sameDiff,new SDVariable[]{in,factor});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "adjust_contrast";
|
return "adjust_contrast";
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -12,6 +14,10 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
|
||||||
super(in, factor, out);
|
super(in, factor, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) {
|
||||||
|
super( sameDiff,new SDVariable[]{in,factor});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "adjust_contrast_v2";
|
return "adjust_contrast_v2";
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -16,4 +18,12 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp {
|
||||||
|
|
||||||
addTArgument(factor);
|
addTArgument(factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) {
|
||||||
|
super("", sameDiff, vars);
|
||||||
|
}
|
||||||
|
|
||||||
|
public BaseAdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor, SDVariable out) {
|
||||||
|
super(null, sameDiff, new SDVariable[]{in, factor, out});
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,5 +1,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -14,6 +16,10 @@ public class BitCast extends DynamicCustomOp {
|
||||||
iArguments.add(Long.valueOf(dataType));
|
iArguments.add(Long.valueOf(dataType));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) {
|
||||||
|
super("", sameDiff, new SDVariable[]{in, dataType});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "bitcast";
|
return "bitcast";
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -13,6 +15,10 @@ public class CompareAndBitpack extends DynamicCustomOp {
|
||||||
outputArguments.add(out);
|
outputArguments.add(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) {
|
||||||
|
super("", sameDiff, new SDVariable[]{threshold});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "compare_and_bitpack";
|
return "compare_and_bitpack";
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
import org.apache.commons.math3.analysis.function.Divide;
|
import org.apache.commons.math3.analysis.function.Divide;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
@ -14,6 +16,10 @@ public class DivideNoNan extends DynamicCustomOp {
|
||||||
outputArguments.add(out);
|
outputArguments.add(out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) {
|
||||||
|
super("", sameDiff, new SDVariable[]{in1, in2});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "divide_no_nan";
|
return "divide_no_nan";
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
@ -14,6 +16,10 @@ public class DrawBoundingBoxes extends DynamicCustomOp {
|
||||||
outputArguments.add(output);
|
outputArguments.add(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) {
|
||||||
|
super("", sameDiff, new SDVariable[]{boxes, colors});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "draw_bounding_boxes";
|
return "draw_bounding_boxes";
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
package org.nd4j.linalg.api.ops.custom;
|
package org.nd4j.linalg.api.ops.custom;
|
||||||
|
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
@ -18,6 +20,10 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
|
||||||
outputArguments.add(output);
|
outputArguments.add(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) {
|
||||||
|
super("", sameDiff, new SDVariable[]{x, min, max});
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "fake_quant_with_min_max_vars_per_channel";
|
return "fake_quant_with_min_max_vars_per_channel";
|
||||||
|
|
Loading…
Reference in New Issue