Remove resolvePropertiesFromSameDiffBeforeExecution() (#172)

* remove unneeded resolveProperties methods

Signed-off-by: Ryan Nett <rnett@skymind.io>

* final fixes, make final to prevent more from being added

Signed-off-by: Ryan Nett <rnett@skymind.io>

* gather fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* deprecate DifferentialFunction resolveProps

Signed-off-by: Ryan Nett <rnett@skymind.io>
master
Ryan Nett 2019-08-27 18:02:41 -07:00 committed by Alex Black
parent b472d7d8c8
commit d31197db5f
15 changed files with 5 additions and 130 deletions

View File

@ -547,8 +547,11 @@ public abstract class DifferentialFunction {
/**
* Resolve properties and arguments right before execution of
* this operation.
*
* @deprecated Will be removed in the future. Ops should support array arguments. Should not bs used or overridden.
*/
public void resolvePropertiesFromSameDiffBeforeExecution() {
@Deprecated
public final void resolvePropertiesFromSameDiffBeforeExecution() {
val properties = sameDiff.propertiesToResolveForFunction(this);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
val currentFields = this.propertiesForFunction();

View File

@ -268,14 +268,6 @@ public class Conv3D extends DynamicCustomOp {
return ret;
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if (numIArguments() < 1) {
addArgs();
}
}
@Override
public boolean isConfigProperties() {
return true;

View File

@ -66,11 +66,6 @@ public class ConfusionMatrix extends DynamicCustomOp {
//Looks like this is implemented in practice using a large collection of discrete ops - not single TF import op?
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public String opName() {
return "confusion_matrix";

View File

@ -74,7 +74,6 @@ public class Gather extends DynamicCustomOp {
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
}
@Override
@ -82,30 +81,6 @@ public class Gather extends DynamicCustomOp {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
// super.resolvePropertiesFromSameDiffBeforeExecution();
if (indices != null && numInputArguments() < 2) {
if (numInputArguments() == 0) {
INDArray a = Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT);
if (indices.length > 1)
a = a.reshape(indices.length);
else
a = a.reshape(new int[]{});
addInputArgument(args()[0].getArr(), a);
} else if (numInputArguments() == 1) {
addInputArgument(Nd4j.create(indices, new long[]{indices.length}, new long[]{1}, 'c', DataType.INT));
}
}
if (numIArguments() < 1) {
addIArgument(jaxis);
}
}
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new HashMap<>();

View File

@ -110,15 +110,6 @@ public class Repeat extends DynamicCustomOp {
super.initFromOnnx(node, initWith, attributesForNode, graph);
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if (numOutputArguments() < getDescriptor().getNumOutputs()) {
for (val output : outputVariables()) {
addOutputArgument(output.getArr());
}
}
}
@Override
public String onnxName() {
return "Repeat";

View File

@ -41,6 +41,7 @@ public class Squeeze extends DynamicCustomOp {
public Squeeze(SameDiff sameDiff, SDVariable arg, int[] squeezeDims) {
super(null, sameDiff, new SDVariable[]{arg});
this.squeezeDims = squeezeDims;
addIArgument(squeezeDims);
}
@Override
@ -53,14 +54,6 @@ public class Squeeze extends DynamicCustomOp {
addIArgument(squeezeDims);
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
super.resolvePropertiesFromSameDiffBeforeExecution();
if (squeezeDims != null && numIArguments() < squeezeDims.length) {
addIArgument(squeezeDims);
}
}
@Override
public String opName() {
return "squeeze";

View File

@ -65,11 +65,6 @@ public class Transpose extends DynamicCustomOp {
public Transpose() {
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
Map<String, Map<String, PropertyMapping>> ret = new LinkedHashMap<>();

View File

@ -70,16 +70,6 @@ public class HistogramFixedWidth extends DynamicCustomOp {
//No op - just need the inputs
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3 && iArguments.isEmpty()){
//Num bins is 3rd array
addIArgument(arg(2).getArr().getInt(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");

View File

@ -82,17 +82,6 @@ public class Pad extends DynamicCustomOp {
//Constant value is resolved just before execution
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3){
INDArray arr = arg(2).getArr();
this.tArguments.clear();
this.tArguments.add(arr.getDouble(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
//Pad backprop: it's basically slice op...

View File

@ -54,14 +54,6 @@ public class UnsortedSegmentMax extends DynamicCustomOp {
return "UnsortedSegmentMax";
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
addIArgument(arg(2).getArr().getInt(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients){
return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments));

View File

@ -54,14 +54,6 @@ public class UnsortedSegmentMean extends DynamicCustomOp {
return "UnsortedSegmentMean";
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
addIArgument(arg(2).getArr().getInt(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients){
return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments));

View File

@ -54,14 +54,6 @@ public class UnsortedSegmentMin extends DynamicCustomOp {
return "UnsortedSegmentMin";
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
addIArgument(arg(2).getArr().getInt(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients){
return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments));

View File

@ -54,14 +54,6 @@ public class UnsortedSegmentProd extends DynamicCustomOp {
return "UnsortedSegmentProd";
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
addIArgument(arg(2).getArr().getInt(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients){
return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments));

View File

@ -53,14 +53,6 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp {
return "UnsortedSegmentSqrtN";
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
addIArgument(arg(2).getArr().getInt(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients){
return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments));

View File

@ -55,14 +55,6 @@ public class UnsortedSegmentSum extends DynamicCustomOp {
return "UnsortedSegmentSum";
}
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
if(args().length == 3 && iArguments == null || iArguments.size() == 0){
addIArgument(arg(2).getArr().getInt(0));
}
super.resolvePropertiesFromSameDiffBeforeExecution();
}
@Override
public List<SDVariable> doDiff(List<SDVariable> gradients){
return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments));