diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 5f5ab2a7c..f4d167866 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -657,19 +657,7 @@ public abstract class DifferentialFunction { if(sameDiff == null) this.ownName = UUID.randomUUID().toString(); else { - int argIndex = 0; - String scope = sameDiff.currentNameScope(); - if(scope == null) - scope = ""; - else - scope = scope + "/"; - String varName = scope + sameDiff.generateNewVarName(opName(),argIndex); - while(sameDiff.functionExists(varName)) { - varName = scope + sameDiff.generateNewVarName(opName(), argIndex); - argIndex++; - } - - this.ownName = varName; + this.ownName = sameDiff.getOpName(opName()); } if(sameDiff != null && !(this instanceof SDVariable)) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 8ac789f9b..f618b1186 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -91,10 +91,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { " SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme); Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); - String nameScope = sameDiff.currentNameScope(); - if(nameScope != null && !varName.startsWith(nameScope + "/")){ - varName = nameScope + "/" + varName; - } + varName = sameDiff.generateNewVarName(varName, 0, true); this.varName = varName; this.variableType = varType; @@ -656,7 +653,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #add(String, double)} */ public SDVariable add(double scalar) { - return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),scalar); + return add(null,scalar); } /** @@ -676,7 +673,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #add(String, SDVariable)} */ public SDVariable add(SDVariable other) { - return add(sameDiff.generateNewVarName(AddOp.OP_NAME,0),other); + return add(null,other); } /** @@ -713,7 +710,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #sub(String, double)} */ public SDVariable sub(double scalar) { - return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),scalar); + return sub(null,scalar); } /** @@ -733,7 +730,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #sub(String, SDVariable)} */ public SDVariable sub(SDVariable x) { - return sub(sameDiff.generateNewVarName(SubOp.OP_NAME,0),x); + return sub(null,x); } /** @@ -770,7 +767,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #div(String,double)} */ public SDVariable div(double scalar) { - return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),scalar); + return div(null,scalar); } /** @@ -790,7 +787,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #div(String, SDVariable)} */ public SDVariable div(SDVariable x) { - return div(sameDiff.generateNewVarName(DivOp.OP_NAME,0),x); + return div(null,x); } /** @@ -811,7 +808,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #mul(String, double)} */ public SDVariable mul(double scalar) { - return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),scalar); + return mul(null,scalar); } /** @@ -832,7 +829,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #mul(String, SDVariable)} */ public SDVariable mul(SDVariable x) { - return mul(sameDiff.generateNewVarName(MulOp.OP_NAME,0),x); + return mul(null,x); } /** @@ -889,7 +886,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #rsub(String, double)} */ public SDVariable rsub(double scalar) { - return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),scalar); + return rsub(null,scalar); } /** @@ -909,7 +906,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #rsub(String, SDVariable)} */ public SDVariable rsub(SDVariable x) { - return rsub(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),x); + return rsub(null,x); } /** @@ -930,7 +927,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #rdiv(String, double)} */ public SDVariable rdiv(double scalar) { - return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),scalar); + return rdiv(null,scalar); } /** @@ -950,7 +947,7 @@ public class SDVariable extends DifferentialFunction implements Serializable { * See {@link #rdiv(String, SDVariable)} */ public SDVariable rdiv(SDVariable sameDiffVariable) { - return rdiv(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable); + return rdiv(null,sameDiffVariable); } /** @@ -968,134 +965,13 @@ public class SDVariable extends DifferentialFunction implements Serializable { } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rsubi(double sameDiffVariable) { - return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rdivi(double sameDiffVariable) { - return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable addi(double sameDiffVariable) { - return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable subi(double sameDiffVariable) { - return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable divi(double sameDiffVariable) { - return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable muli(double sameDiffVariable) { - return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable); - - } - /** * * @param sameDiffVariable * @return */ public SDVariable truncatedDiv(SDVariable sameDiffVariable) { - return truncatedDiv(sameDiff.generateNewVarName(TruncateDivOp.OP_NAME,0),sameDiffVariable); - - } - - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rsubi(SDVariable sameDiffVariable) { - return rsubi(sameDiff.generateNewVarName(RSubOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rdivi(SDVariable sameDiffVariable) { - return rdivi(sameDiff.generateNewVarName(RDivOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable addi(SDVariable sameDiffVariable) { - return addi(sameDiff.generateNewVarName(AddOp.OP_NAME,0),sameDiffVariable); - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable subi(SDVariable sameDiffVariable) { - return subi(sameDiff.generateNewVarName(SubOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable divi(SDVariable sameDiffVariable) { - return divi(sameDiff.generateNewVarName(DivOp.OP_NAME,0),sameDiffVariable); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable muli(SDVariable sameDiffVariable) { - return muli(sameDiff.generateNewVarName(MulOp.OP_NAME,0),sameDiffVariable); + return truncatedDiv(null,sameDiffVariable); } @@ -1112,154 +988,17 @@ public class SDVariable extends DifferentialFunction implements Serializable { } - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rsubi(String varName, double sameDiffVariable) { - val function = sameDiff.f().rsubi(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rdivi(String varName, double sameDiffVariable) { - SDVariable function = sameDiff.f().rdivi(this - ,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable addi(String varName, double sameDiffVariable) { - val function = sameDiff.f().addi(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable subi(String varName, double sameDiffVariable) { - val function = sameDiff.f().subi(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable divi(String varName, double sameDiffVariable) { - val function = sameDiff.f().divi(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable muli(String varName, double sameDiffVariable) { - val function = sameDiff.f().muli(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - } - - - - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rsubi(String varName, SDVariable sameDiffVariable) { - val result = sameDiff.f().rsubi(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(result,varName); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable rdivi(String varName, SDVariable sameDiffVariable) { - val result = sameDiff.f().rdivi(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(result,varName); - - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable addi(String varName, SDVariable sameDiffVariable) { - val result = sameDiff.f().addi(this,sameDiffVariable); - return sameDiff.updateVariableNameAndReference(result,varName); - - } - @Override public Op.Type opType() { return Op.Type.RETURN; } - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable subi(String varName, SDVariable sameDiffVariable) { - SDVariable left = this; - SDVariable right = sameDiffVariable; - val result = sameDiff.f().subi(left,right); - return sameDiff.updateVariableNameAndReference(result,varName); - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable divi(String varName, SDVariable sameDiffVariable) { - val result = sameDiff.f().divi(this,sameDiffVariable); - result.setVarName(varName); - return result; - } - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable muli(String varName, SDVariable sameDiffVariable) { - SDVariable left = this; - SDVariable right = sameDiffVariable; - SDVariable result = sameDiff.f().muli(left,right); - result.setVarName(varName); - return result; - } /** * See {@link #squaredDifference(String, SDVariable)} */ public SDVariable squaredDifference(SDVariable x) { - return squaredDifference(sameDiff.generateNewVarName(SquaredDifferenceOp.OP_NAME,0),x); + return squaredDifference(null,x); } /** @@ -1780,6 +1519,16 @@ public class SDVariable extends DifferentialFunction implements Serializable { } + /** + * Evaluate the result of this variable + * @return + */ + public INDArray eval(Map placeholders) { + sameDiff.exec(placeholders, getVarName()); + return getArr(); + } + + @Override public String toString() { return "SDVariable(name=\"" + varName + "\",variableType=" + variableType + ",dtype=" + dataType + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 2b2ba63dc..0da607a42 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -22,6 +22,8 @@ import com.google.common.primitives.Ints; import com.google.flatbuffers.FlatBufferBuilder; import com.rits.cloning.Cloner; import com.rits.cloning.IFastCloner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import lombok.*; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; @@ -43,6 +45,7 @@ import org.nd4j.autodiff.util.cloner.INDArrayFastCloner; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; import org.nd4j.graph.*; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.factory.DataBufferFactory; @@ -97,6 +100,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.zip.ZipEntry; import java.util.zip.ZipFile; import java.util.zip.ZipOutputStream; +import org.tensorflow.framework.GraphDef; /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. @@ -2498,9 +2502,14 @@ public class SameDiff extends SDBaseOps { //TODO only allowing null datatype for TF import (it's fixed in a later step) - don't want this in the public API! public SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, org.nd4j.linalg.api.buffer.DataType dataType, long... shape) { - String withScope = nameWithScope(name); - if (variables.containsKey(withScope)) { + + if (name == null || name.length() < 1) + name = getNewVarName(); + else + name = generateNewVarName(name, 0); + + if (variables.containsKey(name)) { if(nameScopes.isEmpty()){ throw new IllegalArgumentException("Another variable with the name " + name + " already exists (current name scope: \"" + currentNameScope() + "\""); @@ -2509,9 +2518,6 @@ public class SameDiff extends SDBaseOps { } } - if (name == null || name.length() < 1) - name = getNewVarName(); - SDVariable ret = new SDVariable(name, variableType, this, shape, dataType, weightInitScheme); addVariable(ret); @@ -2647,12 +2653,7 @@ public class SameDiff extends SDBaseOps { } private String getNewVarName() { - String varName = "sd_var_" + String.valueOf(variableId); - while (variables.containsKey(varName)) { - variableId++; - varName = "sd_var_" + String.valueOf(variableId); - } - return varName; + return generateNewVarName("sd_var", 0, false); } /** @@ -3442,37 +3443,6 @@ public class SameDiff extends SDBaseOps { } - /** - * Generate a new variable name based on the uniqueness of the base name and arg index
- * For example, if baseName = "X" will return:
- * "X" if "X" does not already exist, or "X:argIndex" if argIndex > 0
- * "X_1" if "X" already exists, or "X_1:argIndex" if argIndex > 0
- * "X_2" if "X" and "X_1" already exists, or "X_2:argIndex" if argIndex > 0
- * And so on, until an unused name is found - * - * @param baseName the base name to use (use function.opName() where function is a {@link DifferentialFunction} - * @param argIndex the arg index - * @return the new generated name - */ - public String generateNewVarName(String baseName, int argIndex) { - if (!variables.containsKey(baseName) && argIndex == 0) { - return baseName; - } - - //need to find a new name - int count = 0; - String name = baseName + (count == 0 ? "" : "_" + count) + (argIndex > 0 ? ":" + argIndex : ""); - while (getVariable(name) != null) { - name = baseName + "_" + (++count) + (argIndex > 0 ? ":" + argIndex : ""); - } - - if (getVariable(name) != null) { - throw new ND4JIllegalStateException("Converged on already generated variable!"); - } - return name; - } - - /** * Generate the variables based on the given input op and return the output variable names. * @@ -3486,6 +3456,9 @@ public class SameDiff extends SDBaseOps { if (baseName == null || baseName.isEmpty() && getBaseNameForFunction(function) != null) baseName = getBaseNameForFunction(function); + if (baseName == null) + baseName = function.getOwnName(); + if (baseName == null) baseName = function.opName(); @@ -3594,7 +3567,8 @@ public class SameDiff extends SDBaseOps { * @return the set of names generated for each output of the function. */ public SDVariable[] generateOutputVariableForOp(DifferentialFunction function) { - return generateOutputVariableForOp(function, function.opName(), false); + return generateOutputVariableForOp(function, + function.getOwnName() != null ? function.getOwnName() : function.opName(), false); } /** @@ -4391,11 +4365,21 @@ public class SameDiff extends SDBaseOps { throw new NullPointerException("Null input: No variable found for updating!"); } + if(newVarName != null) { + String nameScope = currentNameScope(); + if (nameScope != null) { + if (!newVarName.startsWith(nameScope + "/")) { + newVarName = nameScope + "/" + newVarName; + } + } + } + if(newVarName != null && variables.containsKey(newVarName) && varToUpdate != variables.get(newVarName).getVariable()){ throw new IllegalStateException("Variable name \"" + newVarName + "\" already exists for a different SDVariable"); } - if (newVarName == null && variables.containsKey(varToUpdate.getVarName())) { + if (newVarName == null && variables.containsKey(varToUpdate.getVarName()) + && variables.get(varToUpdate.getVarName()).getVariable() != varToUpdate) { //Edge case: suppose we do m1=sd.mean(in), m2=sd.mean(m1) -> both initially have the name // "mean" and consequently a new variable name needs to be generated newVarName = generateNewVarName(varToUpdate.getVarName(), 0); @@ -4405,13 +4389,6 @@ public class SameDiff extends SDBaseOps { return varToUpdate; } - String nameScope = currentNameScope(); - if(nameScope != null){ - if(!newVarName.startsWith(nameScope)){ - newVarName = nameScope + "/" + newVarName; - } - } - val oldVarName = varToUpdate.getVarName(); varToUpdate.setVarName(newVarName); updateVariableName(oldVarName, newVarName); @@ -5680,4 +5657,130 @@ public class SameDiff extends SDBaseOps { } } + /** + * Import a frozen Tensorflow graph to a new SameDiff graph. + * + * @param graphFile The text or binary file containing the graph + * @return The imported graph + */ + public static SameDiff importFrozenTF(File graphFile){ + return TFGraphMapper.getInstance().importGraph(graphFile); + } + + /** + * See {@link #importFrozenTF(File)} + */ + public static SameDiff importFrozenTF(GraphDef graphDef){ + return TFGraphMapper.getInstance().importGraph(graphDef); + } + + + /** + * See {@link #importFrozenTF(File)} + * + * Again, the input can be text or binary. + */ + public static SameDiff importFrozenTF(InputStream graph){ + return TFGraphMapper.getInstance().importGraph(graph); + } + + + /** + * Generate a new, distinct op name of the form <base>_#. + * + * Applies name scope if active. + * + * @param base The base name to use + * @param force Whether to force the result name to be the same as base. + */ + public String getOpName(String base, boolean force){ + + base = nameWithScope(base); + + if(force && ops.containsKey(base)) + throw new IllegalArgumentException("Op with name \"" + base + "\" already exists"); + else if(force) + return base; + + int start = 1; + + // if we already have a name like "op_2", start from trying "op_3" + if(base.contains("_")){ + // extract number used to generate base + Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base); + // extract argIndex used to generate base + if(num.find()) { + start = Integer.parseInt(num.group(2)); + base = num.group(1); + } + } + + String name = base; + for(int i = start ; true ; i++) { + + // ensure that there are no variables that look like they are outputs of this op + boolean varWithName = false; + for(String varName : variables.keySet()) + if(varName.startsWith(name + ":") || varName.equals(name)) + varWithName = true; + + if(!ops.containsKey(name) && !varWithName) + break; + + name = base + "_" + i; + } + return name; + } + + /** + * See {@link #getOpName(String, boolean)} + * force is false + */ + public String getOpName(String base){ + return getOpName(base, false); + } + + /** + * Generate a new, distinct variable name of the form <base>_#[:#]. + * + * Applies name scopes if active. + * + * @param base The base of the name. + * @param argIndex The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part. + * @param existingOp Whether to generate an distinct operation name from base (if false), or just use base (if true). + */ + public String generateNewVarName(String base, int argIndex, boolean existingOp){ + + base = nameWithScope(base); + + if(argIndex > 0 && base.contains(":")){ + Matcher num = Pattern.compile("(.*):(\\d+)").matcher(base); + // extract argIndex used to generate base + if(num.find()) { + argIndex = Integer.parseInt(num.group(2)) + 1; + base = num.group(1); + } + } + + if(!existingOp) + base = getOpName(base); + + if(argIndex > 0) + base += ":" + argIndex; + + if(variables.containsKey(base)) + throw new IllegalArgumentException("Variable with name \"" + base + "\" already exists"); + + return base; + } + + /** + * + * See {@link #generateNewVarName(String, int, boolean)} + * existingOp is true. + */ + @Override + public String generateNewVarName(String base, int argIndex){ + return generateNewVarName(base, argIndex, true); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 33d03e490..586c279c1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -471,7 +471,10 @@ public class OpValidation { backpropSeen.add(df.getClass()); } for (Class c : backpropSeen) { - gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1); + if(gradCheckCoverageCountPerClass.containsKey(c)) + gradCheckCoverageCountPerClass.put(c, gradCheckCoverageCountPerClass.get(c) + 1); + else + gradCheckCoverageCountPerClass.put(c, 1); } //Collect coverage information for forward pass (expected outputs) @@ -491,15 +494,23 @@ public class OpValidation { if (seen != null) { for (Class c : seen) { - fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1); + if(fwdPassCoverageCountPerClass.containsKey(c)) { + fwdPassCoverageCountPerClass.put(c, fwdPassCoverageCountPerClass.get(c) + 1); + } else { + fwdPassCoverageCountPerClass.put(c, 1); + } } } } private static void collectCoverageInformation(OpTestCase testCase) { //TODO we're basically assuming subtypes of DynamicCustomOp here, for coverage... not DCO itself - singleOpTestCountPerClass.put(testCase.op().getClass(), - singleOpTestCountPerClass.get(testCase.op().getClass()) + 1); + if(singleOpTestCountPerClass.containsKey(testCase.op().getClass())) { + singleOpTestCountPerClass.put(testCase.op().getClass(), + singleOpTestCountPerClass.get(testCase.op().getClass()) + 1); + } else { + singleOpTestCountPerClass.put(testCase.op().getClass(), 1); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java index d684391c5..177a4b795 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java @@ -70,7 +70,7 @@ public class GraphExecutionerTest extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); INDArray ones = Nd4j.ones(4); SDVariable sdVariable = sameDiff.var("ones",ones); - SDVariable result = sdVariable.addi(1.0); + SDVariable result = sdVariable.add(1.0); SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); val executioner = new NativeGraphExecutioner(); @@ -167,7 +167,7 @@ public class GraphExecutionerTest extends BaseNd4jTest { SameDiff sameDiff = SameDiff.create(); INDArray ones = Nd4j.ones(4); SDVariable sdVariable = sameDiff.var("ones",ones); - SDVariable result = sdVariable.addi(1.0); + SDVariable result = sdVariable.add(1.0); SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); val executioner = new NativeGraphExecutioner(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java index 235a8660b..bec6e0349 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/FailingSameDiffTests.java @@ -91,7 +91,7 @@ public class FailingSameDiffTests extends BaseNd4jTest { }, new SameDiffFunctionDefinition() { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].addi(1.0); + SDVariable ret = variableInputs[1].add(1.0); return new SDVariable[]{variableInputs[0], ret}; } }, new SDVariable[]{ @@ -116,7 +116,7 @@ public class FailingSameDiffTests extends BaseNd4jTest { }, new SameDiffFunctionDefinition() { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { - SDVariable ret = variableInputs[1].addi(1.0); + SDVariable ret = variableInputs[1].add(1.0); return new SDVariable[]{variableInputs[0], ret}; } }, new SDVariable[]{ @@ -197,7 +197,7 @@ public class FailingSameDiffTests extends BaseNd4jTest { SDVariable w = sd.var("w", Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(4,5)); SDVariable b = sd.var("b", Nd4j.linspace(1,5,5, DataType.DOUBLE).reshape(1,5)); - SDVariable mmul = sd.mmul(in,w).addi(b); + SDVariable mmul = sd.mmul(in,w).add(b); INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); INDArray out = sd.execAndEndResult(); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java index 3b1f57d33..659fc5438 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -144,7 +144,7 @@ public class NameScopeTests extends BaseNd4jTest { scope.close(); - assertTrue("Var with name test/imax_1 exists", SD.variableMap().containsKey("test/imax_1")); + assertTrue("Var with name test/imax exists", SD.variableMap().containsKey("test/imax")); } @Test