Remove resolvePropertiesFromSameDiffBeforeExecution() (#172)
* remove unneeded resolveProperties methods Signed-off-by: Ryan Nett <rnett@skymind.io> * final fixes, make final to prevent more from being added Signed-off-by: Ryan Nett <rnett@skymind.io> * gather fix Signed-off-by: Ryan Nett <rnett@skymind.io> * deprecate DifferentialFunction resolveProps Signed-off-by: Ryan Nett <rnett@skymind.io>master
parent
b472d7d8c8
commit
d31197db5f
|
@ -547,8 +547,11 @@ public abstract class DifferentialFunction {
|
||||||
/**
|
/**
|
||||||
* Resolve properties and arguments right before execution of
|
* Resolve properties and arguments right before execution of
|
||||||
* this operation.
|
* this operation.
|
||||||
|
*
|
||||||
|
* @deprecated Will be removed in the future. Ops should support array arguments. Should not bs used or overridden.
|
||||||
*/
|
*/
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
@Deprecated
|
||||||
|
public final void resolvePropertiesFromSameDiffBeforeExecution() {
|
||||||
val properties = sameDiff.propertiesToResolveForFunction(this);
|
val properties = sameDiff.propertiesToResolveForFunction(this);
|
||||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
|
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
|
||||||
val currentFields = this.propertiesForFunction();
|
val currentFields = this.propertiesForFunction();
|
||||||
|
|
|
@ -268,14 +268,6 @@ public class Conv3D extends DynamicCustomOp {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if (numIArguments() < 1) {
|
|
||||||
addArgs();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isConfigProperties() {
|
public boolean isConfigProperties() {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -66,11 +66,6 @@ public class ConfusionMatrix extends DynamicCustomOp {
|
||||||
//Looks like this is implemented in practice using a large collection of discrete ops - not single TF import op?
|
//Looks like this is implemented in practice using a large collection of discrete ops - not single TF import op?
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "confusion_matrix";
|
return "confusion_matrix";
|
||||||
|
|
|
@ -74,7 +74,6 @@ public class Gather extends DynamicCustomOp {
|
||||||
@Override
|
@Override
|
||||||
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -82,30 +81,6 @@ public class Gather extends DynamicCustomOp {
|
||||||
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
// super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
if (indices != null && numInputArguments() < 2) {
|
|
||||||
if (numInputArguments() == 0) {
|
|
||||||
INDArray a = Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT);
|
|
||||||
if (indices.length > 1)
|
|
||||||
a = a.reshape(indices.length);
|
|
||||||
else
|
|
||||||
a = a.reshape(new int[]{});
|
|
||||||
|
|
||||||
addInputArgument(args()[0].getArr(), a);
|
|
||||||
} else if (numInputArguments() == 1) {
|
|
||||||
addInputArgument(Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if (numIArguments() < 1) {
|
|
||||||
addIArgument(jaxis);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
||||||
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
|
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();
|
||||||
|
|
|
@ -110,15 +110,6 @@ public class Repeat extends DynamicCustomOp {
|
||||||
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
super.initFromOnnx(node, initWith, attributesForNode, graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if (numOutputArguments() < getDescriptor().getNumOutputs()) {
|
|
||||||
for (val output : outputVariables()) {
|
|
||||||
addOutputArgument(output.getArr());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String onnxName() {
|
public String onnxName() {
|
||||||
return "Repeat";
|
return "Repeat";
|
||||||
|
|
|
@ -41,6 +41,7 @@ public class Squeeze extends DynamicCustomOp {
|
||||||
public Squeeze(SameDiff sameDiff, SDVariable arg, int[] squeezeDims) {
|
public Squeeze(SameDiff sameDiff, SDVariable arg, int[] squeezeDims) {
|
||||||
super(null, sameDiff, new SDVariable[]{arg});
|
super(null, sameDiff, new SDVariable[]{arg});
|
||||||
this.squeezeDims = squeezeDims;
|
this.squeezeDims = squeezeDims;
|
||||||
|
addIArgument(squeezeDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -53,14 +54,6 @@ public class Squeeze extends DynamicCustomOp {
|
||||||
addIArgument(squeezeDims);
|
addIArgument(squeezeDims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
if (squeezeDims != null && numIArguments() < squeezeDims.length) {
|
|
||||||
addIArgument(squeezeDims);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "squeeze";
|
return "squeeze";
|
||||||
|
|
|
@ -65,11 +65,6 @@ public class Transpose extends DynamicCustomOp {
|
||||||
public Transpose() {
|
public Transpose() {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
|
||||||
Map<String, Map<String, PropertyMapping>> ret = new LinkedHashMap<>();
|
Map<String, Map<String, PropertyMapping>> ret = new LinkedHashMap<>();
|
||||||
|
|
|
@ -70,16 +70,6 @@ public class HistogramFixedWidth extends DynamicCustomOp {
|
||||||
//No op - just need the inputs
|
//No op - just need the inputs
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3 && iArguments.isEmpty()){
|
|
||||||
//Num bins is 3rd array
|
|
||||||
addIArgument(arg(2).getArr().getInt(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
throw new UnsupportedOperationException("Not supported");
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
|
|
@ -82,17 +82,6 @@ public class Pad extends DynamicCustomOp {
|
||||||
//Constant value is resolved just before execution
|
//Constant value is resolved just before execution
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3){
|
|
||||||
INDArray arr = arg(2).getArr();
|
|
||||||
this.tArguments.clear();
|
|
||||||
this.tArguments.add(arr.getDouble(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
//Pad backprop: it's basically slice op...
|
//Pad backprop: it's basically slice op...
|
||||||
|
|
|
@ -54,14 +54,6 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
|
||||||
return "UnsortedSegmentMax";
|
return "UnsortedSegmentMax";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
|
|
||||||
addIArgument(arg(2).getArr().getInt(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments));
|
return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments));
|
||||||
|
|
|
@ -54,14 +54,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
|
||||||
return "UnsortedSegmentMean";
|
return "UnsortedSegmentMean";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
|
|
||||||
addIArgument(arg(2).getArr().getInt(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments));
|
return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments));
|
||||||
|
|
|
@ -54,14 +54,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
|
||||||
return "UnsortedSegmentMin";
|
return "UnsortedSegmentMin";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
|
|
||||||
addIArgument(arg(2).getArr().getInt(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments));
|
return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments));
|
||||||
|
|
|
@ -54,14 +54,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
|
||||||
return "UnsortedSegmentProd";
|
return "UnsortedSegmentProd";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
|
|
||||||
addIArgument(arg(2).getArr().getInt(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments));
|
return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments));
|
||||||
|
|
|
@ -53,14 +53,6 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
|
||||||
return "UnsortedSegmentSqrtN";
|
return "UnsortedSegmentSqrtN";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
|
|
||||||
addIArgument(arg(2).getArr().getInt(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments));
|
return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments));
|
||||||
|
|
|
@ -55,14 +55,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
|
||||||
return "UnsortedSegmentSum";
|
return "UnsortedSegmentSum";
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void resolvePropertiesFromSameDiffBeforeExecution() {
|
|
||||||
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
|
|
||||||
addIArgument(arg(2).getArr().getInt(0));
|
|
||||||
}
|
|
||||||
super.resolvePropertiesFromSameDiffBeforeExecution();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
public List<SDVariable> doDiff(List<SDVariable> gradients){
|
||||||
return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments));
|
return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments));
|
||||||
|
|
Loading…
Reference in New Issue