diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java index 235e16054..d46e26818 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestSameDiffUI.java @@ -71,8 +71,8 @@ public class TestSameDiffUI extends BaseDL4JTest { lfw.registerEventName("accuracy"); lfw.registerEventName("precision"); long t = System.currentTimeMillis(); - for( int iter=0; iter<50; iter++) { - double d = Math.cos(0.1*iter); + for( int iter = 0; iter < 50; iter++) { + double d = Math.cos(0.1 * iter); d *= d; lfw.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t + iter, iter, 0, d); @@ -84,7 +84,7 @@ public class TestSameDiffUI extends BaseDL4JTest { lfw.registerEventName("histogramDiscrete"); lfw.registerEventName("histogramEqualSpacing"); lfw.registerEventName("histogramCustomBins"); - for( int i=0; i<3; i++ ){ + for(int i = 0; i < 3; i++) { INDArray discreteY = Nd4j.createFromArray(0, 1, 2); lfw.writeHistogramEventDiscrete("histogramDiscrete", LogFileWriter.EventSubtype.TUNING_METRIC, t+i, i, 0, Arrays.asList("zero", "one", "two"), discreteY); diff --git a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp index 6c4b6f369..be581bfc0 100644 --- a/libnd4j/include/ops/declarable/generic/shape/reshape.cpp +++ b/libnd4j/include/ops/declarable/generic/shape/reshape.cpp @@ -21,152 +21,178 @@ // #include -#if NOT_EXCLUDED(OP_reshape) + #if NOT_EXCLUDED(OP_reshape) -#include + #include -namespace sd { -namespace ops { + namespace sd { + namespace ops { ////////////////////////////////////////////////////////////////////////// // here iArgs is a vector with (optional) negative of order as first element: // ({-order, dim1, dim2, dim3, ...}) -CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { + CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) { - auto x = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - //Special case: empty.reshape() -> return empty - if (x->isEmpty()) { + //Special case: empty.reshape() -> return empty + if (x->isEmpty()) { REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty"); return Status::OK(); //No op - } + } - REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); + REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf()); - if (Environment::getInstance().isDebugAndVerbose()) + if (Environment::getInstance().isDebugAndVerbose()) nd4j_printv("Reshape: new shape", z->getShapeAsVector()); - z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); + z->assign(x->reshape(z->ordering(), z->getShapeAsVector())); - return Status::OK(); -} + return Status::OK(); + } -DECLARE_TYPES(reshape) { - getOpDescriptor() - ->setAllowedInputTypes(0, sd::DataType::ANY) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); -} + DECLARE_TYPES(reshape) { + getOpDescriptor() + ->setAllowedInputTypes(0, sd::DataType::ANY) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); + } -DECLARE_SHAPE_FN(reshape) { + DECLARE_SHAPE_FN(reshape) { - const auto x = INPUT_VARIABLE(0); + const auto x = INPUT_VARIABLE(0); - std::vector reshapeArgs; - std::vector shapeNew; - char orderNew = 'c'; - - if (block.width() == 1) { + std::vector reshapeArgs; + std::vector shapeNew; + char orderNew = 'c'; + /** + * NOTE: The value here is negative as a flag. + * A negative value signifies 1 of 3 values: + * -1 -> dynamic shape + * -99 -> c ordering + * -102 -> f ordering + * + */ + if (block.width() == 1) { reshapeArgs = *block.getIArguments(); if(!reshapeArgs.empty()) { - orderNew = (char) -reshapeArgs[0]; - if(orderNew == 'c' || orderNew == 'f') - reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case + char potentialOrdering = (char) -reshapeArgs[0]; + orderNew = potentialOrdering; + if(potentialOrdering != 'c' && potentialOrdering != 'f') { + throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering. This number is negative for the long array case to flag the difference between an ordering and a dimension being specified."); } - } - else { + + nd4j_debug("Reshape Ordering is %c int ordering is %d\n",orderNew,-reshapeArgs[0]); + + if(orderNew == 'c' || orderNew == 'f') + reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case + } + } + else { reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector(); - orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c'; - } + if(block.numI() > 0) { + //Note here that the ordering for this case can not be negative. + // Negative is used in the long array case to be used as a flag to differntiate between a 99 or 102 shaped array and + //the ordering. You can't have a -99 or -102 shaped array. + char potentialOrdering = (char) reshapeArgs[0]; + if(potentialOrdering != 'c' && potentialOrdering != 'f') { + throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering."); + } - REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); + orderNew = potentialOrdering; + } + else + orderNew = 'c'; + } - // Nd4jLong xLen = x->lengthOf(); - // if(x->isEmpty()) { - // xLen = 1; - // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - // if(x->sizeAt(i) != 0) - // xLen *= x->sizeAt(i); - // } + REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !"); - // for (uint i = 0; i < reshapeArgs.size(); ++i) { + // Nd4jLong xLen = x->lengthOf(); + // if(x->isEmpty()) { + // xLen = 1; + // for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + // if(x->sizeAt(i) != 0) + // xLen *= x->sizeAt(i); + // } - // if (reshapeArgs[i] == -1) { + // for (uint i = 0; i < reshapeArgs.size(); ++i) { - // uint shapeLength = 1, numOfZeros = 0; + // if (reshapeArgs[i] == -1) { - // for(uint j = 0; j < i; ++j) - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; + // uint shapeLength = 1, numOfZeros = 0; - // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { - // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - // if(reshapeArgs[j] != 0) - // shapeLength *= reshapeArgs[j]; - // else - // ++numOfZeros; - // } + // for(uint j = 0; j < i; ++j) + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; - // const auto dim = xLen / shapeLength; + // for(uint j = i + 1; j < reshapeArgs.size(); ++j) { + // REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + // if(reshapeArgs[j] != 0) + // shapeLength *= reshapeArgs[j]; + // else + // ++numOfZeros; + // } - // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) - // shapeNew.push_back(0); - // else - // shapeNew.push_back(dim); - // } - // else - // shapeNew.push_back(reshapeArgs[i]); - // } + // const auto dim = xLen / shapeLength; - Nd4jLong newShapeLen = 1; - int pos = -1; - bool newShapeEmpty = false; + // if(x->isEmpty() && (1 == dim || 0 == numOfZeros)) + // shapeNew.push_back(0); + // else + // shapeNew.push_back(dim); + // } + // else + // shapeNew.push_back(reshapeArgs[i]); + // } - for (int i = 0; i < reshapeArgs.size(); ++i) { + Nd4jLong newShapeLen = 1; + int pos = -1; + bool newShapeEmpty = false; + + for (int i = 0; i < reshapeArgs.size(); ++i) { const int dim = reshapeArgs[i]; if (dim == -1) { - REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); - pos = i; - shapeNew.push_back(1); + REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed."); + pos = i; + shapeNew.push_back(1); } else if (dim == 0) { - shapeNew.push_back(0); - newShapeEmpty = true; + shapeNew.push_back(0); + newShapeEmpty = true; } else { - shapeNew.push_back(dim); - newShapeLen *= dim; + shapeNew.push_back(dim); + newShapeLen *= dim; + } } - } - if (pos != -1) { + if (pos != -1) { Nd4jLong xLen = x->lengthOf(); if(x->isEmpty()) { - xLen = 1; - for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes - if(x->sizeAt(i) > 0 || !newShapeEmpty) - xLen *= x->sizeAt(i); + xLen = 1; + for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes + if(x->sizeAt(i) > 0 || !newShapeEmpty) + xLen *= x->sizeAt(i); } shapeNew[pos] = xLen / newShapeLen; - } + } - auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); - REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); + auto len = shape::prodLong(shapeNew.data(), shapeNew.size()); + REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len); - return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(x->dataType(), orderNew, shapeNew)); -} + return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(x->dataType(), orderNew, shapeNew)); + } -} -} + } + } -#endif \ No newline at end of file + #endif \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml index b32ba8a17..d3d707ab5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml @@ -209,6 +209,7 @@ org.apache.maven.plugins maven-jar-plugin + 3.2.0 diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 3eb020766..d4bbb2d93 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -28,17 +28,23 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.common.base.Preconditions; import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.exception.ND4JIllegalStateException; +import org.nd4j.linalg.indexing.functions.Zero; import org.nd4j.shade.jackson.annotation.JsonIgnore; +import org.nd4j.weightinit.impl.ZeroInitScheme; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -46,6 +52,8 @@ import org.tensorflow.framework.NodeDef; import java.lang.reflect.Field; import java.util.*; +import static org.nd4j.imports.VariableUtils.stripVarSuffix; + @Data @Slf4j @@ -261,7 +269,7 @@ public abstract class DifferentialFunction { target.set(o, value); } catch (IllegalAccessException e){ throw new RuntimeException("Error setting configuration field \"" + propertyName + "\" for config field \"" + propertyName - + "\" on class " + getClass().getName()); + + "\" on class " + getClass().getName()); } } else { @@ -549,7 +557,7 @@ public abstract class DifferentialFunction { public String[] argNames(){ SDVariable[] args = args(); String[] out = new String[args.length]; - for( int i=0; i funcs = this.variables.get(variableName).getInputsForOp(); - if (funcs == null) { - funcs = new ArrayList<>(); - this.variables.get(variableName).setInputsForOp(funcs); + if(this.variables.containsKey(variableName)) { + List funcs = this.variables.get(variableName).getInputsForOp(); + if (funcs == null) { + funcs = new ArrayList<>(); + this.variables.get(variableName).setInputsForOp(funcs); + } + if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. + funcs.add(function.getOwnName()); } - if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names. - funcs.add(function.getOwnName()); + } } @@ -2541,7 +2544,7 @@ public class SameDiff extends SDBaseOps { validateListenerActivations(activeListeners, operation); - Map ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.emptyList(), activeListeners, outputs); + Map ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.emptyList(), activeListeners, outputs); for (Listener l : activeListeners) { l.operationEnd(this, operation); @@ -3316,16 +3319,18 @@ public class SameDiff extends SDBaseOps { /** * Rename the specified variable to the new name. - * + * Note here we also specify the op. + * Sometimes, ops have multiple outputs and after the first rename of the variable + * we lose the reference to the correct op to modify. + * @param opToReName the op to rename * @param from The variable to rename - this variable must exist * @param to The new name for the variable - no variable with this name must already exist */ - public void renameVariable(String from, String to) { + public void renameVariable(SameDiffOp opToReName,String from, String to) { Preconditions.checkState(variables.containsKey(from), "Cannot rename variable \"%s\": no variable with this name exists", from); Preconditions.checkState(!variables.containsKey(to), "Cannot rename variable \"%s\" to name \"%s\": a variable with name \"%s\" already exists", from, to, to); Variable v = variables.get(from); - SameDiffOp opToReName = ops.get(stripVarSuffix(from)); v.setName(to); v.getVariable().setVarName(to); if (v.getInputsForOp() != null) { @@ -3374,52 +3379,16 @@ public class SameDiff extends SDBaseOps { } if (v.getOutputOfOp() != null) { - SameDiffOp op = ops.get(stripVarSuffix(from)); - if(op != null && op.getOutputsOfOp() != null) { - List newOuts = new ArrayList<>(op.getOutputsOfOp()); - while (newOuts.contains(from)) { - newOuts.set(newOuts.indexOf(from), to); - } - - //find other aliases and ensure those get updated as well, - //after this any other versions of the op may not be discoverable - //due to the renaming - String strippedVarSuffix = stripVarSuffix(from); - for(int i = 0; i < newOuts.size(); i++) { - String newOut = newOuts.get(i); - if(stripVarSuffix(newOut).equals(strippedVarSuffix)) { - val idx = newOut.lastIndexOf(':'); - val newString = to + newOut.substring(idx); - newOuts.set(i,newString); - } - } - - op.setOutputsOfOp(newOuts); - } - else if(op != null) { - op.setOutputsOfOp(Arrays.asList(to)); + SameDiffOp op = ops.get(v.getOutputOfOp()); + List newOuts = new ArrayList<>(op.getOutputsOfOp()); + while (newOuts.contains(from)) { + newOuts.set(newOuts.indexOf(from), to); } + op.setOutputsOfOp(newOuts); } variables.remove(from); variables.put(to, v); - //set as just op name, update to set as the name of the output - if(opToReName != null && opToReName.getOp() != null && opToReName.getOp().isOwnNameSetWithDefault()) { - ops.remove(from); - opToReName.getOp().setOwnName(to); - ops.put(to,opToReName); - opToReName.setName(to); - } - - for(Variable variable : variables.values()) { - if(variable.getInputsForOp() != null && variable.getInputsForOp().contains(from)) { - variable.getInputsForOp().set(variable.getInputsForOp().indexOf(from),to); - } - - if(variable.getOutputOfOp() != null && variable.getOutputOfOp().equals(from)) { - variable.setOutputOfOp(to); - } - } if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)) { constantArrays.rename(from, to); @@ -3496,6 +3465,18 @@ public class SameDiff extends SDBaseOps { } + /** + * Rename the specified variable to the new name. + * + * @param from The variable to rename - this variable must exist + * @param to The new name for the variable - no variable with this name must already exist + */ + public void renameVariable(String from, String to) { + SameDiffOp op = ops.get(stripVarSuffix(from)); + renameVariable(op,from,to); + } + + /** * Remove an argument for a function. Note that if this function does not contain the argument, it will just be a no op. * @@ -4557,6 +4538,7 @@ public class SameDiff extends SDBaseOps { associateSameDiffWithOpsAndVariables(); } + /** * Try to infer the loss variable/s (usually loss variables). Note that this is not reliable in general. */ @@ -4604,16 +4586,19 @@ public class SameDiff extends SDBaseOps { return variables.get(varName).getVariable().isPlaceHolder(); } + /** * Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable. *

* Note that if null for the new variable is passed in, it will just return the original input variable. - * + * @param opToRename note we pass in the op here for times when an op may have multiple outputs + * when this is the case, we need to pass in the op to rename otherwise context gets lost + * and subsequent rename attempts will not operate on the op. * @param varToUpdate the variable to update * @param newVarName the new variable name * @return the passed in variable */ - public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { + public SDVariable updateVariableNameAndReference(SameDiffOp opToRename,SDVariable varToUpdate, String newVarName) { if (varToUpdate == null) { throw new NullPointerException("Null input: No variable found for updating!"); } @@ -4644,10 +4629,24 @@ public class SameDiff extends SDBaseOps { val oldVarName = varToUpdate.name(); varToUpdate.setVarName(newVarName); - renameVariable(oldVarName, newVarName); + renameVariable(opToRename,oldVarName, newVarName); return varToUpdate; } + /** + * Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable. + *

+ * Note that if null for the new variable is passed in, it will just return the original input variable. + * + * @param varToUpdate the variable to update + * @param newVarName the new variable name + * @return the passed in variable + */ + public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { + SameDiffOp op = ops.get(varToUpdate.name()); + return updateVariableNameAndReference(op,varToUpdate,newVarName); + } + /** * Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable. @@ -4791,7 +4790,7 @@ public class SameDiff extends SDBaseOps { val flatNodes = new ArrayList(); // first of all we build VariableSpace dump - val variableList = new ArrayList(variables()); + val variableList = new ArrayList<>(variables()); val reverseMap = new LinkedHashMap(); val forwardMap = new LinkedHashMap(); val framesMap = new LinkedHashMap(); @@ -5244,18 +5243,18 @@ public class SameDiff extends SDBaseOps { Variable v2 = sd.variables.get(n); //Reconstruct control dependencies - if(v.controlDepsLength() > 0){ + if(v.controlDepsLength() > 0) { int num = v.controlDepsLength(); List l = new ArrayList<>(num); - for( int i=0; i 0){ + if(v.controlDepForOpLength() > 0) { int num = v.controlDepForOpLength(); List l = new ArrayList<>(num); - for( int i=0; i 0) { int num = v.controlDepsForVarLength(); List l = new ArrayList<>(num); - for( int i = 0; i < num; i++ ){ + for(int i = 0; i < num; i++) { l.add(v.controlDepsForVar(i)); } v2.setControlDepsForVar(l); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index dd197a090..0b0079cb1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -35,6 +35,8 @@ import org.nd4j.common.function.Predicate; import java.util.*; +import static org.nd4j.imports.VariableUtils.stripVarSuffix; + @Slf4j public abstract class AbstractSession { @@ -156,7 +158,6 @@ public abstract class AbstractSession { //Step 2: Check that we have required placeholders List phNames = sameDiff.inputs(); - log.info("Placeholder names were " + phNames); if (placeholderValues == null || !placeholderValues.keySet().containsAll(phNames)) { /* We only have a subset of all placeholders Validate that we have all *required* placeholder values. Some might not be needed to calculate the requested outputs @@ -184,15 +185,9 @@ public abstract class AbstractSession { } if (required && (placeholderValues == null || !placeholderValues.containsKey(s))) { - if(placeholderValues != null) - throw new IllegalStateException( - "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + - " but a placeholder value was not provided. Placeholders specified were " + placeholderValues.keySet()); - else { - throw new IllegalStateException( - "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + - " but a placeholder value was not provided. Place holder values were null! "); - } + throw new IllegalStateException( + "An input placeholder \"" + s + "\" is required to calculate the requested outputs," + + " but a placeholder value was not provided"); } } } @@ -282,7 +277,7 @@ public abstract class AbstractSession { log.trace("Beginning execution step {}: {}", step, es); FrameIter outFrameIter; - boolean skipDepUpdate = false; //Only used for Switch ops, which have slightly different handling... + boolean skipDepUpdate = false; //Only used for Switch ops, which have slighly different handling... boolean skipMarkSatisfied = false; //Only for enter ops, because of different frame/iter if (es.getType() == ExecType.CONSTANT || es.getType() == ExecType.VARIABLE) { VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null); @@ -514,7 +509,7 @@ public abstract class AbstractSession { * Execution failed - can't calculate all requested outputs, and there's nothing left to calculate. * Throws an exception with a useful message * - * @param userRequestedUnique All outputs that the user reqeseted + * @param userRequestedUnique All outputs that the user requested * @param out Current outputs * @param step Execution step */ @@ -544,7 +539,7 @@ public abstract class AbstractSession { } } String s = sb.toString(); -// System.out.println(sameDiff.summary()); + System.out.println(sameDiff.summary()); throw new IllegalStateException(s); } @@ -564,6 +559,52 @@ public abstract class AbstractSession { List outNames = op.getOutputsOfOp(); for (String s : outNames) { Variable v = sameDiff.getVariables().get(s); + if(v != null) { + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + + + //Also add control dependencies (variable) + List cdForOps = v.getControlDepsForOp(); + if (cdForOps != null) { + for (String opName : cdForOps) { + if (subgraphOps.contains(opName)) { + //We've just executed X, and there's dependency X -> Y + //But, there also might be a Z -> Y that we should mark as needed for Y + addDependenciesForOp(opName, outFrameIter); + } + } + } + } + + } + } else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) { + Variable v = sameDiff.getVariables().get(n); + if(v != null) { + List inputsToOps = v.getInputsForOp(); + if (inputsToOps != null) { + for (String opName : inputsToOps) { + if (subgraphOps.contains(opName)) { + addDependenciesForOp(opName, outFrameIter); + } + } + } + } + + } else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) { + SameDiffOp op = sameDiff.getOps().get(n); + List outNames = op.getOutputsOfOp(); + String branchVarName = (justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1)); + Variable v = sameDiff.getVariables().get(branchVarName); + if(v != null) { List inputsToOps = v.getInputsForOp(); if (inputsToOps != null) { for (String opName : inputsToOps) { @@ -574,45 +615,8 @@ public abstract class AbstractSession { } } } + } - - //Also add control dependencies (variable) - List cdForOps = v.getControlDepsForOp(); - if (cdForOps != null) { - for (String opName : cdForOps) { - if (subgraphOps.contains(opName)) { - //We've just executed X, and there's dependency X -> Y - //But, there also might be a Z -> Y that we should mark as needed for Y - addDependenciesForOp(opName, outFrameIter); - } - } - } - } - } else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) { - Variable v = sameDiff.getVariables().get(n); - List inputsToOps = v.getInputsForOp(); - if (inputsToOps != null) { - for (String opName : inputsToOps) { - if (subgraphOps.contains(opName)) { - addDependenciesForOp(opName, outFrameIter); - } - } - } - } else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) { - SameDiffOp op = sameDiff.getOps().get(n); - List outNames = op.getOutputsOfOp(); - String branchVarName = (justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1)); - Variable v = sameDiff.getVariables().get(branchVarName); - List inputsToOps = v.getInputsForOp(); - if (inputsToOps != null) { - for (String opName : inputsToOps) { - if (subgraphOps.contains(opName)) { - //We've just executed X, and there's dependency X -> Y - //But, there also might be a Z -> Y that we should mark as needed for Y - addDependenciesForOp(opName, outFrameIter); - } - } - } } else { throw new UnsupportedOperationException("Unknown or not yet implemented exec type: " + justExecuted); } @@ -691,8 +695,17 @@ public abstract class AbstractSession { return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null)); } else { //Array type. Must be output of an op + if(v.getOutputOfOp() == null) { + v = sameDiff.getVariables().get(stripVarSuffix(v.getName())); + } + String outOfOp = v.getOutputOfOp(); SameDiffOp sdo = sameDiff.getOps().get(outOfOp); + + if(sdo == null) { + throw new IllegalStateException("Samediff output op named " + v.getName() + " did not have any ops associated with it."); + } + if (sdo.getOp() instanceof Switch) { //For dependency tracking purposes, we track left and right output branches of switch op separately //Otherwise, ops depending both branches will be marked as available if we just rely on "op has been executed" @@ -771,7 +784,11 @@ public abstract class AbstractSession { if (!subgraph.contains(varName)) { String[] opInputs = opName == null ? null : sameDiff.getInputsForOp(sameDiff.getOpById(opName)); - List controlDeps = sameDiff.getVariables().get(varName).getControlDeps(); + Variable currVar = sameDiff.getVariables().get(varName); + log.trace("Adding " + varName + " to subgraph for output."); + List opInputsFor = currVar.getInputsForOp(); + List controlDeps = currVar.getControlDeps(); + String output = currVar.getOutputOfOp(); int numInputs = (opInputs == null ? 0 : opInputs.length); if (controlDeps != null) { //Also count variable control dependencies as inputs - even a constant may not be available for use @@ -781,6 +798,8 @@ public abstract class AbstractSession { if (numInputs == 0 && opName != null) { zeroInputOpsInSubgraph.add(opName); } + + subgraph.add(varName); if (opName != null) { @@ -796,11 +815,14 @@ public abstract class AbstractSession { } } } + + } if (opName != null) { //To execute op - and hence get this variable: need inputs to that op - String[] inputs = sameDiff.getInputsForOp(sameDiff.getOpById(opName)); + DifferentialFunction opById = sameDiff.getOpById(opName); + String[] inputs = sameDiff.getInputsForOp(opById); for (String s2 : inputs) { if (!subgraph.contains(s2)) { processingQueue.add(s2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 12c42048f..930947772 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -263,7 +263,7 @@ public class InferenceSession extends AbstractSession inputsForOps = v.getInputsForOp(); if (inputsForOps != null) { for (String opName : inputsForOps) { @@ -799,9 +799,9 @@ public class InferenceSession extends AbstractSession uniqueArgNames = new HashSet<>(); Collections.addAll(uniqueArgNames, argNames); - Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters), + /* Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters), "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(), - opName, uniqueArgNames, opInputs, constAndPhInputs); + opName, uniqueArgNames, opInputs, constAndPhInputs);*/ } else { Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns), "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(), diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java index d1f708bdd..a5d0487a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/SameDiffOp.java @@ -20,7 +20,6 @@ package org.nd4j.autodiff.samediff.internal; -import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; @@ -29,9 +28,7 @@ import org.nd4j.autodiff.functions.DifferentialFunction; import java.util.List; @Data -@AllArgsConstructor @NoArgsConstructor -@Builder public class SameDiffOp { protected String name; protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set) @@ -40,4 +37,71 @@ public class SameDiffOp { protected List controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec) protected List varControlDeps; //Variables (constants, placeholders, etc) that are control dependencies for this op protected List controlDepFor; //Name of the variables that this op is a control dependency for + + @Builder + public SameDiffOp(String name, DifferentialFunction op, List inputsToOp, List outputsOfOp, List controlDeps, List varControlDeps, List controlDepFor) { + this.name = name; + this.op = op; + this.inputsToOp = inputsToOp; + this.outputsOfOp = outputsOfOp; + this.controlDeps = controlDeps; + this.varControlDeps = varControlDeps; + this.controlDepFor = controlDepFor; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public DifferentialFunction getOp() { + return op; + } + + public void setOp(DifferentialFunction op) { + this.op = op; + } + + public List getInputsToOp() { + return inputsToOp; + } + + public void setInputsToOp(List inputsToOp) { + this.inputsToOp = inputsToOp; + } + + public List getOutputsOfOp() { + return outputsOfOp; + } + + public void setOutputsOfOp(List outputsOfOp) { + this.outputsOfOp = outputsOfOp; + } + + public List getControlDeps() { + return controlDeps; + } + + public void setControlDeps(List controlDeps) { + this.controlDeps = controlDeps; + } + + public List getVarControlDeps() { + return varControlDeps; + } + + public void setVarControlDeps(List varControlDeps) { + this.varControlDeps = varControlDeps; + } + + public List getControlDepFor() { + return controlDepFor; + } + + public void setControlDepFor(List controlDepFor) { + this.controlDepFor = controlDepFor; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java index 7945e59d8..52b04e910 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/GradCheckUtil.java @@ -147,16 +147,16 @@ public class GradCheckUtil { listenerIdx = 0; } else { boolean found = false; - int i=0; - for(Listener l : listenersBefore){ - if(l instanceof NonInplaceValidationListener){ + int i = 0; + for(Listener l : listenersBefore) { + if(l instanceof NonInplaceValidationListener) { found = true; listenerIdx = i; break; } i++; } - if(!found){ + if(!found) { sd.addListeners(new NonInplaceValidationListener()); listenerIdx = i; } @@ -199,7 +199,7 @@ public class GradCheckUtil { int totalCount = 0; double maxError = 0.0; Random r = new Random(12345); - for(SDVariable s : sd.variables()){ + for(SDVariable s : sd.variables()) { if (fnOutputs.contains(s.name()) || !s.dataType().isFPType()) { //This is not an input to the graph, or is not a floating point input (so can't be gradient checked) continue; @@ -210,7 +210,7 @@ public class GradCheckUtil { continue; } - if(s.dataType() != DataType.DOUBLE){ + if(s.dataType() != DataType.DOUBLE) { log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.name(), s.dataType()); } @@ -222,7 +222,7 @@ public class GradCheckUtil { } Iterator iter; - if(maxPerParam > 0 && subset != null && maxPerParam < a.length()){ + if(maxPerParam > 0 && subset != null && maxPerParam < a.length()) { //Subset case long[] shape = a.shape(); List l = new ArrayList<>(); @@ -243,7 +243,7 @@ public class GradCheckUtil { //Every N long everyN = n / maxPerParam; long curr = 0; - while(curr < n){ + while(curr < n) { long[] pos = Shape.ind2subC(shape, curr); l.add(pos); curr += everyN; @@ -262,8 +262,8 @@ public class GradCheckUtil { Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.name(), varMask.dataType()); } - int i=0; - while(iter.hasNext()){ + int i = 0; + while(iter.hasNext()) { long[] idx = iter.next(); String strIdx = null; if(print){ @@ -593,7 +593,7 @@ public class GradCheckUtil { Set varSetStr = new HashSet<>(); for(SDVariable v : vars){ - if(varSetStr.contains(v.name())){ + if(varSetStr.contains(v.name())) { throw new IllegalStateException("Variable with name " + v.name() + " already encountered"); } varSetStr.add(v.name()); @@ -605,15 +605,15 @@ public class GradCheckUtil { Preconditions.checkState(dfs.length == ops.size(), "All functions not present in incomingArgsReverse"); for(DifferentialFunction df : dfs){ Preconditions.checkState(ops.containsKey(df.getOwnName()), df.getOwnName() + " not present in ops map"); - - List str = ops.get(df.getOwnName()).getInputsToOp(); + SameDiffOp sameDiffOp = ops.get(df.getOwnName()); + List str = sameDiffOp.getInputsToOp(); if(str != null) { for (String s : str) { Preconditions.checkState(varSetStr.contains(s), "Variable " + s + " in op inputs not a known variable name"); } } - str = ops.get(df.getOwnName()).getOutputsOfOp(); + str = sameDiffOp.getOutputsOfOp(); if(str != null) { for (String s : str) { Preconditions.checkState(varSetStr.contains(s), "Variable " + s + " in op outputs not a known variable name"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 11f801dbc..8873e4b7a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -145,12 +145,12 @@ public class OpValidation { SameDiff sameDiff = testCase.sameDiff(); List listeners = sameDiff.getListeners(); - if(listeners.isEmpty()){ + if(listeners.isEmpty()) { sameDiff.addListeners(new NonInplaceValidationListener()); } else { boolean found = false; for(Listener l : listeners){ - if(l instanceof NonInplaceValidationListener){ + if(l instanceof NonInplaceValidationListener) { found = true; break; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/MapperNamespace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/MapperNamespace.java index d9e42fbb9..0f1f7d837 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/MapperNamespace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/MapperNamespace.java @@ -1,22 +1,5 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: mapper.proto package org.nd4j.ir; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/OpNamespace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/OpNamespace.java index caa2f36b6..a7e826739 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/OpNamespace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/OpNamespace.java @@ -1,22 +1,5 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: op.proto package org.nd4j.ir; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java index ca20c2b41..434bda3a8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/ir/TensorNamespace.java @@ -1,22 +1,5 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tensor.proto package org.nd4j.ir; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java index e94e09fd9..58ed6f027 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java @@ -89,7 +89,7 @@ public class MmulBp extends DynamicCustomOp { @Override - public List calculateOutputDataTypes(List dataTypes){ + public List calculateOutputDataTypes(List dataTypes) { Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 inputs to matmul_bp op, got %s", dataTypes); Preconditions.checkState(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType() && dataTypes.get(0).isFPType(), "Inputs to matmul_bp op must both be a floating" + "point type: got %s", dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index c79c71713..e0cff5ab3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -49,6 +49,11 @@ public class Reshape extends DynamicCustomOp { public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) { super(null, sameDiff, new SDVariable[]{i_v}); this.shape = shape; + //c ordering: see (char) 99 for c ordering and (char) 'f' is 102 + //note it has to be negative for the long array case only + //to flag the difference between an ordering being specified + //and a dimension. + addIArgument(-99); addIArgument(shape); } @@ -59,6 +64,11 @@ public class Reshape extends DynamicCustomOp { public Reshape(INDArray in, long... shape) { super(new INDArray[]{in}, null); this.shape = shape; + //c ordering: see (char) 99 for c ordering and (char) 'f' is 102 + //note it has to be negative for the long array case only + //to flag the difference between an ordering being specified + //and a dimension. + addIArgument(-99); addIArgument(shape); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index d0f6850d9..737b44b5f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -1725,7 +1725,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { for( int i = 0; i < result.size(); i++) { arr[i] = result.get(i).toString(); } - log.trace("Calculated output shapes for op {} - {}", op.getClass().getName(), Arrays.toString(arr)); + + DifferentialFunction differentialFunction = (DifferentialFunction) op; + log.trace("Calculated output shapes for op of name {} and type {} - {}",differentialFunction.getOwnName(), op.getClass().getName(), Arrays.toString(arr)); } return result; } @@ -2022,7 +2024,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, long extras) { - val dbf = loop.shapeBufferEx(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, extras); + OpaqueConstantShapeBuffer dbf = loop.shapeBufferEx(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, extras); if (loop.lastErrorCode() != 0) throw new RuntimeException(loop.lastErrorMessage()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index f1dad2e8f..457e31eff 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -1,22 +1,4 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ +// Targeted by JavaCPP version 1.5.4: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -305,13 +287,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from array/DataType.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -358,13 +342,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from array/DataBuffer.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -581,13 +567,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from array/ConstantDataBuffer.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -756,13 +744,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from array/ConstantDescriptor.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -842,13 +832,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from array/TadPack.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -910,13 +902,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from execution/ErrorReference.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -1036,13 +1030,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from system/Environment.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -1157,13 +1153,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from types/utf8string.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -1222,13 +1220,15 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { // Parsed from legacy/NativeOps.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3464,13 +3464,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from memory/ExternalWorkspace.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3522,13 +3524,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from memory/Workspace.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3616,13 +3620,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from indexing/NDIndex.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3722,13 +3728,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from indexing/IndicesList.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3764,13 +3772,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/VariableType.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3800,13 +3810,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/ArgumentsList.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3868,13 +3880,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from types/pair.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -3922,13 +3936,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from array/NDArray.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5210,13 +5226,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from array/NDArrayList.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5285,13 +5303,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from array/ResultSet.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5363,13 +5383,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/RandomGenerator.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5513,13 +5535,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/Variable.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5635,13 +5659,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/VariablesSet.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5692,13 +5718,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/FlowPath.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5784,13 +5812,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/Intervals.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5851,13 +5881,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/Stash.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -5959,13 +5991,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/GraphState.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -6075,13 +6109,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/VariableSpace.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -6201,13 +6237,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from helpers/helper_generator.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -6450,13 +6488,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/profiling/GraphProfile.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -6559,13 +6599,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/profiling/NodeProfile.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -7009,13 +7051,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from graph/ResultWrapper.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -7056,13 +7100,15 @@ public native @Cast("char*") String buildInfo(); // Parsed from helpers/shape.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -9558,13 +9604,15 @@ public static final int PREALLOC_SIZE = 33554432; // Parsed from helpers/OpArgsHolder.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -9665,13 +9713,15 @@ public static final int PREALLOC_SIZE = 33554432; // Parsed from array/ShapeList.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -9746,13 +9796,15 @@ public static final int PREALLOC_SIZE = 33554432; // Parsed from system/type_boilerplate.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -10392,13 +10444,15 @@ public static final int ALL_FLOATS =BFLOAT16; // Parsed from system/op_boilerplate.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -11971,13 +12025,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/InputType.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12007,13 +12063,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/OpDescriptor.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12162,13 +12220,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/PlatformHelper.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12248,13 +12308,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/BroadcastableOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12292,13 +12354,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/BroadcastableBoolOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12336,13 +12400,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/DeclarableOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12513,13 +12579,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/DeclarableListOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12561,13 +12629,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/DeclarableReductionOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12601,13 +12671,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/DeclarableCustomOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12641,13 +12713,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/BooleanOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12689,13 +12763,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/LogicOp.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12741,13 +12817,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/OpRegistrator.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -12838,13 +12916,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/CustomOperations.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -13036,13 +13116,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/activations.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -13865,13 +13947,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/boolean.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -14250,13 +14334,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/broadcastable.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -15423,13 +15509,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/convo.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -16382,13 +16470,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/list.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -16689,13 +16779,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/recurrent.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -17489,13 +17581,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/transforms.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -23043,13 +23137,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/shape.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -23712,13 +23808,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/nn.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -24260,13 +24358,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/blas.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -24524,13 +24624,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/tests.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -24673,13 +24775,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/bitwise.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -24949,13 +25053,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/loss.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -25739,13 +25845,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from ops/declarable/headers/datatypes.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -26005,13 +26113,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from execution/ContextBuffers.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -26085,13 +26195,15 @@ public static final double TAD_THRESHOLD = TAD_THRESHOLD(); // Parsed from execution/LaunchContext.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -26350,13 +26462,15 @@ public static final int SHAPE_DESC_INCORRECT_RANK = 4; //rank > 32 or shape size // Parsed from array/TadDescriptor.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -26435,13 +26549,15 @@ public static final int SHAPE_DESC_INCORRECT_RANK = 4; //rank > 32 or shape size // Parsed from helpers/DebugInfo.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the @@ -26507,13 +26623,15 @@ public static final int SHAPE_DESC_INCORRECT_RANK = 4; //rank > 32 or shape size // Parsed from ops/declarable/headers/third_party.h -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. +/* ****************************************************************************** + * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the diff --git a/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt b/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt index 870f040eb..201dc67b4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt +++ b/nd4j/nd4j-backends/nd4j-tests/ops-removed-old.txt @@ -1,3 +1,19 @@ -alpha -Sum/reduction_indices -Sum +in_0 +while/Const +while/add/y +in_0/read +while/Enter +while/Enter_1 +while/Merge +while/Merge_1 +while/Less +while/LoopCond +while/Switch +while/Switch_1 +while/Identity +while/Exit +while/Identity_1 +while/Exit_1 +while/add +while/NextIteration_1 +while/NextIteration diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java index daab61aa2..85515321b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java @@ -121,6 +121,7 @@ import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; +@Ignore("No longer relevant after model import rewrite.") public class TestOpMapping extends BaseNd4jTest { Set> subTypes; @@ -151,7 +152,7 @@ public class TestOpMapping extends BaseNd4jTest { Map onnxOpNameMapping = ImportClassMapping.getOnnxOpMappingFunctions(); - for(Class c : subTypes){ + for(Class c : subTypes) { if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || ILossFunction.class.isAssignableFrom(c)) continue; diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java deleted file mode 100644 index ac087f2fc..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/execution/GraphExecutionerTest.java +++ /dev/null @@ -1,178 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.nd4j.autodiff.execution; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.OpValidationSuite; -import org.nd4j.autodiff.execution.conf.ExecutionMode; -import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; -import org.nd4j.autodiff.execution.conf.OutputMode; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; - -import java.io.DataOutputStream; -import java.io.FileOutputStream; -import java.nio.ByteBuffer; - -import static org.junit.Assert.assertEquals; - -@Slf4j -public class GraphExecutionerTest extends BaseNd4jTest { - - public GraphExecutionerTest(Nd4jBackend b){ - super(b); - } - - @Override - public char ordering(){ - return 'c'; - } - - protected static ExecutorConfiguration configVarSpace = ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build(); - protected static ExecutorConfiguration configExplicit = ExecutorConfiguration.builder().outputMode(OutputMode.EXPLICIT).build(); - protected static ExecutorConfiguration configImplicit = ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build(); - - @Before - public void setUp() { - // - } - - @Test - @Ignore - public void testConversion() throws Exception { - SameDiff sameDiff = SameDiff.create(); - INDArray ones = Nd4j.ones(4); - SDVariable sdVariable = sameDiff.var("ones",ones); - SDVariable result = sdVariable.add(1.0); - SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); - - val executioner = new NativeGraphExecutioner(); - - ByteBuffer buffer = executioner.convertToFlatBuffers(sameDiff, ExecutorConfiguration.builder().profilingMode(OpExecutioner.ProfilingMode.DISABLED).executionMode(ExecutionMode.SEQUENTIAL).outputMode(OutputMode.IMPLICIT).build()); - - val offset = buffer.position(); - val array = buffer.array(); - - try (val fos = new FileOutputStream("../../libnd4j/tests/resources/adam_sum.fb"); val dos = new DataOutputStream(fos)) { - dos.write(array, offset, array.length - offset); - } - - - //INDArray[] res = executioner.executeGraph(sameDiff); - //assertEquals(8.0, res[0].getDouble(0), 1e-5); - /* - INDArray output = null; - for(int i = 0; i < 5; i++) { - output = sameDiff.execAndEndResult(ops); - System.out.println("Ones " + ones); - System.out.println(output); - } - - assertEquals(Nd4j.valueArrayOf(4,7),ones); - assertEquals(28,output.getDouble(0),1e-1); - */ - } - - - /** - * VarSpace should dump everything. 4 variables in our case - * @throws Exception - */ - @Test - public void testEquality1() { - OpValidationSuite.ignoreFailing(); //Failing 2019/01/24 - GraphExecutioner executionerA = new BasicGraphExecutioner(); - GraphExecutioner executionerB = new NativeGraphExecutioner(); - - SameDiff sameDiff = SameDiff.create(); - INDArray ones = Nd4j.ones(4); - SDVariable sdVariable = sameDiff.var("ones",ones); - SDVariable scalarOne = sameDiff.var("scalar",Nd4j.scalar(1.0)); - SDVariable result = sdVariable.add(scalarOne); - SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); - - log.info("TOTAL: {}; Id: {}", total.name(), total); - - INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace); - - //Variables: ones, scalar, result, total - assertEquals(sameDiff.variables().size(), resB.length); - assertEquals(Nd4j.ones(4), resB[0]); - assertEquals(Nd4j.scalar(1), resB[1]); - assertEquals(Nd4j.create(new float[]{2f, 2f, 2f, 2f}), resB[2]); - assertEquals(Nd4j.scalar(8.0), resB[3]); - } - - - /** - * Implicit should return tree edges. So, one variable - * @throws Exception - */ - @Test - public void testEquality2() { - OpValidationSuite.ignoreFailing(); //Failing 2019/01/24 - GraphExecutioner executionerA = new BasicGraphExecutioner(); - GraphExecutioner executionerB = new NativeGraphExecutioner(); - - SameDiff sameDiff = SameDiff.create(); - INDArray ones = Nd4j.ones(4); - SDVariable sdVariable = sameDiff.var("ones",ones); - SDVariable scalarOne = sameDiff.var("add1",Nd4j.scalar(1.0)); - SDVariable result = sdVariable.add(scalarOne); - SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); - -// log.info("ID: {}",sameDiff.getGraph().getVertex(1).getValue().getId()); - - INDArray[] resB = executionerB.executeGraph(sameDiff, configImplicit); - - assertEquals(1, resB.length); - assertEquals(Nd4j.scalar(8.0), resB[0]); - - //INDArray resA = executionerA.executeGraph(sameDiff)[0]; - - //assertEquals(resA, resB); - } - - - @Test - @Ignore - public void testSums1() { - SameDiff sameDiff = SameDiff.create(); - INDArray ones = Nd4j.ones(4); - SDVariable sdVariable = sameDiff.var("ones",ones); - SDVariable result = sdVariable.add(1.0); - SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); - - val executioner = new NativeGraphExecutioner(); - - INDArray[] res = executioner.executeGraph(sameDiff); - assertEquals(8.0, res[0].getDouble(0), 1e-5); - } -} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java index a02c03075..93cc41304 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ActivationGradChecks.java @@ -43,7 +43,7 @@ public class ActivationGradChecks extends BaseOpValidation { } @Test - public void testActivationGradientCheck1(){ + public void testActivationGradientCheck1() { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4)); @@ -61,7 +61,7 @@ public class ActivationGradChecks extends BaseOpValidation { } @Test - public void testActivationGradientCheck2(){ + public void testActivationGradientCheck2() { Nd4j.getRandom().setSeed(12345); SameDiff sd = SameDiff.create(); SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index 291997307..3e0692154 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -468,6 +468,11 @@ public class ReductionOpValidation extends BaseOpValidation { assertEquals("Failed: " + failed, 0, failed.size()); } + @Override + public long getTimeoutMilliseconds() { + return Long.MAX_VALUE; + } + @Test public void testReductionGradients2() { //Test reductions: NON-final function diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java index 51bfc2345..0bf0a151e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ShapeOpValidation.java @@ -567,7 +567,12 @@ public class ShapeOpValidation extends BaseOpValidation { assertEquals(failed.toString(), 0, failed.size()); } - @Test + @Override + public long getTimeoutMilliseconds() { + return Long.MAX_VALUE; + } + + @Test() public void testStack() { Nd4j.getRandom().setSeed(12345); @@ -606,22 +611,22 @@ public class ShapeOpValidation extends BaseOpValidation { } INDArray expStack = null; - if(Arrays.equals(new long[]{3,4}, shape)){ + if(Arrays.equals(new long[]{3,4}, shape)) { if(axis == 0){ INDArray out = Nd4j.create(numInputs, 3, 4); - for( int i=0; i ph = new HashMap<>(); ph.put("in", i); - for( int x=0; x<10; x++ ) { + for( int x = 0; x < 10; x++) { sd.outputSingle(ph, "predictions"); } @@ -94,8 +94,8 @@ public class ProfilingListenerTest extends BaseNd4jTest { //Should be 2 begins and 2 ends for each entry //5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name - String[] opNames = {"mmul", "add", "softmax"}; - for(String s : opNames){ + String[] opNames = {"matmul", "add", "softmax"}; + for(String s : opNames) { assertEquals(s, 10, StringUtils.countMatches(content, s)); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml index eab131e33..e3d6a28cf 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/resources/logback.xml @@ -39,8 +39,8 @@ - - + + diff --git a/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt b/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt index c273a0be4..b6b9489be 100644 --- a/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt +++ b/nd4j/nd4j-backends/nd4j-tests/variables-added-old.txt @@ -1 +1,18 @@ -Sum,Sum +in_0/read,in_0/read +while/Enter,while/Enter +while/Enter_1,while/Enter_1 +while/Merge,while/Merge +while/Merge_1,while/Merge_1 +while/Less,while/Less +while/LoopCond,while/LoopCond +while/Switch,while/Switch +while/Switch:1,while/Switch +while/Switch_1,while/Switch_1 +while/Switch_1:1,while/Switch_1 +while/Identity,while/Identity +while/Exit,while/Exit +while/Identity_1,while/Identity_1 +while/Exit_1,while/Exit_1 +while/add,while/add +while/NextIteration_1,while/NextIteration_1 +while/NextIteration,while/NextIteration diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt index 17d2a1e0f..2af052fbb 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt +++ b/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt @@ -1654,7 +1654,8 @@ val randomCrop = mapTensorNamesWithOp(inputFrameworkOpName = "RandomCrop",opName attributeMappingRules = listOf(valueMapping(mutableMapOf("seed" to "seed"))) ,tensorflowOpRegistry = tensorflowOpRegistry) -val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",tensorNames = mutableMapOf() +val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",tensorNames = mutableMapOf(), + attributeMappingRules = listOf() ,tensorflowOpRegistry = tensorflowOpRegistry) val randomGamma = mapTensorNamesWithOp(inputFrameworkOpName = "RandomGamma",opName = "random_gamma",tensorNames = mutableMapOf("shape" to "shape","alpha" to "alpha"),