SameDiff fixes and naming (#78)
* remove SDVariable inplace methods * import methods * npe fix in OpVal * removed SameDiff inplace ops from tests * Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything * quick fixes * javadoc * SDVariable eval with placeholders * use regex match * better matchingmaster
parent
ce0743da17
commit
ac321265a7
|
@ -657,19 +657,7 @@ public abstract class DifferentialFunction {
|
|||
if(sameDiff == null)
|
||||
this.ownName = UUID.randomUUID().toString();
|
||||
else {
|
||||
int argIndex = 0;
|
||||
String scope = sameDiff.currentNameScope();
|
||||
if(scope == null)
|
||||
scope = "";
|
||||
else
|
||||
scope = scope + "/";
|
||||
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex);
|
||||
while(sameDiff.functionExists(varName)) {
|
||||
varName = scope + sameDiff.generateNewVarName(opName(), argIndex);
|
||||
argIndex++;
|
||||
}
|
||||
|
||||
this.ownName = varName;
|
||||
this.ownName = sameDiff.getOpName(opName());
|
||||
}
|
||||
|
||||
if(sameDiff != null && !(this instanceof SDVariable))
|
||||
|
|
|
@ -91,10 +91,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
" SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme);
|
||||
Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
|
||||
|
||||
String nameScope = sameDiff.currentNameScope();
|
||||
if(nameScope != null && !varName.startsWith(nameScope + "/")){
|
||||
varName = nameScope + "/" + varName;
|
||||
}
|
||||
varName = sameDiff.generateNewVarName(varName, 0, true);
|
||||
|
||||
this.varName = varName;
|
||||
this.variableType = varType;
|
||||
|
@ -656,7 +653,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #add(String, double)}
|
||||
*/
|
||||
public SDVariable add(double scalar) {
|
||||
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),scalar);
|
||||
return add(null,scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -676,7 +673,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #add(String, SDVariable)}
|
||||
*/
|
||||
public SDVariable add(SDVariable other) {
|
||||
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),other);
|
||||
return add(null,other);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -713,7 +710,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #sub(String, double)}
|
||||
*/
|
||||
public SDVariable sub(double scalar) {
|
||||
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),scalar);
|
||||
return sub(null,scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -733,7 +730,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #sub(String, SDVariable)}
|
||||
*/
|
||||
public SDVariable sub(SDVariable x) {
|
||||
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),x);
|
||||
return sub(null,x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -770,7 +767,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #div(String,double)}
|
||||
*/
|
||||
public SDVariable div(double scalar) {
|
||||
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),scalar);
|
||||
return div(null,scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -790,7 +787,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #div(String, SDVariable)}
|
||||
*/
|
||||
public SDVariable div(SDVariable x) {
|
||||
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),x);
|
||||
return div(null,x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -811,7 +808,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #mul(String, double)}
|
||||
*/
|
||||
public SDVariable mul(double scalar) {
|
||||
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),scalar);
|
||||
return mul(null,scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -832,7 +829,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #mul(String, SDVariable)}
|
||||
*/
|
||||
public SDVariable mul(SDVariable x) {
|
||||
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),x);
|
||||
return mul(null,x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -889,7 +886,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #rsub(String, double)}
|
||||
*/
|
||||
public SDVariable rsub(double scalar) {
|
||||
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),scalar);
|
||||
return rsub(null,scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -909,7 +906,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #rsub(String, SDVariable)}
|
||||
*/
|
||||
public SDVariable rsub(SDVariable x) {
|
||||
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),x);
|
||||
return rsub(null,x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -930,7 +927,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #rdiv(String, double)}
|
||||
*/
|
||||
public SDVariable rdiv(double scalar) {
|
||||
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),scalar);
|
||||
return rdiv(null,scalar);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -950,7 +947,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
* See {@link #rdiv(String, SDVariable)}
|
||||
*/
|
||||
public SDVariable rdiv(SDVariable sameDiffVariable) {
|
||||
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
|
||||
return rdiv(null,sameDiffVariable);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -968,134 +965,13 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rsubi(double sameDiffVariable) {
|
||||
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rdivi(double sameDiffVariable) {
|
||||
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable addi(double sameDiffVariable) {
|
||||
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable subi(double sameDiffVariable) {
|
||||
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable divi(double sameDiffVariable) {
|
||||
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable muli(double sameDiffVariable) {
|
||||
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable truncatedDiv(SDVariable sameDiffVariable) {
|
||||
return truncatedDiv(sameDiff.generateNewVarName(TruncateDivOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rsubi(SDVariable sameDiffVariable) {
|
||||
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rdivi(SDVariable sameDiffVariable) {
|
||||
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable addi(SDVariable sameDiffVariable) {
|
||||
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable subi(SDVariable sameDiffVariable) {
|
||||
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable divi(SDVariable sameDiffVariable) {
|
||||
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable muli(SDVariable sameDiffVariable) {
|
||||
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
|
||||
return truncatedDiv(null,sameDiffVariable);
|
||||
|
||||
}
|
||||
|
||||
|
@ -1112,154 +988,17 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rsubi(String varName, double sameDiffVariable) {
|
||||
val function = sameDiff.f().rsubi(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rdivi(String varName, double sameDiffVariable) {
|
||||
SDVariable function = sameDiff.f().rdivi(this
|
||||
,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable addi(String varName, double sameDiffVariable) {
|
||||
val function = sameDiff.f().addi(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable subi(String varName, double sameDiffVariable) {
|
||||
val function = sameDiff.f().subi(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable divi(String varName, double sameDiffVariable) {
|
||||
val function = sameDiff.f().divi(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable muli(String varName, double sameDiffVariable) {
|
||||
val function = sameDiff.f().muli(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rsubi(String varName, SDVariable sameDiffVariable) {
|
||||
val result = sameDiff.f().rsubi(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable rdivi(String varName, SDVariable sameDiffVariable) {
|
||||
val result = sameDiff.f().rdivi(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable addi(String varName, SDVariable sameDiffVariable) {
|
||||
val result = sameDiff.f().addi(this,sameDiffVariable);
|
||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public Op.Type opType() {
|
||||
return Op.Type.RETURN;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable subi(String varName, SDVariable sameDiffVariable) {
|
||||
SDVariable left = this;
|
||||
SDVariable right = sameDiffVariable;
|
||||
val result = sameDiff.f().subi(left,right);
|
||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable divi(String varName, SDVariable sameDiffVariable) {
|
||||
val result = sameDiff.f().divi(this,sameDiffVariable);
|
||||
result.setVarName(varName);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param sameDiffVariable
|
||||
* @return
|
||||
*/
|
||||
public SDVariable muli(String varName, SDVariable sameDiffVariable) {
|
||||
SDVariable left = this;
|
||||
SDVariable right = sameDiffVariable;
|
||||
SDVariable result = sameDiff.f().muli(left,right);
|
||||
result.setVarName(varName);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #squaredDifference(String, SDVariable)}
|
||||
*/
|
||||
public SDVariable squaredDifference(SDVariable x) {
|
||||
return squaredDifference(sameDiff.generateNewVarName(SquaredDifferenceOp.OP_NAME,0),x);
|
||||
return squaredDifference(null,x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1780,6 +1519,16 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Evaluate the result of this variable
|
||||
* @return
|
||||
*/
|
||||
public INDArray eval(Map<String, INDArray> placeholders) {
|
||||
sameDiff.exec(placeholders, getVarName());
|
||||
return getArr();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType +
|
||||
|
|
|
@ -22,6 +22,8 @@ import com.google.common.primitives.Ints;
|
|||
import com.google.flatbuffers.FlatBufferBuilder;
|
||||
import com.rits.cloning.Cloner;
|
||||
import com.rits.cloning.IFastCloner;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
|
@ -43,6 +45,7 @@ import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
|
|||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.graph.*;
|
||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
|
||||
|
@ -97,6 +100,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
import java.util.zip.ZipEntry;
|
||||
import java.util.zip.ZipFile;
|
||||
import java.util.zip.ZipOutputStream;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
|
||||
/**
|
||||
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
||||
|
@ -2498,9 +2502,14 @@ public class SameDiff extends SDBaseOps {
|
|||
//TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API!
|
||||
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
||||
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
||||
String withScope = nameWithScope(name);
|
||||
|
||||
if (variables.containsKey(withScope)) {
|
||||
|
||||
if (name == null || name.length() < 1)
|
||||
name = getNewVarName();
|
||||
else
|
||||
name = generateNewVarName(name, 0);
|
||||
|
||||
if (variables.containsKey(name)) {
|
||||
if(nameScopes.isEmpty()){
|
||||
throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \""
|
||||
+ currentNameScope() + "\"");
|
||||
|
@ -2509,9 +2518,6 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
}
|
||||
|
||||
if (name == null || name.length() < 1)
|
||||
name = getNewVarName();
|
||||
|
||||
|
||||
SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme);
|
||||
addVariable(ret);
|
||||
|
@ -2647,12 +2653,7 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
|
||||
private String getNewVarName() {
|
||||
String varName = "sd_var_" + String.valueOf(variableId);
|
||||
while (variables.containsKey(varName)) {
|
||||
variableId++;
|
||||
varName = "sd_var_" + String.valueOf(variableId);
|
||||
}
|
||||
return varName;
|
||||
return generateNewVarName("sd_var", 0, false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3442,37 +3443,6 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate a new variable name based on the uniqueness of the base name and arg index<br>
|
||||
* For example, if baseName = "X" will return:<br>
|
||||
* "X" if "X" does not already exist, or "X:argIndex" if argIndex > 0<br>
|
||||
* "X_1" if "X" already exists, or "X_1:argIndex" if argIndex > 0<br>
|
||||
* "X_2" if "X" and "X_1" already exists, or "X_2:argIndex" if argIndex > 0<br>
|
||||
* And so on, until an unused name is found
|
||||
*
|
||||
* @param baseName the base name to use (use function.opName() where function is a {@link DifferentialFunction}
|
||||
* @param argIndex the arg index
|
||||
* @return the new generated name
|
||||
*/
|
||||
public String generateNewVarName(String baseName, int argIndex) {
|
||||
if (!variables.containsKey(baseName) && argIndex == 0) {
|
||||
return baseName;
|
||||
}
|
||||
|
||||
//need to find a new name
|
||||
int count = 0;
|
||||
String name = baseName + (count == 0 ? "" : "_" + count) + (argIndex > 0 ? ":" + argIndex : "");
|
||||
while (getVariable(name) != null) {
|
||||
name = baseName + "_" + (++count) + (argIndex > 0 ? ":" + argIndex : "");
|
||||
}
|
||||
|
||||
if (getVariable(name) != null) {
|
||||
throw new ND4JIllegalStateException("Converged on already generated variable!");
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate the variables based on the given input op and return the output variable names.
|
||||
*
|
||||
|
@ -3486,6 +3456,9 @@ public class SameDiff extends SDBaseOps {
|
|||
if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
|
||||
baseName = getBaseNameForFunction(function);
|
||||
|
||||
if (baseName == null)
|
||||
baseName = function.getOwnName();
|
||||
|
||||
if (baseName == null)
|
||||
baseName = function.opName();
|
||||
|
||||
|
@ -3594,7 +3567,8 @@ public class SameDiff extends SDBaseOps {
|
|||
* @return the set of names generated for each output of the function.
|
||||
*/
|
||||
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
|
||||
return generateOutputVariableForOp(function, function.opName(), false);
|
||||
return generateOutputVariableForOp(function,
|
||||
function.getOwnName() != null ? function.getOwnName() : function.opName(), false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -4391,11 +4365,21 @@ public class SameDiff extends SDBaseOps {
|
|||
throw new NullPointerException("Null input: No variable found for updating!");
|
||||
}
|
||||
|
||||
if(newVarName != null) {
|
||||
String nameScope = currentNameScope();
|
||||
if (nameScope != null) {
|
||||
if (!newVarName.startsWith(nameScope + "/")) {
|
||||
newVarName = nameScope + "/" + newVarName;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()){
|
||||
throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable");
|
||||
}
|
||||
|
||||
if (newVarName == null && variables.containsKey(varToUpdate.getVarName())) {
|
||||
if (newVarName == null && variables.containsKey(varToUpdate.getVarName())
|
||||
&& variables.get(varToUpdate.getVarName()).getVariable() != varToUpdate) {
|
||||
//Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name
|
||||
// "mean" and consequently a new variable name needs to be generated
|
||||
newVarName = generateNewVarName(varToUpdate.getVarName(), 0);
|
||||
|
@ -4405,13 +4389,6 @@ public class SameDiff extends SDBaseOps {
|
|||
return varToUpdate;
|
||||
}
|
||||
|
||||
String nameScope = currentNameScope();
|
||||
if(nameScope != null){
|
||||
if(!newVarName.startsWith(nameScope)){
|
||||
newVarName = nameScope + "/" + newVarName;
|
||||
}
|
||||
}
|
||||
|
||||
val oldVarName = varToUpdate.getVarName();
|
||||
varToUpdate.setVarName(newVarName);
|
||||
updateVariableName(oldVarName, newVarName);
|
||||
|
@ -5680,4 +5657,130 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Import a frozen Tensorflow graph to a new SameDiff graph.
|
||||
*
|
||||
* @param graphFile The text or binary file containing the graph
|
||||
* @return The imported graph
|
||||
*/
|
||||
public static SameDiff importFrozenTF(File graphFile){
|
||||
return TFGraphMapper.getInstance().importGraph(graphFile);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #importFrozenTF(File)}
|
||||
*/
|
||||
public static SameDiff importFrozenTF(GraphDef graphDef){
|
||||
return TFGraphMapper.getInstance().importGraph(graphDef);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* See {@link #importFrozenTF(File)}
|
||||
*
|
||||
* Again, the input can be text or binary.
|
||||
*/
|
||||
public static SameDiff importFrozenTF(InputStream graph){
|
||||
return TFGraphMapper.getInstance().importGraph(graph);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Generate a new, distinct op name of the form <base>_#.
|
||||
*
|
||||
* Applies name scope if active.
|
||||
*
|
||||
* @param base The base name to use
|
||||
* @param force Whether to force the result name to be the same as base.
|
||||
*/
|
||||
public String getOpName(String base, boolean force){
|
||||
|
||||
base = nameWithScope(base);
|
||||
|
||||
if(force && ops.containsKey(base))
|
||||
throw new IllegalArgumentException("Op with name \"" + base + "\" already exists");
|
||||
else if(force)
|
||||
return base;
|
||||
|
||||
int start = 1;
|
||||
|
||||
// if we already have a name like "op_2", start from trying "op_3"
|
||||
if(base.contains("_")){
|
||||
// extract number used to generate base
|
||||
Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base);
|
||||
// extract argIndex used to generate base
|
||||
if(num.find()) {
|
||||
start = Integer.parseInt(num.group(2));
|
||||
base = num.group(1);
|
||||
}
|
||||
}
|
||||
|
||||
String name = base;
|
||||
for(int i = start ; true ; i++) {
|
||||
|
||||
// ensure that there are no variables that look like they are outputs of this op
|
||||
boolean varWithName = false;
|
||||
for(String varName : variables.keySet())
|
||||
if(varName.startsWith(name + ":") || varName.equals(name))
|
||||
varWithName = true;
|
||||
|
||||
if(!ops.containsKey(name) && !varWithName)
|
||||
break;
|
||||
|
||||
name = base + "_" + i;
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #getOpName(String, boolean)}
|
||||
* force is false
|
||||
*/
|
||||
public String getOpName(String base){
|
||||
return getOpName(base, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new, distinct variable name of the form <base>_#[:#].
|
||||
*
|
||||
* Applies name scopes if active.
|
||||
*
|
||||
* @param base The base of the name.
|
||||
* @param argIndex The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part.
|
||||
* @param existingOp Whether to generate an distinct operation name from base (if false), or just use base (if true).
|
||||
*/
|
||||
public String generateNewVarName(String base, int argIndex, boolean existingOp){
|
||||
|
||||
base = nameWithScope(base);
|
||||
|
||||
if(argIndex > 0 && base.contains(":")){
|
||||
Matcher num = Pattern.compile("(.*):(\\d+)").matcher(base);
|
||||
// extract argIndex used to generate base
|
||||
if(num.find()) {
|
||||
argIndex = Integer.parseInt(num.group(2)) + 1;
|
||||
base = num.group(1);
|
||||
}
|
||||
}
|
||||
|
||||
if(!existingOp)
|
||||
base = getOpName(base);
|
||||
|
||||
if(argIndex > 0)
|
||||
base += ":" + argIndex;
|
||||
|
||||
if(variables.containsKey(base))
|
||||
throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists");
|
||||
|
||||
return base;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* See {@link #generateNewVarName(String, int, boolean)}
|
||||
* existingOp is true.
|
||||
*/
|
||||
@Override
|
||||
public String generateNewVarName(String base, int argIndex){
|
||||
return generateNewVarName(base, argIndex, true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -471,7 +471,10 @@ public class OpValidation {
|
|||
backpropSeen.add(df.getClass());
|
||||
}
|
||||
for (Class c : backpropSeen) {
|
||||
gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1);
|
||||
if(gradCheckCoverageCountPerClass.containsKey(c))
|
||||
gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1);
|
||||
else
|
||||
gradCheckCoverageCountPerClass.put(c, 1);
|
||||
}
|
||||
|
||||
//Collect coverage information for forward pass (expected outputs)
|
||||
|
@ -491,15 +494,23 @@ public class OpValidation {
|
|||
|
||||
if (seen != null) {
|
||||
for (Class c : seen) {
|
||||
fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1);
|
||||
if(fwdPassCoverageCountPerClass.containsKey(c)) {
|
||||
fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1);
|
||||
} else {
|
||||
fwdPassCoverageCountPerClass.put(c, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void collectCoverageInformation(OpTestCase testCase) {
|
||||
//TODO we're basically assuming subtypes of DynamicCustomOp here, for coverage... not DCO itself
|
||||
singleOpTestCountPerClass.put(testCase.op().getClass(),
|
||||
singleOpTestCountPerClass.get(testCase.op().getClass()) + 1);
|
||||
if(singleOpTestCountPerClass.containsKey(testCase.op().getClass())) {
|
||||
singleOpTestCountPerClass.put(testCase.op().getClass(),
|
||||
singleOpTestCountPerClass.get(testCase.op().getClass()) + 1);
|
||||
} else {
|
||||
singleOpTestCountPerClass.put(testCase.op().getClass(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
|
|||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray ones = Nd4j.ones(4);
|
||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||
SDVariable result = sdVariable.addi(1.0);
|
||||
SDVariable result = sdVariable.add(1.0);
|
||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||
|
||||
val executioner = new NativeGraphExecutioner();
|
||||
|
@ -167,7 +167,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
|
|||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray ones = Nd4j.ones(4);
|
||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||
SDVariable result = sdVariable.addi(1.0);
|
||||
SDVariable result = sdVariable.add(1.0);
|
||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||
|
||||
val executioner = new NativeGraphExecutioner();
|
||||
|
|
|
@ -91,7 +91,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
}, new SameDiffFunctionDefinition() {
|
||||
@Override
|
||||
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
||||
SDVariable ret = variableInputs[1].addi(1.0);
|
||||
SDVariable ret = variableInputs[1].add(1.0);
|
||||
return new SDVariable[]{variableInputs[0], ret};
|
||||
}
|
||||
}, new SDVariable[]{
|
||||
|
@ -116,7 +116,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
}, new SameDiffFunctionDefinition() {
|
||||
@Override
|
||||
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
||||
SDVariable ret = variableInputs[1].addi(1.0);
|
||||
SDVariable ret = variableInputs[1].add(1.0);
|
||||
return new SDVariable[]{variableInputs[0], ret};
|
||||
}
|
||||
}, new SDVariable[]{
|
||||
|
@ -197,7 +197,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
SDVariable w = sd.var("w", Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(4,5));
|
||||
SDVariable b = sd.var("b", Nd4j.linspace(1,5,5, DataType.DOUBLE).reshape(1,5));
|
||||
|
||||
SDVariable mmul = sd.mmul(in,w).addi(b);
|
||||
SDVariable mmul = sd.mmul(in,w).add(b);
|
||||
INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
|
|
|
@ -144,7 +144,7 @@ public class NameScopeTests extends BaseNd4jTest {
|
|||
|
||||
scope.close();
|
||||
|
||||
assertTrue("Var with name test/imax_1 exists", SD.variableMap().containsKey("test/imax_1"));
|
||||
assertTrue("Var with name test/imax exists", SD.variableMap().containsKey("test/imax"));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue