SameDiff fixes and naming (#78)

* remove SDVariable inplace methods

* import methods

* npe fix in OpVal

* removed SameDiff inplace ops from tests

* Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything

* quick fixes

* javadoc

* SDVariable eval with placeholders

* use regex match

* better matching
master
Ryan Nett 2019-07-25 00:25:30 -07:00 committed by AlexDBlack
parent ce0743da17
commit ac321265a7
7 changed files with 201 additions and 350 deletions

View File

@ -657,19 +657,7 @@ public abstract class DifferentialFunction {
if(sameDiff == null) if(sameDiff == null)
this.ownName = UUID.randomUUID().toString(); this.ownName = UUID.randomUUID().toString();
else { else {
int argIndex = 0; this.ownName = sameDiff.getOpName(opName());
String scope = sameDiff.currentNameScope();
if(scope == null)
scope = "";
else
scope = scope + "/";
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex);
while(sameDiff.functionExists(varName)) {
varName = scope + sameDiff.generateNewVarName(opName(), argIndex);
argIndex++;
}
this.ownName = varName;
} }
if(sameDiff != null && !(this instanceof SDVariable)) if(sameDiff != null && !(this instanceof SDVariable))

View File

@ -91,10 +91,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
" SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme); " SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme);
Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
String nameScope = sameDiff.currentNameScope(); varName = sameDiff.generateNewVarName(varName, 0, true);
if(nameScope != null && !varName.startsWith(nameScope + "/")){
varName = nameScope + "/" + varName;
}
this.varName = varName; this.varName = varName;
this.variableType = varType; this.variableType = varType;
@ -656,7 +653,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #add(String, double)} * See {@link #add(String, double)}
*/ */
public SDVariable add(double scalar) { public SDVariable add(double scalar) {
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),scalar); return add(null,scalar);
} }
/** /**
@ -676,7 +673,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #add(String, SDVariable)} * See {@link #add(String, SDVariable)}
*/ */
public SDVariable add(SDVariable other) { public SDVariable add(SDVariable other) {
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),other); return add(null,other);
} }
/** /**
@ -713,7 +710,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #sub(String, double)} * See {@link #sub(String, double)}
*/ */
public SDVariable sub(double scalar) { public SDVariable sub(double scalar) {
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),scalar); return sub(null,scalar);
} }
/** /**
@ -733,7 +730,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #sub(String, SDVariable)} * See {@link #sub(String, SDVariable)}
*/ */
public SDVariable sub(SDVariable x) { public SDVariable sub(SDVariable x) {
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),x); return sub(null,x);
} }
/** /**
@ -770,7 +767,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #div(String,double)} * See {@link #div(String,double)}
*/ */
public SDVariable div(double scalar) { public SDVariable div(double scalar) {
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),scalar); return div(null,scalar);
} }
/** /**
@ -790,7 +787,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #div(String, SDVariable)} * See {@link #div(String, SDVariable)}
*/ */
public SDVariable div(SDVariable x) { public SDVariable div(SDVariable x) {
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),x); return div(null,x);
} }
/** /**
@ -811,7 +808,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #mul(String, double)} * See {@link #mul(String, double)}
*/ */
public SDVariable mul(double scalar) { public SDVariable mul(double scalar) {
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),scalar); return mul(null,scalar);
} }
/** /**
@ -832,7 +829,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #mul(String, SDVariable)} * See {@link #mul(String, SDVariable)}
*/ */
public SDVariable mul(SDVariable x) { public SDVariable mul(SDVariable x) {
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),x); return mul(null,x);
} }
/** /**
@ -889,7 +886,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #rsub(String, double)} * See {@link #rsub(String, double)}
*/ */
public SDVariable rsub(double scalar) { public SDVariable rsub(double scalar) {
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),scalar); return rsub(null,scalar);
} }
/** /**
@ -909,7 +906,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #rsub(String, SDVariable)} * See {@link #rsub(String, SDVariable)}
*/ */
public SDVariable rsub(SDVariable x) { public SDVariable rsub(SDVariable x) {
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),x); return rsub(null,x);
} }
/** /**
@ -930,7 +927,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #rdiv(String, double)} * See {@link #rdiv(String, double)}
*/ */
public SDVariable rdiv(double scalar) { public SDVariable rdiv(double scalar) {
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),scalar); return rdiv(null,scalar);
} }
/** /**
@ -950,7 +947,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
* See {@link #rdiv(String, SDVariable)} * See {@link #rdiv(String, SDVariable)}
*/ */
public SDVariable rdiv(SDVariable sameDiffVariable) { public SDVariable rdiv(SDVariable sameDiffVariable) {
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable); return rdiv(null,sameDiffVariable);
} }
/** /**
@ -968,134 +965,13 @@ public class SDVariable extends DifferentialFunction implements Serializable {
} }
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(double sameDiffVariable) {
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(double sameDiffVariable) {
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(double sameDiffVariable) {
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(double sameDiffVariable) {
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(double sameDiffVariable) {
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(double sameDiffVariable) {
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
}
/** /**
* *
* @param sameDiffVariable * @param sameDiffVariable
* @return * @return
*/ */
public SDVariable truncatedDiv(SDVariable sameDiffVariable) { public SDVariable truncatedDiv(SDVariable sameDiffVariable) {
return truncatedDiv(sameDiff.generateNewVarName(TruncateDivOp.OP_NAME,0),sameDiffVariable); return truncatedDiv(null,sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(SDVariable sameDiffVariable) {
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(SDVariable sameDiffVariable) {
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(SDVariable sameDiffVariable) {
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(SDVariable sameDiffVariable) {
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(SDVariable sameDiffVariable) {
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(SDVariable sameDiffVariable) {
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
} }
@ -1112,154 +988,17 @@ public class SDVariable extends DifferentialFunction implements Serializable {
} }
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(String varName, double sameDiffVariable) {
val function = sameDiff.f().rsubi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(String varName, double sameDiffVariable) {
SDVariable function = sameDiff.f().rdivi(this
,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(String varName, double sameDiffVariable) {
val function = sameDiff.f().addi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(String varName, double sameDiffVariable) {
val function = sameDiff.f().subi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(String varName, double sameDiffVariable) {
val function = sameDiff.f().divi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(String varName, double sameDiffVariable) {
val function = sameDiff.f().muli(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(function,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rsubi(String varName, SDVariable sameDiffVariable) {
val result = sameDiff.f().rsubi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable rdivi(String varName, SDVariable sameDiffVariable) {
val result = sameDiff.f().rdivi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable addi(String varName, SDVariable sameDiffVariable) {
val result = sameDiff.f().addi(this,sameDiffVariable);
return sameDiff.updateVariableNameAndReference(result,varName);
}
@Override @Override
public Op.Type opType() { public Op.Type opType() {
return Op.Type.RETURN; return Op.Type.RETURN;
} }
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable subi(String varName, SDVariable sameDiffVariable) {
SDVariable left = this;
SDVariable right = sameDiffVariable;
val result = sameDiff.f().subi(left,right);
return sameDiff.updateVariableNameAndReference(result,varName);
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable divi(String varName, SDVariable sameDiffVariable) {
val result = sameDiff.f().divi(this,sameDiffVariable);
result.setVarName(varName);
return result;
}
/**
*
* @param sameDiffVariable
* @return
*/
public SDVariable muli(String varName, SDVariable sameDiffVariable) {
SDVariable left = this;
SDVariable right = sameDiffVariable;
SDVariable result = sameDiff.f().muli(left,right);
result.setVarName(varName);
return result;
}
/** /**
* See {@link #squaredDifference(String, SDVariable)} * See {@link #squaredDifference(String, SDVariable)}
*/ */
public SDVariable squaredDifference(SDVariable x) { public SDVariable squaredDifference(SDVariable x) {
return squaredDifference(sameDiff.generateNewVarName(SquaredDifferenceOp.OP_NAME,0),x); return squaredDifference(null,x);
} }
/** /**
@ -1780,6 +1519,16 @@ public class SDVariable extends DifferentialFunction implements Serializable {
} }
/**
* Evaluate the result of this variable
* @return
*/
public INDArray eval(Map<String, INDArray> placeholders) {
sameDiff.exec(placeholders, getVarName());
return getArr();
}
@Override @Override
public String toString() { public String toString() {
return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType +

View File

@ -22,6 +22,8 @@ import com.google.common.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import com.rits.cloning.Cloner; import com.rits.cloning.Cloner;
import com.rits.cloning.IFastCloner; import com.rits.cloning.IFastCloner;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
@ -43,6 +45,7 @@ import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.*; import org.nd4j.graph.*;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
@ -97,6 +100,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.ZipEntry; import java.util.zip.ZipEntry;
import java.util.zip.ZipFile; import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream; import java.util.zip.ZipOutputStream;
import org.tensorflow.framework.GraphDef;
/** /**
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality. * SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
@ -2498,9 +2502,14 @@ public class SameDiff extends SDBaseOps {
//TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API!
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
String withScope = nameWithScope(name);
if (variables.containsKey(withScope)) {
if (name == null || name.length() < 1)
name = getNewVarName();
else
name = generateNewVarName(name, 0);
if (variables.containsKey(name)) {
if(nameScopes.isEmpty()){ if(nameScopes.isEmpty()){
throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \""
+ currentNameScope() + "\""); + currentNameScope() + "\"");
@ -2509,9 +2518,6 @@ public class SameDiff extends SDBaseOps {
} }
} }
if (name == null || name.length() < 1)
name = getNewVarName();
SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme); SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme);
addVariable(ret); addVariable(ret);
@ -2647,12 +2653,7 @@ public class SameDiff extends SDBaseOps {
} }
private String getNewVarName() { private String getNewVarName() {
String varName = "sd_var_" + String.valueOf(variableId); return generateNewVarName("sd_var", 0, false);
while (variables.containsKey(varName)) {
variableId++;
varName = "sd_var_" + String.valueOf(variableId);
}
return varName;
} }
/** /**
@ -3442,37 +3443,6 @@ public class SameDiff extends SDBaseOps {
} }
/**
* Generate a new variable name based on the uniqueness of the base name and arg index<br>
* For example, if baseName = "X" will return:<br>
* "X" if "X" does not already exist, or "X:argIndex" if argIndex > 0<br>
* "X_1" if "X" already exists, or "X_1:argIndex" if argIndex > 0<br>
* "X_2" if "X" and "X_1" already exists, or "X_2:argIndex" if argIndex > 0<br>
* And so on, until an unused name is found
*
* @param baseName the base name to use (use function.opName() where function is a {@link DifferentialFunction}
* @param argIndex the arg index
* @return the new generated name
*/
public String generateNewVarName(String baseName, int argIndex) {
if (!variables.containsKey(baseName) && argIndex == 0) {
return baseName;
}
//need to find a new name
int count = 0;
String name = baseName + (count == 0 ? "" : "_" + count) + (argIndex > 0 ? ":" + argIndex : "");
while (getVariable(name) != null) {
name = baseName + "_" + (++count) + (argIndex > 0 ? ":" + argIndex : "");
}
if (getVariable(name) != null) {
throw new ND4JIllegalStateException("Converged on already generated variable!");
}
return name;
}
/** /**
* Generate the variables based on the given input op and return the output variable names. * Generate the variables based on the given input op and return the output variable names.
* *
@ -3486,6 +3456,9 @@ public class SameDiff extends SDBaseOps {
if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null) if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
baseName = getBaseNameForFunction(function); baseName = getBaseNameForFunction(function);
if (baseName == null)
baseName = function.getOwnName();
if (baseName == null) if (baseName == null)
baseName = function.opName(); baseName = function.opName();
@ -3594,7 +3567,8 @@ public class SameDiff extends SDBaseOps {
* @return the set of names generated for each output of the function. * @return the set of names generated for each output of the function.
*/ */
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) { public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
return generateOutputVariableForOp(function, function.opName(), false); return generateOutputVariableForOp(function,
function.getOwnName() != null ? function.getOwnName() : function.opName(), false);
} }
/** /**
@ -4391,11 +4365,21 @@ public class SameDiff extends SDBaseOps {
throw new NullPointerException("Null input: No variable found for updating!"); throw new NullPointerException("Null input: No variable found for updating!");
} }
if(newVarName != null) {
String nameScope = currentNameScope();
if (nameScope != null) {
if (!newVarName.startsWith(nameScope + "/")) {
newVarName = nameScope + "/" + newVarName;
}
}
}
if(newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()){ if(newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()){
throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable"); throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable");
} }
if (newVarName == null && variables.containsKey(varToUpdate.getVarName())) { if (newVarName == null && variables.containsKey(varToUpdate.getVarName())
&& variables.get(varToUpdate.getVarName()).getVariable() != varToUpdate) {
//Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name //Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name
// "mean" and consequently a new variable name needs to be generated // "mean" and consequently a new variable name needs to be generated
newVarName = generateNewVarName(varToUpdate.getVarName(), 0); newVarName = generateNewVarName(varToUpdate.getVarName(), 0);
@ -4405,13 +4389,6 @@ public class SameDiff extends SDBaseOps {
return varToUpdate; return varToUpdate;
} }
String nameScope = currentNameScope();
if(nameScope != null){
if(!newVarName.startsWith(nameScope)){
newVarName = nameScope + "/" + newVarName;
}
}
val oldVarName = varToUpdate.getVarName(); val oldVarName = varToUpdate.getVarName();
varToUpdate.setVarName(newVarName); varToUpdate.setVarName(newVarName);
updateVariableName(oldVarName, newVarName); updateVariableName(oldVarName, newVarName);
@ -5680,4 +5657,130 @@ public class SameDiff extends SDBaseOps {
} }
} }
/**
* Import a frozen Tensorflow graph to a new SameDiff graph.
*
* @param graphFile The text or binary file containing the graph
* @return The imported graph
*/
public static SameDiff importFrozenTF(File graphFile){
return TFGraphMapper.getInstance().importGraph(graphFile);
}
/**
* See {@link #importFrozenTF(File)}
*/
public static SameDiff importFrozenTF(GraphDef graphDef){
return TFGraphMapper.getInstance().importGraph(graphDef);
}
/**
* See {@link #importFrozenTF(File)}
*
* Again, the input can be text or binary.
*/
public static SameDiff importFrozenTF(InputStream graph){
return TFGraphMapper.getInstance().importGraph(graph);
}
/**
* Generate a new, distinct op name of the form &lt;base&gt;_#.
*
* Applies name scope if active.
*
* @param base The base name to use
* @param force Whether to force the result name to be the same as base.
*/
public String getOpName(String base, boolean force){
base = nameWithScope(base);
if(force && ops.containsKey(base))
throw new IllegalArgumentException("Op with name \"" + base + "\" already exists");
else if(force)
return base;
int start = 1;
// if we already have a name like "op_2", start from trying "op_3"
if(base.contains("_")){
// extract number used to generate base
Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base);
// extract argIndex used to generate base
if(num.find()) {
start = Integer.parseInt(num.group(2));
base = num.group(1);
}
}
String name = base;
for(int i = start ; true ; i++) {
// ensure that there are no variables that look like they are outputs of this op
boolean varWithName = false;
for(String varName : variables.keySet())
if(varName.startsWith(name + ":") || varName.equals(name))
varWithName = true;
if(!ops.containsKey(name) && !varWithName)
break;
name = base + "_" + i;
}
return name;
}
/**
* See {@link #getOpName(String, boolean)}
* force is false
*/
public String getOpName(String base){
return getOpName(base, false);
}
/**
* Generate a new, distinct variable name of the form &lt;base&gt;_#[:#].
*
* Applies name scopes if active.
*
* @param base The base of the name.
* @param argIndex The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part.
* @param existingOp Whether to generate an distinct operation name from base (if false), or just use base (if true).
*/
public String generateNewVarName(String base, int argIndex, boolean existingOp){
base = nameWithScope(base);
if(argIndex > 0 && base.contains(":")){
Matcher num = Pattern.compile("(.*):(\\d+)").matcher(base);
// extract argIndex used to generate base
if(num.find()) {
argIndex = Integer.parseInt(num.group(2)) + 1;
base = num.group(1);
}
}
if(!existingOp)
base = getOpName(base);
if(argIndex > 0)
base += ":" + argIndex;
if(variables.containsKey(base))
throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists");
return base;
}
/**
*
* See {@link #generateNewVarName(String, int, boolean)}
* existingOp is true.
*/
@Override
public String generateNewVarName(String base, int argIndex){
return generateNewVarName(base, argIndex, true);
}
} }

View File

@ -471,7 +471,10 @@ public class OpValidation {
backpropSeen.add(df.getClass()); backpropSeen.add(df.getClass());
} }
for (Class c : backpropSeen) { for (Class c : backpropSeen) {
gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1); if(gradCheckCoverageCountPerClass.containsKey(c))
gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1);
else
gradCheckCoverageCountPerClass.put(c, 1);
} }
//Collect coverage information for forward pass (expected outputs) //Collect coverage information for forward pass (expected outputs)
@ -491,15 +494,23 @@ public class OpValidation {
if (seen != null) { if (seen != null) {
for (Class c : seen) { for (Class c : seen) {
fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1); if(fwdPassCoverageCountPerClass.containsKey(c)) {
fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1);
} else {
fwdPassCoverageCountPerClass.put(c, 1);
}
} }
} }
} }
private static void collectCoverageInformation(OpTestCase testCase) { private static void collectCoverageInformation(OpTestCase testCase) {
//TODO we're basically assuming subtypes of DynamicCustomOp here, for coverage... not DCO itself //TODO we're basically assuming subtypes of DynamicCustomOp here, for coverage... not DCO itself
singleOpTestCountPerClass.put(testCase.op().getClass(), if(singleOpTestCountPerClass.containsKey(testCase.op().getClass())) {
singleOpTestCountPerClass.get(testCase.op().getClass()) + 1); singleOpTestCountPerClass.put(testCase.op().getClass(),
singleOpTestCountPerClass.get(testCase.op().getClass()) + 1);
} else {
singleOpTestCountPerClass.put(testCase.op().getClass(), 1);
}
} }

View File

@ -70,7 +70,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray ones = Nd4j.ones(4); INDArray ones = Nd4j.ones(4);
SDVariable sdVariable = sameDiff.var("ones",ones); SDVariable sdVariable = sameDiff.var("ones",ones);
SDVariable result = sdVariable.addi(1.0); SDVariable result = sdVariable.add(1.0);
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
val executioner = new NativeGraphExecutioner(); val executioner = new NativeGraphExecutioner();
@ -167,7 +167,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
INDArray ones = Nd4j.ones(4); INDArray ones = Nd4j.ones(4);
SDVariable sdVariable = sameDiff.var("ones",ones); SDVariable sdVariable = sameDiff.var("ones",ones);
SDVariable result = sdVariable.addi(1.0); SDVariable result = sdVariable.add(1.0);
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
val executioner = new NativeGraphExecutioner(); val executioner = new NativeGraphExecutioner();

View File

@ -91,7 +91,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}, new SameDiffFunctionDefinition() { }, new SameDiffFunctionDefinition() {
@Override @Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
SDVariable ret = variableInputs[1].addi(1.0); SDVariable ret = variableInputs[1].add(1.0);
return new SDVariable[]{variableInputs[0], ret}; return new SDVariable[]{variableInputs[0], ret};
} }
}, new SDVariable[]{ }, new SDVariable[]{
@ -116,7 +116,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
}, new SameDiffFunctionDefinition() { }, new SameDiffFunctionDefinition() {
@Override @Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) { public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
SDVariable ret = variableInputs[1].addi(1.0); SDVariable ret = variableInputs[1].add(1.0);
return new SDVariable[]{variableInputs[0], ret}; return new SDVariable[]{variableInputs[0], ret};
} }
}, new SDVariable[]{ }, new SDVariable[]{
@ -197,7 +197,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
SDVariable w = sd.var("w", Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(4,5)); SDVariable w = sd.var("w", Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(4,5));
SDVariable b = sd.var("b", Nd4j.linspace(1,5,5, DataType.DOUBLE).reshape(1,5)); SDVariable b = sd.var("b", Nd4j.linspace(1,5,5, DataType.DOUBLE).reshape(1,5));
SDVariable mmul = sd.mmul(in,w).addi(b); SDVariable mmul = sd.mmul(in,w).add(b);
INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
INDArray out = sd.execAndEndResult(); INDArray out = sd.execAndEndResult();

View File

@ -144,7 +144,7 @@ public class NameScopeTests extends BaseNd4jTest {
scope.close(); scope.close();
assertTrue("Var with name test/imax_1 exists", SD.variableMap().containsKey("test/imax_1")); assertTrue("Var with name test/imax exists", SD.variableMap().containsKey("test/imax"));
} }
@Test @Test