Ops exported for sameDiff
This commit is contained in:
		
							parent
							
								
									96a9a1a733
								
							
						
					
					
						commit
						99d77e1384
					
				| @ -33,6 +33,7 @@ import org.nd4j.linalg.api.blas.params.MMulTranspose; | |||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.NoOp; | import org.nd4j.linalg.api.ops.NoOp; | ||||||
|  | import org.nd4j.linalg.api.ops.custom.*; | ||||||
| import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; | import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; | ||||||
| import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; | import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; | ||||||
| import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; | import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; | ||||||
| @ -2649,6 +2650,33 @@ public class DifferentialFunctionFactory { | |||||||
|         return new NextIteration(sameDiff, x).outputVariable(); |         return new NextIteration(sameDiff, x).outputVariable(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public SDVariable adjustContrast(SDVariable in, SDVariable factor) { | ||||||
|  |         return new AdjustContrast(sameDiff, in, factor).outputVariable(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) { | ||||||
|  |         return new AdjustContrastV2(sameDiff, in, factor).outputVariable(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public SDVariable bitCast(SDVariable in, SDVariable dataType) { | ||||||
|  |         return new BitCast(sameDiff, in, dataType).outputVariable(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public SDVariable compareAndBitpack(SDVariable threshold) { | ||||||
|  |         return new CompareAndBitpack(sameDiff, threshold).outputVariable(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public SDVariable divideNoNan(SDVariable in1, SDVariable in2) { | ||||||
|  |         return new DivideNoNan(sameDiff, in1, in2).outputVariable(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) { | ||||||
|  |         return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max) { | ||||||
|  |         return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max).outputVariable(); | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     public String toString() { |     public String toString() { | ||||||
|         return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; |         return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.base.Preconditions; | import org.nd4j.base.Preconditions; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| @ -12,6 +14,10 @@ public class AdjustContrast extends BaseAdjustContrast { | |||||||
|         super(in, factor, out); |         super(in, factor, out); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public AdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor) { | ||||||
|  |         super(sameDiff,new SDVariable[]{in,factor}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public String opName() { |     public String opName() { | ||||||
|         return "adjust_contrast"; |         return "adjust_contrast"; | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.base.Preconditions; | import org.nd4j.base.Preconditions; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| @ -12,6 +14,10 @@ public class AdjustContrastV2 extends BaseAdjustContrast { | |||||||
|         super(in, factor, out); |         super(in, factor, out); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public AdjustContrastV2(SameDiff sameDiff, SDVariable in, SDVariable factor) { | ||||||
|  |         super( sameDiff,new SDVariable[]{in,factor}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public String opName() { |     public String opName() { | ||||||
|         return "adjust_contrast_v2"; |         return "adjust_contrast_v2"; | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.base.Preconditions; | import org.nd4j.base.Preconditions; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| @ -16,4 +18,12 @@ public abstract class BaseAdjustContrast extends DynamicCustomOp { | |||||||
| 
 | 
 | ||||||
|         addTArgument(factor); |         addTArgument(factor); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     public BaseAdjustContrast(SameDiff sameDiff, SDVariable[] vars) { | ||||||
|  |         super("", sameDiff, vars); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public BaseAdjustContrast(SameDiff sameDiff, SDVariable in, SDVariable factor, SDVariable out) { | ||||||
|  |         super(null, sameDiff, new SDVariable[]{in, factor, out}); | ||||||
|  |     } | ||||||
| } | } | ||||||
| @ -1,5 +1,7 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| @ -14,6 +16,10 @@ public class BitCast extends DynamicCustomOp { | |||||||
|         iArguments.add(Long.valueOf(dataType)); |         iArguments.add(Long.valueOf(dataType)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public BitCast(SameDiff sameDiff, SDVariable in, SDVariable dataType) { | ||||||
|  |         super("", sameDiff, new SDVariable[]{in, dataType}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public String opName() { |     public String opName() { | ||||||
|         return "bitcast"; |         return "bitcast"; | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| import org.nd4j.linalg.factory.Nd4j; | import org.nd4j.linalg.factory.Nd4j; | ||||||
| @ -13,6 +15,10 @@ public class CompareAndBitpack extends DynamicCustomOp { | |||||||
|         outputArguments.add(out); |         outputArguments.add(out); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public CompareAndBitpack(SameDiff sameDiff, SDVariable threshold) { | ||||||
|  |         super("", sameDiff, new SDVariable[]{threshold}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public String opName() { |     public String opName() { | ||||||
|         return "compare_and_bitpack"; |         return "compare_and_bitpack"; | ||||||
|  | |||||||
| @ -1,6 +1,8 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
| import org.apache.commons.math3.analysis.function.Divide; | import org.apache.commons.math3.analysis.function.Divide; | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| 
 | 
 | ||||||
| @ -14,6 +16,10 @@ public class DivideNoNan extends DynamicCustomOp { | |||||||
|         outputArguments.add(out); |         outputArguments.add(out); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public DivideNoNan(SameDiff sameDiff, SDVariable in1, SDVariable in2) { | ||||||
|  |         super("", sameDiff, new SDVariable[]{in1, in2}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public String opName() { |     public String opName() { | ||||||
|         return "divide_no_nan"; |         return "divide_no_nan"; | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| 
 | 
 | ||||||
| @ -14,6 +16,10 @@ public class DrawBoundingBoxes extends DynamicCustomOp { | |||||||
|         outputArguments.add(output); |         outputArguments.add(output); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public DrawBoundingBoxes(SameDiff sameDiff, SDVariable boxes, SDVariable colors) { | ||||||
|  |         super("", sameDiff, new SDVariable[]{boxes, colors}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public String opName() { |     public String opName() { | ||||||
|         return "draw_bounding_boxes"; |         return "draw_bounding_boxes"; | ||||||
|  | |||||||
| @ -1,5 +1,7 @@ | |||||||
| package org.nd4j.linalg.api.ops.custom; | package org.nd4j.linalg.api.ops.custom; | ||||||
| 
 | 
 | ||||||
|  | import org.nd4j.autodiff.samediff.SDVariable; | ||||||
|  | import org.nd4j.autodiff.samediff.SameDiff; | ||||||
| import org.nd4j.base.Preconditions; | import org.nd4j.base.Preconditions; | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; | import org.nd4j.linalg.api.ndarray.INDArray; | ||||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||||
| @ -18,6 +20,10 @@ public class FakeQuantWithMinMaxVarsPerChannel extends DynamicCustomOp { | |||||||
|         outputArguments.add(output); |         outputArguments.add(output); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public FakeQuantWithMinMaxVarsPerChannel(SameDiff sameDiff, SDVariable x, SDVariable min, SDVariable max) { | ||||||
|  |         super("", sameDiff, new SDVariable[]{x, min, max}); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Override |     @Override | ||||||
|     public String opName() { |     public String opName() { | ||||||
|         return "fake_quant_with_min_max_vars_per_channel"; |         return "fake_quant_with_min_max_vars_per_channel"; | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user