SameDiff fixes and naming (#78)
* remove SDVariable inplace methods * import methods * npe fix in OpVal * removed SameDiff inplace ops from tests * Naming updates, moved to centralized methods in SameDiff, should use op_#:# for everything * quick fixes * javadoc * SDVariable eval with placeholders * use regex match * better matchingmaster
parent
ce0743da17
commit
ac321265a7
|
@ -657,19 +657,7 @@ public abstract class DifferentialFunction {
|
||||||
if(sameDiff == null)
|
if(sameDiff == null)
|
||||||
this.ownName = UUID.randomUUID().toString();
|
this.ownName = UUID.randomUUID().toString();
|
||||||
else {
|
else {
|
||||||
int argIndex = 0;
|
this.ownName = sameDiff.getOpName(opName());
|
||||||
String scope = sameDiff.currentNameScope();
|
|
||||||
if(scope == null)
|
|
||||||
scope = "";
|
|
||||||
else
|
|
||||||
scope = scope + "/";
|
|
||||||
String varName = scope + sameDiff.generateNewVarName(opName(),argIndex);
|
|
||||||
while(sameDiff.functionExists(varName)) {
|
|
||||||
varName = scope + sameDiff.generateNewVarName(opName(), argIndex);
|
|
||||||
argIndex++;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.ownName = varName;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(sameDiff != null && !(this instanceof SDVariable))
|
if(sameDiff != null && !(this instanceof SDVariable))
|
||||||
|
|
|
@ -91,10 +91,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
" SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme);
|
" SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme);
|
||||||
Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
|
Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
|
||||||
|
|
||||||
String nameScope = sameDiff.currentNameScope();
|
varName = sameDiff.generateNewVarName(varName, 0, true);
|
||||||
if(nameScope != null && !varName.startsWith(nameScope + "/")){
|
|
||||||
varName = nameScope + "/" + varName;
|
|
||||||
}
|
|
||||||
|
|
||||||
this.varName = varName;
|
this.varName = varName;
|
||||||
this.variableType = varType;
|
this.variableType = varType;
|
||||||
|
@ -656,7 +653,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #add(String, double)}
|
* See {@link #add(String, double)}
|
||||||
*/
|
*/
|
||||||
public SDVariable add(double scalar) {
|
public SDVariable add(double scalar) {
|
||||||
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),scalar);
|
return add(null,scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -676,7 +673,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #add(String, SDVariable)}
|
* See {@link #add(String, SDVariable)}
|
||||||
*/
|
*/
|
||||||
public SDVariable add(SDVariable other) {
|
public SDVariable add(SDVariable other) {
|
||||||
return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),other);
|
return add(null,other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -713,7 +710,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #sub(String, double)}
|
* See {@link #sub(String, double)}
|
||||||
*/
|
*/
|
||||||
public SDVariable sub(double scalar) {
|
public SDVariable sub(double scalar) {
|
||||||
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),scalar);
|
return sub(null,scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -733,7 +730,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #sub(String, SDVariable)}
|
* See {@link #sub(String, SDVariable)}
|
||||||
*/
|
*/
|
||||||
public SDVariable sub(SDVariable x) {
|
public SDVariable sub(SDVariable x) {
|
||||||
return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),x);
|
return sub(null,x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -770,7 +767,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #div(String,double)}
|
* See {@link #div(String,double)}
|
||||||
*/
|
*/
|
||||||
public SDVariable div(double scalar) {
|
public SDVariable div(double scalar) {
|
||||||
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),scalar);
|
return div(null,scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -790,7 +787,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #div(String, SDVariable)}
|
* See {@link #div(String, SDVariable)}
|
||||||
*/
|
*/
|
||||||
public SDVariable div(SDVariable x) {
|
public SDVariable div(SDVariable x) {
|
||||||
return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),x);
|
return div(null,x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -811,7 +808,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #mul(String, double)}
|
* See {@link #mul(String, double)}
|
||||||
*/
|
*/
|
||||||
public SDVariable mul(double scalar) {
|
public SDVariable mul(double scalar) {
|
||||||
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),scalar);
|
return mul(null,scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -832,7 +829,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #mul(String, SDVariable)}
|
* See {@link #mul(String, SDVariable)}
|
||||||
*/
|
*/
|
||||||
public SDVariable mul(SDVariable x) {
|
public SDVariable mul(SDVariable x) {
|
||||||
return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),x);
|
return mul(null,x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -889,7 +886,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #rsub(String, double)}
|
* See {@link #rsub(String, double)}
|
||||||
*/
|
*/
|
||||||
public SDVariable rsub(double scalar) {
|
public SDVariable rsub(double scalar) {
|
||||||
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),scalar);
|
return rsub(null,scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -909,7 +906,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #rsub(String, SDVariable)}
|
* See {@link #rsub(String, SDVariable)}
|
||||||
*/
|
*/
|
||||||
public SDVariable rsub(SDVariable x) {
|
public SDVariable rsub(SDVariable x) {
|
||||||
return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),x);
|
return rsub(null,x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -930,7 +927,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #rdiv(String, double)}
|
* See {@link #rdiv(String, double)}
|
||||||
*/
|
*/
|
||||||
public SDVariable rdiv(double scalar) {
|
public SDVariable rdiv(double scalar) {
|
||||||
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),scalar);
|
return rdiv(null,scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -950,7 +947,7 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
* See {@link #rdiv(String, SDVariable)}
|
* See {@link #rdiv(String, SDVariable)}
|
||||||
*/
|
*/
|
||||||
public SDVariable rdiv(SDVariable sameDiffVariable) {
|
public SDVariable rdiv(SDVariable sameDiffVariable) {
|
||||||
return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
|
return rdiv(null,sameDiffVariable);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -968,134 +965,13 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rsubi(double sameDiffVariable) {
|
|
||||||
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rdivi(double sameDiffVariable) {
|
|
||||||
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable addi(double sameDiffVariable) {
|
|
||||||
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable subi(double sameDiffVariable) {
|
|
||||||
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable divi(double sameDiffVariable) {
|
|
||||||
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable muli(double sameDiffVariable) {
|
|
||||||
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param sameDiffVariable
|
* @param sameDiffVariable
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public SDVariable truncatedDiv(SDVariable sameDiffVariable) {
|
public SDVariable truncatedDiv(SDVariable sameDiffVariable) {
|
||||||
return truncatedDiv(sameDiff.generateNewVarName(TruncateDivOp.OP_NAME,0),sameDiffVariable);
|
return truncatedDiv(null,sameDiffVariable);
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rsubi(SDVariable sameDiffVariable) {
|
|
||||||
return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rdivi(SDVariable sameDiffVariable) {
|
|
||||||
return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable addi(SDVariable sameDiffVariable) {
|
|
||||||
return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable subi(SDVariable sameDiffVariable) {
|
|
||||||
return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable divi(SDVariable sameDiffVariable) {
|
|
||||||
return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable muli(SDVariable sameDiffVariable) {
|
|
||||||
return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1112,154 +988,17 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rsubi(String varName, double sameDiffVariable) {
|
|
||||||
val function = sameDiff.f().rsubi(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rdivi(String varName, double sameDiffVariable) {
|
|
||||||
SDVariable function = sameDiff.f().rdivi(this
|
|
||||||
,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable addi(String varName, double sameDiffVariable) {
|
|
||||||
val function = sameDiff.f().addi(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable subi(String varName, double sameDiffVariable) {
|
|
||||||
val function = sameDiff.f().subi(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable divi(String varName, double sameDiffVariable) {
|
|
||||||
val function = sameDiff.f().divi(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable muli(String varName, double sameDiffVariable) {
|
|
||||||
val function = sameDiff.f().muli(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(function,varName);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rsubi(String varName, SDVariable sameDiffVariable) {
|
|
||||||
val result = sameDiff.f().rsubi(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable rdivi(String varName, SDVariable sameDiffVariable) {
|
|
||||||
val result = sameDiff.f().rdivi(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable addi(String varName, SDVariable sameDiffVariable) {
|
|
||||||
val result = sameDiff.f().addi(this,sameDiffVariable);
|
|
||||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Op.Type opType() {
|
public Op.Type opType() {
|
||||||
return Op.Type.RETURN;
|
return Op.Type.RETURN;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable subi(String varName, SDVariable sameDiffVariable) {
|
|
||||||
SDVariable left = this;
|
|
||||||
SDVariable right = sameDiffVariable;
|
|
||||||
val result = sameDiff.f().subi(left,right);
|
|
||||||
return sameDiff.updateVariableNameAndReference(result,varName);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable divi(String varName, SDVariable sameDiffVariable) {
|
|
||||||
val result = sameDiff.f().divi(this,sameDiffVariable);
|
|
||||||
result.setVarName(varName);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param sameDiffVariable
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SDVariable muli(String varName, SDVariable sameDiffVariable) {
|
|
||||||
SDVariable left = this;
|
|
||||||
SDVariable right = sameDiffVariable;
|
|
||||||
SDVariable result = sameDiff.f().muli(left,right);
|
|
||||||
result.setVarName(varName);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #squaredDifference(String, SDVariable)}
|
* See {@link #squaredDifference(String, SDVariable)}
|
||||||
*/
|
*/
|
||||||
public SDVariable squaredDifference(SDVariable x) {
|
public SDVariable squaredDifference(SDVariable x) {
|
||||||
return squaredDifference(sameDiff.generateNewVarName(SquaredDifferenceOp.OP_NAME,0),x);
|
return squaredDifference(null,x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1780,6 +1519,16 @@ public class SDVariable extends DifferentialFunction implements Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Evaluate the result of this variable
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public INDArray eval(Map<String, INDArray> placeholders) {
|
||||||
|
sameDiff.exec(placeholders, getVarName());
|
||||||
|
return getArr();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType +
|
return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType +
|
||||||
|
|
|
@ -22,6 +22,8 @@ import com.google.common.primitives.Ints;
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import com.rits.cloning.Cloner;
|
import com.rits.cloning.Cloner;
|
||||||
import com.rits.cloning.IFastCloner;
|
import com.rits.cloning.IFastCloner;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
|
@ -43,6 +45,7 @@ import org.nd4j.autodiff.util.cloner.INDArrayFastCloner;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.evaluation.IEvaluation;
|
import org.nd4j.evaluation.IEvaluation;
|
||||||
import org.nd4j.graph.*;
|
import org.nd4j.graph.*;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
|
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
|
||||||
|
@ -97,6 +100,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
import java.util.zip.ZipEntry;
|
import java.util.zip.ZipEntry;
|
||||||
import java.util.zip.ZipFile;
|
import java.util.zip.ZipFile;
|
||||||
import java.util.zip.ZipOutputStream;
|
import java.util.zip.ZipOutputStream;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
|
||||||
|
@ -2498,9 +2502,14 @@ public class SameDiff extends SDBaseOps {
|
||||||
//TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API!
|
//TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API!
|
||||||
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme,
|
||||||
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
org.nd4j.linalg.api.buffer.DataType dataType, long... shape) {
|
||||||
String withScope = nameWithScope(name);
|
|
||||||
|
|
||||||
if (variables.containsKey(withScope)) {
|
|
||||||
|
if (name == null || name.length() < 1)
|
||||||
|
name = getNewVarName();
|
||||||
|
else
|
||||||
|
name = generateNewVarName(name, 0);
|
||||||
|
|
||||||
|
if (variables.containsKey(name)) {
|
||||||
if(nameScopes.isEmpty()){
|
if(nameScopes.isEmpty()){
|
||||||
throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \""
|
throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \""
|
||||||
+ currentNameScope() + "\"");
|
+ currentNameScope() + "\"");
|
||||||
|
@ -2509,9 +2518,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (name == null || name.length() < 1)
|
|
||||||
name = getNewVarName();
|
|
||||||
|
|
||||||
|
|
||||||
SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme);
|
SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme);
|
||||||
addVariable(ret);
|
addVariable(ret);
|
||||||
|
@ -2647,12 +2653,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getNewVarName() {
|
private String getNewVarName() {
|
||||||
String varName = "sd_var_" + String.valueOf(variableId);
|
return generateNewVarName("sd_var", 0, false);
|
||||||
while (variables.containsKey(varName)) {
|
|
||||||
variableId++;
|
|
||||||
varName = "sd_var_" + String.valueOf(variableId);
|
|
||||||
}
|
|
||||||
return varName;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -3442,37 +3443,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Generate a new variable name based on the uniqueness of the base name and arg index<br>
|
|
||||||
* For example, if baseName = "X" will return:<br>
|
|
||||||
* "X" if "X" does not already exist, or "X:argIndex" if argIndex > 0<br>
|
|
||||||
* "X_1" if "X" already exists, or "X_1:argIndex" if argIndex > 0<br>
|
|
||||||
* "X_2" if "X" and "X_1" already exists, or "X_2:argIndex" if argIndex > 0<br>
|
|
||||||
* And so on, until an unused name is found
|
|
||||||
*
|
|
||||||
* @param baseName the base name to use (use function.opName() where function is a {@link DifferentialFunction}
|
|
||||||
* @param argIndex the arg index
|
|
||||||
* @return the new generated name
|
|
||||||
*/
|
|
||||||
public String generateNewVarName(String baseName, int argIndex) {
|
|
||||||
if (!variables.containsKey(baseName) && argIndex == 0) {
|
|
||||||
return baseName;
|
|
||||||
}
|
|
||||||
|
|
||||||
//need to find a new name
|
|
||||||
int count = 0;
|
|
||||||
String name = baseName + (count == 0 ? "" : "_" + count) + (argIndex > 0 ? ":" + argIndex : "");
|
|
||||||
while (getVariable(name) != null) {
|
|
||||||
name = baseName + "_" + (++count) + (argIndex > 0 ? ":" + argIndex : "");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (getVariable(name) != null) {
|
|
||||||
throw new ND4JIllegalStateException("Converged on already generated variable!");
|
|
||||||
}
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate the variables based on the given input op and return the output variable names.
|
* Generate the variables based on the given input op and return the output variable names.
|
||||||
*
|
*
|
||||||
|
@ -3486,6 +3456,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
|
if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null)
|
||||||
baseName = getBaseNameForFunction(function);
|
baseName = getBaseNameForFunction(function);
|
||||||
|
|
||||||
|
if (baseName == null)
|
||||||
|
baseName = function.getOwnName();
|
||||||
|
|
||||||
if (baseName == null)
|
if (baseName == null)
|
||||||
baseName = function.opName();
|
baseName = function.opName();
|
||||||
|
|
||||||
|
@ -3594,7 +3567,8 @@ public class SameDiff extends SDBaseOps {
|
||||||
* @return the set of names generated for each output of the function.
|
* @return the set of names generated for each output of the function.
|
||||||
*/
|
*/
|
||||||
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
|
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) {
|
||||||
return generateOutputVariableForOp(function, function.opName(), false);
|
return generateOutputVariableForOp(function,
|
||||||
|
function.getOwnName() != null ? function.getOwnName() : function.opName(), false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -4391,11 +4365,21 @@ public class SameDiff extends SDBaseOps {
|
||||||
throw new NullPointerException("Null input: No variable found for updating!");
|
throw new NullPointerException("Null input: No variable found for updating!");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(newVarName != null) {
|
||||||
|
String nameScope = currentNameScope();
|
||||||
|
if (nameScope != null) {
|
||||||
|
if (!newVarName.startsWith(nameScope + "/")) {
|
||||||
|
newVarName = nameScope + "/" + newVarName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if(newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()){
|
if(newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()){
|
||||||
throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable");
|
throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (newVarName == null && variables.containsKey(varToUpdate.getVarName())) {
|
if (newVarName == null && variables.containsKey(varToUpdate.getVarName())
|
||||||
|
&& variables.get(varToUpdate.getVarName()).getVariable() != varToUpdate) {
|
||||||
//Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name
|
//Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name
|
||||||
// "mean" and consequently a new variable name needs to be generated
|
// "mean" and consequently a new variable name needs to be generated
|
||||||
newVarName = generateNewVarName(varToUpdate.getVarName(), 0);
|
newVarName = generateNewVarName(varToUpdate.getVarName(), 0);
|
||||||
|
@ -4405,13 +4389,6 @@ public class SameDiff extends SDBaseOps {
|
||||||
return varToUpdate;
|
return varToUpdate;
|
||||||
}
|
}
|
||||||
|
|
||||||
String nameScope = currentNameScope();
|
|
||||||
if(nameScope != null){
|
|
||||||
if(!newVarName.startsWith(nameScope)){
|
|
||||||
newVarName = nameScope + "/" + newVarName;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
val oldVarName = varToUpdate.getVarName();
|
val oldVarName = varToUpdate.getVarName();
|
||||||
varToUpdate.setVarName(newVarName);
|
varToUpdate.setVarName(newVarName);
|
||||||
updateVariableName(oldVarName, newVarName);
|
updateVariableName(oldVarName, newVarName);
|
||||||
|
@ -5680,4 +5657,130 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Import a frozen Tensorflow graph to a new SameDiff graph.
|
||||||
|
*
|
||||||
|
* @param graphFile The text or binary file containing the graph
|
||||||
|
* @return The imported graph
|
||||||
|
*/
|
||||||
|
public static SameDiff importFrozenTF(File graphFile){
|
||||||
|
return TFGraphMapper.getInstance().importGraph(graphFile);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #importFrozenTF(File)}
|
||||||
|
*/
|
||||||
|
public static SameDiff importFrozenTF(GraphDef graphDef){
|
||||||
|
return TFGraphMapper.getInstance().importGraph(graphDef);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #importFrozenTF(File)}
|
||||||
|
*
|
||||||
|
* Again, the input can be text or binary.
|
||||||
|
*/
|
||||||
|
public static SameDiff importFrozenTF(InputStream graph){
|
||||||
|
return TFGraphMapper.getInstance().importGraph(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new, distinct op name of the form <base>_#.
|
||||||
|
*
|
||||||
|
* Applies name scope if active.
|
||||||
|
*
|
||||||
|
* @param base The base name to use
|
||||||
|
* @param force Whether to force the result name to be the same as base.
|
||||||
|
*/
|
||||||
|
public String getOpName(String base, boolean force){
|
||||||
|
|
||||||
|
base = nameWithScope(base);
|
||||||
|
|
||||||
|
if(force && ops.containsKey(base))
|
||||||
|
throw new IllegalArgumentException("Op with name \"" + base + "\" already exists");
|
||||||
|
else if(force)
|
||||||
|
return base;
|
||||||
|
|
||||||
|
int start = 1;
|
||||||
|
|
||||||
|
// if we already have a name like "op_2", start from trying "op_3"
|
||||||
|
if(base.contains("_")){
|
||||||
|
// extract number used to generate base
|
||||||
|
Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base);
|
||||||
|
// extract argIndex used to generate base
|
||||||
|
if(num.find()) {
|
||||||
|
start = Integer.parseInt(num.group(2));
|
||||||
|
base = num.group(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
String name = base;
|
||||||
|
for(int i = start ; true ; i++) {
|
||||||
|
|
||||||
|
// ensure that there are no variables that look like they are outputs of this op
|
||||||
|
boolean varWithName = false;
|
||||||
|
for(String varName : variables.keySet())
|
||||||
|
if(varName.startsWith(name + ":") || varName.equals(name))
|
||||||
|
varWithName = true;
|
||||||
|
|
||||||
|
if(!ops.containsKey(name) && !varWithName)
|
||||||
|
break;
|
||||||
|
|
||||||
|
name = base + "_" + i;
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See {@link #getOpName(String, boolean)}
|
||||||
|
* force is false
|
||||||
|
*/
|
||||||
|
public String getOpName(String base){
|
||||||
|
return getOpName(base, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a new, distinct variable name of the form <base>_#[:#].
|
||||||
|
*
|
||||||
|
* Applies name scopes if active.
|
||||||
|
*
|
||||||
|
* @param base The base of the name.
|
||||||
|
* @param argIndex The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part.
|
||||||
|
* @param existingOp Whether to generate an distinct operation name from base (if false), or just use base (if true).
|
||||||
|
*/
|
||||||
|
public String generateNewVarName(String base, int argIndex, boolean existingOp){
|
||||||
|
|
||||||
|
base = nameWithScope(base);
|
||||||
|
|
||||||
|
if(argIndex > 0 && base.contains(":")){
|
||||||
|
Matcher num = Pattern.compile("(.*):(\\d+)").matcher(base);
|
||||||
|
// extract argIndex used to generate base
|
||||||
|
if(num.find()) {
|
||||||
|
argIndex = Integer.parseInt(num.group(2)) + 1;
|
||||||
|
base = num.group(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!existingOp)
|
||||||
|
base = getOpName(base);
|
||||||
|
|
||||||
|
if(argIndex > 0)
|
||||||
|
base += ":" + argIndex;
|
||||||
|
|
||||||
|
if(variables.containsKey(base))
|
||||||
|
throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists");
|
||||||
|
|
||||||
|
return base;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* See {@link #generateNewVarName(String, int, boolean)}
|
||||||
|
* existingOp is true.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public String generateNewVarName(String base, int argIndex){
|
||||||
|
return generateNewVarName(base, argIndex, true);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -471,7 +471,10 @@ public class OpValidation {
|
||||||
backpropSeen.add(df.getClass());
|
backpropSeen.add(df.getClass());
|
||||||
}
|
}
|
||||||
for (Class c : backpropSeen) {
|
for (Class c : backpropSeen) {
|
||||||
gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1);
|
if(gradCheckCoverageCountPerClass.containsKey(c))
|
||||||
|
gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1);
|
||||||
|
else
|
||||||
|
gradCheckCoverageCountPerClass.put(c, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Collect coverage information for forward pass (expected outputs)
|
//Collect coverage information for forward pass (expected outputs)
|
||||||
|
@ -491,15 +494,23 @@ public class OpValidation {
|
||||||
|
|
||||||
if (seen != null) {
|
if (seen != null) {
|
||||||
for (Class c : seen) {
|
for (Class c : seen) {
|
||||||
fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1);
|
if(fwdPassCoverageCountPerClass.containsKey(c)) {
|
||||||
|
fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1);
|
||||||
|
} else {
|
||||||
|
fwdPassCoverageCountPerClass.put(c, 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void collectCoverageInformation(OpTestCase testCase) {
|
private static void collectCoverageInformation(OpTestCase testCase) {
|
||||||
//TODO we're basically assuming subtypes of DynamicCustomOp here, for coverage... not DCO itself
|
//TODO we're basically assuming subtypes of DynamicCustomOp here, for coverage... not DCO itself
|
||||||
singleOpTestCountPerClass.put(testCase.op().getClass(),
|
if(singleOpTestCountPerClass.containsKey(testCase.op().getClass())) {
|
||||||
singleOpTestCountPerClass.get(testCase.op().getClass()) + 1);
|
singleOpTestCountPerClass.put(testCase.op().getClass(),
|
||||||
|
singleOpTestCountPerClass.get(testCase.op().getClass()) + 1);
|
||||||
|
} else {
|
||||||
|
singleOpTestCountPerClass.put(testCase.op().getClass(), 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray ones = Nd4j.ones(4);
|
INDArray ones = Nd4j.ones(4);
|
||||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||||
SDVariable result = sdVariable.addi(1.0);
|
SDVariable result = sdVariable.add(1.0);
|
||||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||||
|
|
||||||
val executioner = new NativeGraphExecutioner();
|
val executioner = new NativeGraphExecutioner();
|
||||||
|
@ -167,7 +167,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
|
||||||
SameDiff sameDiff = SameDiff.create();
|
SameDiff sameDiff = SameDiff.create();
|
||||||
INDArray ones = Nd4j.ones(4);
|
INDArray ones = Nd4j.ones(4);
|
||||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||||
SDVariable result = sdVariable.addi(1.0);
|
SDVariable result = sdVariable.add(1.0);
|
||||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||||
|
|
||||||
val executioner = new NativeGraphExecutioner();
|
val executioner = new NativeGraphExecutioner();
|
||||||
|
|
|
@ -91,7 +91,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
||||||
}, new SameDiffFunctionDefinition() {
|
}, new SameDiffFunctionDefinition() {
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
||||||
SDVariable ret = variableInputs[1].addi(1.0);
|
SDVariable ret = variableInputs[1].add(1.0);
|
||||||
return new SDVariable[]{variableInputs[0], ret};
|
return new SDVariable[]{variableInputs[0], ret};
|
||||||
}
|
}
|
||||||
}, new SDVariable[]{
|
}, new SDVariable[]{
|
||||||
|
@ -116,7 +116,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
||||||
}, new SameDiffFunctionDefinition() {
|
}, new SameDiffFunctionDefinition() {
|
||||||
@Override
|
@Override
|
||||||
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
||||||
SDVariable ret = variableInputs[1].addi(1.0);
|
SDVariable ret = variableInputs[1].add(1.0);
|
||||||
return new SDVariable[]{variableInputs[0], ret};
|
return new SDVariable[]{variableInputs[0], ret};
|
||||||
}
|
}
|
||||||
}, new SDVariable[]{
|
}, new SDVariable[]{
|
||||||
|
@ -197,7 +197,7 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
||||||
SDVariable w = sd.var("w", Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(4,5));
|
SDVariable w = sd.var("w", Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(4,5));
|
||||||
SDVariable b = sd.var("b", Nd4j.linspace(1,5,5, DataType.DOUBLE).reshape(1,5));
|
SDVariable b = sd.var("b", Nd4j.linspace(1,5,5, DataType.DOUBLE).reshape(1,5));
|
||||||
|
|
||||||
SDVariable mmul = sd.mmul(in,w).addi(b);
|
SDVariable mmul = sd.mmul(in,w).add(b);
|
||||||
INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
|
INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
|
||||||
|
|
||||||
INDArray out = sd.execAndEndResult();
|
INDArray out = sd.execAndEndResult();
|
||||||
|
|
|
@ -144,7 +144,7 @@ public class NameScopeTests extends BaseNd4jTest {
|
||||||
|
|
||||||
scope.close();
|
scope.close();
|
||||||
|
|
||||||
assertTrue("Var with name test/imax_1 exists", SD.variableMap().containsKey("test/imax_1"));
|
assertTrue("Var with name test/imax exists", SD.variableMap().containsKey("test/imax"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
Loading…
Reference in New Issue