Ops exported for sameDiff

master
Alexander Stoyakin 2019-10-16 19:16:47 +03:00
parent 96a9a1a733
commit 99d77e1384
9 changed files with 80 additions and 0 deletions

View File

@ -33,6 +33,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.NoOp; import org.nd4j.linalg.api.ops.NoOp;
import org.nd4j.linalg.api.ops.custom.*;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
@ -2649,6 +2650,33 @@ public class DifferentialFunctionFactory {
return new NextIteration(sameDiff, x).outputVariable(); return new NextIteration(sameDiff, x).outputVariable();
} }
public SDVariable adjustContrast(SDVariable in, SDVariable factor) {
return new AdjustContrast(sameDiff, in, factor).outputVariable();
}
public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) {
return new AdjustContrastV2(sameDiff, in, factor).outputVariable();
}
public SDVariable bitCast(SDVariable in, SDVariable dataType) {
return new BitCast(sameDiff, in, dataType).outputVariable();
}
public SDVariable compareAndBitpack(SDVariable threshold) {
return new CompareAndBitpack(sameDiff, threshold).outputVariable();
}
public SDVariable divideNoNan(SDVariable in1, SDVariable in2) {
return new DivideNoNan(sameDiff, in1, in2).outputVariable();
}
public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) {
return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable();
}
public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) {
return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable();
}
public String toString() { public String toString() {
return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; return "DifferentialFunctionFactory{methodNames=" + methodNames + "}";

View File

@ -1,5 +1,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -12,6 +14,10 @@ public class AdjustContrast extends BaseAdjustContrast {
super(in, factor, out); super(in, factor, out);
} }
public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) {
super(sameDiff,new SDVariable[]{in,factor});
}
@Override @Override
public String opName() { public String opName() {
return "adjust_contrast"; return "adjust_contrast";

View File

@ -1,5 +1,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -12,6 +14,10 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
super(in, factor, out); super(in, factor, out);
} }
public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) {
super( sameDiff,new SDVariable[]{in,factor});
}
@Override @Override
public String opName() { public String opName() {
return "adjust_contrast_v2"; return "adjust_contrast_v2";

View File

@ -1,5 +1,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -16,4 +18,12 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp {
addTArgument(factor); addTArgument(factor);
} }
public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) {
super("", sameDiff, vars);
}
public BaseAdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor, SDVariable out) {
super(null, sameDiff, new SDVariable[]{in, factor, out});
}
} }

View File

@ -1,5 +1,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -14,6 +16,10 @@ public class BitCast extends DynamicCustomOp {
iArguments.add(Long.valueOf(dataType)); iArguments.add(Long.valueOf(dataType));
} }
public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) {
super("", sameDiff, new SDVariable[]{in, dataType});
}
@Override @Override
public String opName() { public String opName() {
return "bitcast"; return "bitcast";

View File

@ -1,5 +1,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -13,6 +15,10 @@ public class CompareAndBitpack extends DynamicCustomOp {
outputArguments.add(out); outputArguments.add(out);
} }
public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) {
super("", sameDiff, new SDVariable[]{threshold});
}
@Override @Override
public String opName() { public String opName() {
return "compare_and_bitpack"; return "compare_and_bitpack";

View File

@ -1,6 +1,8 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.apache.commons.math3.analysis.function.Divide; import org.apache.commons.math3.analysis.function.Divide;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -14,6 +16,10 @@ public class DivideNoNan extends DynamicCustomOp {
outputArguments.add(out); outputArguments.add(out);
} }
public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) {
super("", sameDiff, new SDVariable[]{in1, in2});
}
@Override @Override
public String opName() { public String opName() {
return "divide_no_nan"; return "divide_no_nan";

View File

@ -1,5 +1,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -14,6 +16,10 @@ public class DrawBoundingBoxes extends DynamicCustomOp {
outputArguments.add(output); outputArguments.add(output);
} }
public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) {
super("", sameDiff, new SDVariable[]{boxes, colors});
}
@Override @Override
public String opName() { public String opName() {
return "draw_bounding_boxes"; return "draw_bounding_boxes";

View File

@ -1,5 +1,7 @@
package org.nd4j.linalg.api.ops.custom; package org.nd4j.linalg.api.ops.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
@ -18,6 +20,10 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp {
outputArguments.add(output); outputArguments.add(output);
} }
public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) {
super("", sameDiff, new SDVariable[]{x, min, max});
}
@Override @Override
public String opName() { public String opName() {
return "fake_quant_with_min_max_vars_per_channel"; return "fake_quant_with_min_max_vars_per_channel";