Fix reshape and other unit tests
parent
e770e0b0b4
commit
b2fabb0585
|
@ -71,8 +71,8 @@ public class TestSameDiffUI extends BaseDL4JTest {
|
|||
lfw.registerEventName("accuracy");
|
||||
lfw.registerEventName("precision");
|
||||
long t = System.currentTimeMillis();
|
||||
for( int iter=0; iter<50; iter++) {
|
||||
double d = Math.cos(0.1*iter);
|
||||
for( int iter = 0; iter < 50; iter++) {
|
||||
double d = Math.cos(0.1 * iter);
|
||||
d *= d;
|
||||
lfw.writeScalarEvent("accuracy", LogFileWriter.EventSubtype.EVALUATION, t + iter, iter, 0, d);
|
||||
|
||||
|
@ -84,7 +84,7 @@ public class TestSameDiffUI extends BaseDL4JTest {
|
|||
lfw.registerEventName("histogramDiscrete");
|
||||
lfw.registerEventName("histogramEqualSpacing");
|
||||
lfw.registerEventName("histogramCustomBins");
|
||||
for( int i=0; i<3; i++ ){
|
||||
for(int i = 0; i < 3; i++) {
|
||||
INDArray discreteY = Nd4j.createFromArray(0, 1, 2);
|
||||
lfw.writeHistogramEventDiscrete("histogramDiscrete", LogFileWriter.EventSubtype.TUNING_METRIC, t+i, i, 0, Arrays.asList("zero", "one", "two"), discreteY);
|
||||
|
||||
|
|
|
@ -21,152 +21,178 @@
|
|||
//
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_reshape)
|
||||
#if NOT_EXCLUDED(OP_reshape)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// here iArgs is a vector with (optional) negative of order as first element:
|
||||
// ({-order, dim1, dim2, dim3, ...})
|
||||
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||
CUSTOM_OP_IMPL(reshape, 1, 1, false, 0, -2) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
//Special case: empty.reshape(<other empty shape>) -> return empty
|
||||
if (x->isEmpty()) {
|
||||
REQUIRE_TRUE(z->isEmpty(), 0, "Reshape: when input is empty, output must also be empty");
|
||||
return Status::OK(); //No op
|
||||
}
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
|
||||
REQUIRE_TRUE(x->lengthOf() == z->lengthOf(), 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), z->lengthOf());
|
||||
|
||||
if (Environment::getInstance().isDebugAndVerbose())
|
||||
if (Environment::getInstance().isDebugAndVerbose())
|
||||
nd4j_printv("Reshape: new shape", z->getShapeAsVector());
|
||||
|
||||
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
|
||||
z->assign(x->reshape(z->ordering(), z->getShapeAsVector()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
DECLARE_TYPES(reshape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, sd::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
DECLARE_TYPES(reshape) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, sd::DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(reshape) {
|
||||
DECLARE_SHAPE_FN(reshape) {
|
||||
|
||||
const auto x = INPUT_VARIABLE(0);
|
||||
const auto x = INPUT_VARIABLE(0);
|
||||
|
||||
std::vector<int> reshapeArgs;
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
char orderNew = 'c';
|
||||
|
||||
if (block.width() == 1) {
|
||||
std::vector<int> reshapeArgs;
|
||||
std::vector<Nd4jLong> shapeNew;
|
||||
char orderNew = 'c';
|
||||
/**
|
||||
* NOTE: The value here is negative as a flag.
|
||||
* A negative value signifies 1 of 3 values:
|
||||
* -1 -> dynamic shape
|
||||
* -99 -> c ordering
|
||||
* -102 -> f ordering
|
||||
*
|
||||
*/
|
||||
if (block.width() == 1) {
|
||||
reshapeArgs = *block.getIArguments();
|
||||
if(!reshapeArgs.empty()) {
|
||||
orderNew = (char) -reshapeArgs[0];
|
||||
if(orderNew == 'c' || orderNew == 'f')
|
||||
reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
|
||||
char potentialOrdering = (char) -reshapeArgs[0];
|
||||
orderNew = potentialOrdering;
|
||||
if(potentialOrdering != 'c' && potentialOrdering != 'f') {
|
||||
throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering. This number is negative for the long array case to flag the difference between an ordering and a dimension being specified.");
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
nd4j_debug("Reshape Ordering is %c int ordering is %d\n",orderNew,-reshapeArgs[0]);
|
||||
|
||||
if(orderNew == 'c' || orderNew == 'f')
|
||||
reshapeArgs.erase(reshapeArgs.begin()); // remove first element being order in this case
|
||||
}
|
||||
}
|
||||
else {
|
||||
reshapeArgs = INPUT_VARIABLE(1)->getBufferAsVector<int>();
|
||||
orderNew = block.numI() > 0 ? (char) -INT_ARG(0) : 'c';
|
||||
}
|
||||
if(block.numI() > 0) {
|
||||
//Note here that the ordering for this case can not be negative.
|
||||
// Negative is used in the long array case to be used as a flag to differntiate between a 99 or 102 shaped array and
|
||||
//the ordering. You can't have a -99 or -102 shaped array.
|
||||
char potentialOrdering = (char) reshapeArgs[0];
|
||||
if(potentialOrdering != 'c' && potentialOrdering != 'f') {
|
||||
throw std::runtime_error("reshape:: Value passed in must be -99 or -102 for the ordering if an int array is present. -99 represents c ordering and -102 represents f ordering.");
|
||||
}
|
||||
|
||||
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
||||
orderNew = potentialOrdering;
|
||||
}
|
||||
else
|
||||
orderNew = 'c';
|
||||
}
|
||||
|
||||
// Nd4jLong xLen = x->lengthOf();
|
||||
// if(x->isEmpty()) {
|
||||
// xLen = 1;
|
||||
// for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||
// if(x->sizeAt(i) != 0)
|
||||
// xLen *= x->sizeAt(i);
|
||||
// }
|
||||
REQUIRE_TRUE(!reshapeArgs.empty() || x->lengthOf() == 1, 0, "Reshape buffer should have at least 1 dimension !");
|
||||
|
||||
// for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||
// Nd4jLong xLen = x->lengthOf();
|
||||
// if(x->isEmpty()) {
|
||||
// xLen = 1;
|
||||
// for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||
// if(x->sizeAt(i) != 0)
|
||||
// xLen *= x->sizeAt(i);
|
||||
// }
|
||||
|
||||
// if (reshapeArgs[i] == -1) {
|
||||
// for (uint i = 0; i < reshapeArgs.size(); ++i) {
|
||||
|
||||
// uint shapeLength = 1, numOfZeros = 0;
|
||||
// if (reshapeArgs[i] == -1) {
|
||||
|
||||
// for(uint j = 0; j < i; ++j)
|
||||
// if(reshapeArgs[j] != 0)
|
||||
// shapeLength *= reshapeArgs[j];
|
||||
// else
|
||||
// ++numOfZeros;
|
||||
// uint shapeLength = 1, numOfZeros = 0;
|
||||
|
||||
// for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
// if(reshapeArgs[j] != 0)
|
||||
// shapeLength *= reshapeArgs[j];
|
||||
// else
|
||||
// ++numOfZeros;
|
||||
// }
|
||||
// for(uint j = 0; j < i; ++j)
|
||||
// if(reshapeArgs[j] != 0)
|
||||
// shapeLength *= reshapeArgs[j];
|
||||
// else
|
||||
// ++numOfZeros;
|
||||
|
||||
// const auto dim = xLen / shapeLength;
|
||||
// for(uint j = i + 1; j < reshapeArgs.size(); ++j) {
|
||||
// REQUIRE_TRUE(reshapeArgs[j] != -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
// if(reshapeArgs[j] != 0)
|
||||
// shapeLength *= reshapeArgs[j];
|
||||
// else
|
||||
// ++numOfZeros;
|
||||
// }
|
||||
|
||||
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||
// shapeNew.push_back(0);
|
||||
// else
|
||||
// shapeNew.push_back(dim);
|
||||
// }
|
||||
// else
|
||||
// shapeNew.push_back(reshapeArgs[i]);
|
||||
// }
|
||||
// const auto dim = xLen / shapeLength;
|
||||
|
||||
Nd4jLong newShapeLen = 1;
|
||||
int pos = -1;
|
||||
bool newShapeEmpty = false;
|
||||
// if(x->isEmpty() && (1 == dim || 0 == numOfZeros))
|
||||
// shapeNew.push_back(0);
|
||||
// else
|
||||
// shapeNew.push_back(dim);
|
||||
// }
|
||||
// else
|
||||
// shapeNew.push_back(reshapeArgs[i]);
|
||||
// }
|
||||
|
||||
for (int i = 0; i < reshapeArgs.size(); ++i) {
|
||||
Nd4jLong newShapeLen = 1;
|
||||
int pos = -1;
|
||||
bool newShapeEmpty = false;
|
||||
|
||||
for (int i = 0; i < reshapeArgs.size(); ++i) {
|
||||
|
||||
const int dim = reshapeArgs[i];
|
||||
|
||||
if (dim == -1) {
|
||||
REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
pos = i;
|
||||
shapeNew.push_back(1);
|
||||
REQUIRE_TRUE(pos == -1, 0, "Reshape : Only one unknown dimension (-1) is allowed.");
|
||||
pos = i;
|
||||
shapeNew.push_back(1);
|
||||
}
|
||||
else if (dim == 0) {
|
||||
shapeNew.push_back(0);
|
||||
newShapeEmpty = true;
|
||||
shapeNew.push_back(0);
|
||||
newShapeEmpty = true;
|
||||
}
|
||||
else {
|
||||
shapeNew.push_back(dim);
|
||||
newShapeLen *= dim;
|
||||
shapeNew.push_back(dim);
|
||||
newShapeLen *= dim;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (pos != -1) {
|
||||
if (pos != -1) {
|
||||
|
||||
Nd4jLong xLen = x->lengthOf();
|
||||
if(x->isEmpty()) {
|
||||
xLen = 1;
|
||||
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||
if(x->sizeAt(i) > 0 || !newShapeEmpty)
|
||||
xLen *= x->sizeAt(i);
|
||||
xLen = 1;
|
||||
for (uint i = 0; i < x->rankOf(); ++i) // take into account possible empty shapes
|
||||
if(x->sizeAt(i) > 0 || !newShapeEmpty)
|
||||
xLen *= x->sizeAt(i);
|
||||
}
|
||||
|
||||
shapeNew[pos] = xLen / newShapeLen;
|
||||
}
|
||||
}
|
||||
|
||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||
REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||
auto len = shape::prodLong(shapeNew.data(), shapeNew.size());
|
||||
REQUIRE_TRUE(x->lengthOf() == len, 0, "Reshape: lengths before and after reshape should match, but got %i vs %i", x->lengthOf(), len);
|
||||
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(x->dataType(), orderNew, shapeNew));
|
||||
}
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(x->dataType(), orderNew, shapeNew));
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
|
@ -209,6 +209,7 @@
|
|||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<version>3.2.0</version>
|
||||
<configuration>
|
||||
<archive>
|
||||
<manifest>
|
||||
|
|
|
@ -28,17 +28,23 @@ import lombok.val;
|
|||
import onnx.Onnx;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
|
||||
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
|
||||
import org.nd4j.imports.descriptors.properties.PropertyMapping;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.Op;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.indexing.functions.Zero;
|
||||
import org.nd4j.shade.jackson.annotation.JsonIgnore;
|
||||
import org.nd4j.weightinit.impl.ZeroInitScheme;
|
||||
import org.tensorflow.framework.AttrValue;
|
||||
import org.tensorflow.framework.GraphDef;
|
||||
import org.tensorflow.framework.NodeDef;
|
||||
|
@ -46,6 +52,8 @@ import org.tensorflow.framework.NodeDef;
|
|||
import java.lang.reflect.Field;
|
||||
import java.util.*;
|
||||
|
||||
import static org.nd4j.imports.VariableUtils.stripVarSuffix;
|
||||
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
|
@ -261,7 +269,7 @@ public abstract class DifferentialFunction {
|
|||
target.set(o, value);
|
||||
} catch (IllegalAccessException e){
|
||||
throw new RuntimeException("Error setting configuration field \"" + propertyName + "\" for config field \"" + propertyName
|
||||
+ "\" on class " + getClass().getName());
|
||||
+ "\" on class " + getClass().getName());
|
||||
}
|
||||
|
||||
} else {
|
||||
|
@ -549,7 +557,7 @@ public abstract class DifferentialFunction {
|
|||
public String[] argNames(){
|
||||
SDVariable[] args = args();
|
||||
String[] out = new String[args.length];
|
||||
for( int i=0; i<args.length; i++ ){
|
||||
for( int i = 0; i < args.length; i++ ){
|
||||
out[i] = args[i].name();
|
||||
}
|
||||
return out;
|
||||
|
@ -710,7 +718,7 @@ public abstract class DifferentialFunction {
|
|||
* Duplicate this function
|
||||
* @return
|
||||
*/
|
||||
public DifferentialFunction dup() {
|
||||
public DifferentialFunction dup() {
|
||||
return FlatBuffersMapper.cloneViaSerialize(sameDiff, this);
|
||||
}
|
||||
|
||||
|
|
|
@ -46,7 +46,6 @@ public class SDVariable implements Serializable {
|
|||
protected SameDiff sameDiff;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
protected String varName;
|
||||
@Getter
|
||||
@Setter
|
||||
|
@ -84,6 +83,10 @@ public class SDVariable implements Serializable {
|
|||
return varName;
|
||||
}
|
||||
|
||||
public void setVarName(String varName) {
|
||||
this.varName = varName;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #name()}
|
||||
*/
|
||||
|
|
|
@ -1097,13 +1097,16 @@ public class SameDiff extends SDBaseOps {
|
|||
ops.get(function.getOwnName()).setInputsToOp(Arrays.asList(variables)); //Duplicate variables OK/required here
|
||||
|
||||
for (String variableName : variables) {
|
||||
List<String> funcs = this.variables.get(variableName).getInputsForOp();
|
||||
if (funcs == null) {
|
||||
funcs = new ArrayList<>();
|
||||
this.variables.get(variableName).setInputsForOp(funcs);
|
||||
if(this.variables.containsKey(variableName)) {
|
||||
List<String> funcs = this.variables.get(variableName).getInputsForOp();
|
||||
if (funcs == null) {
|
||||
funcs = new ArrayList<>();
|
||||
this.variables.get(variableName).setInputsForOp(funcs);
|
||||
}
|
||||
if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names.
|
||||
funcs.add(function.getOwnName());
|
||||
}
|
||||
if (!funcs.contains(function.getOwnName())) //Avoid duplicates for function names.
|
||||
funcs.add(function.getOwnName());
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2541,7 +2544,7 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
validateListenerActivations(activeListeners, operation);
|
||||
|
||||
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.<String>emptyList(), activeListeners, outputs);
|
||||
Map<String, INDArray> ret = directExecHelper(placeholders, At.defaultAt(operation), null, Collections.emptyList(), activeListeners, outputs);
|
||||
|
||||
for (Listener l : activeListeners) {
|
||||
l.operationEnd(this, operation);
|
||||
|
@ -3316,16 +3319,18 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
/**
|
||||
* Rename the specified variable to the new name.
|
||||
*
|
||||
* Note here we also specify the op.
|
||||
* Sometimes, ops have multiple outputs and after the first rename of the variable
|
||||
* we lose the reference to the correct op to modify.
|
||||
* @param opToReName the op to rename
|
||||
* @param from The variable to rename - this variable must exist
|
||||
* @param to The new name for the variable - no variable with this name must already exist
|
||||
*/
|
||||
public void renameVariable(String from, String to) {
|
||||
public void renameVariable(SameDiffOp opToReName,String from, String to) {
|
||||
Preconditions.checkState(variables.containsKey(from), "Cannot rename variable \"%s\": no variable with this name exists", from);
|
||||
Preconditions.checkState(!variables.containsKey(to), "Cannot rename variable \"%s\" to name \"%s\": a variable with name \"%s\" already exists", from, to, to);
|
||||
|
||||
Variable v = variables.get(from);
|
||||
SameDiffOp opToReName = ops.get(stripVarSuffix(from));
|
||||
v.setName(to);
|
||||
v.getVariable().setVarName(to);
|
||||
if (v.getInputsForOp() != null) {
|
||||
|
@ -3374,52 +3379,16 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
|
||||
if (v.getOutputOfOp() != null) {
|
||||
SameDiffOp op = ops.get(stripVarSuffix(from));
|
||||
if(op != null && op.getOutputsOfOp() != null) {
|
||||
List<String> newOuts = new ArrayList<>(op.getOutputsOfOp());
|
||||
while (newOuts.contains(from)) {
|
||||
newOuts.set(newOuts.indexOf(from), to);
|
||||
}
|
||||
|
||||
//find other aliases and ensure those get updated as well,
|
||||
//after this any other versions of the op may not be discoverable
|
||||
//due to the renaming
|
||||
String strippedVarSuffix = stripVarSuffix(from);
|
||||
for(int i = 0; i < newOuts.size(); i++) {
|
||||
String newOut = newOuts.get(i);
|
||||
if(stripVarSuffix(newOut).equals(strippedVarSuffix)) {
|
||||
val idx = newOut.lastIndexOf(':');
|
||||
val newString = to + newOut.substring(idx);
|
||||
newOuts.set(i,newString);
|
||||
}
|
||||
}
|
||||
|
||||
op.setOutputsOfOp(newOuts);
|
||||
}
|
||||
else if(op != null) {
|
||||
op.setOutputsOfOp(Arrays.asList(to));
|
||||
SameDiffOp op = ops.get(v.getOutputOfOp());
|
||||
List<String> newOuts = new ArrayList<>(op.getOutputsOfOp());
|
||||
while (newOuts.contains(from)) {
|
||||
newOuts.set(newOuts.indexOf(from), to);
|
||||
}
|
||||
op.setOutputsOfOp(newOuts);
|
||||
}
|
||||
|
||||
variables.remove(from);
|
||||
variables.put(to, v);
|
||||
//set as just op name, update to set as the name of the output
|
||||
if(opToReName != null && opToReName.getOp() != null && opToReName.getOp().isOwnNameSetWithDefault()) {
|
||||
ops.remove(from);
|
||||
opToReName.getOp().setOwnName(to);
|
||||
ops.put(to,opToReName);
|
||||
opToReName.setName(to);
|
||||
}
|
||||
|
||||
for(Variable variable : variables.values()) {
|
||||
if(variable.getInputsForOp() != null && variable.getInputsForOp().contains(from)) {
|
||||
variable.getInputsForOp().set(variable.getInputsForOp().indexOf(from),to);
|
||||
}
|
||||
|
||||
if(variable.getOutputOfOp() != null && variable.getOutputOfOp().equals(from)) {
|
||||
variable.setOutputOfOp(to);
|
||||
}
|
||||
}
|
||||
|
||||
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)) {
|
||||
constantArrays.rename(from, to);
|
||||
|
@ -3496,6 +3465,18 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* Rename the specified variable to the new name.
|
||||
*
|
||||
* @param from The variable to rename - this variable must exist
|
||||
* @param to The new name for the variable - no variable with this name must already exist
|
||||
*/
|
||||
public void renameVariable(String from, String to) {
|
||||
SameDiffOp op = ops.get(stripVarSuffix(from));
|
||||
renameVariable(op,from,to);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Remove an argument for a function. Note that if this function does not contain the argument, it will just be a no op.
|
||||
*
|
||||
|
@ -4557,6 +4538,7 @@ public class SameDiff extends SDBaseOps {
|
|||
associateSameDiffWithOpsAndVariables();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Try to infer the loss variable/s (usually loss variables). Note that this is not reliable in general.
|
||||
*/
|
||||
|
@ -4604,16 +4586,19 @@ public class SameDiff extends SDBaseOps {
|
|||
return variables.get(varName).getVariable().isPlaceHolder();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.
|
||||
* <p>
|
||||
* Note that if null for the new variable is passed in, it will just return the original input variable.
|
||||
*
|
||||
* @param opToRename note we pass in the op here for times when an op may have multiple outputs
|
||||
* when this is the case, we need to pass in the op to rename otherwise context gets lost
|
||||
* and subsequent rename attempts will not operate on the op.
|
||||
* @param varToUpdate the variable to update
|
||||
* @param newVarName the new variable name
|
||||
* @return the passed in variable
|
||||
*/
|
||||
public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
|
||||
public SDVariable updateVariableNameAndReference(SameDiffOp opToRename,SDVariable varToUpdate, String newVarName) {
|
||||
if (varToUpdate == null) {
|
||||
throw new NullPointerException("Null input: No variable found for updating!");
|
||||
}
|
||||
|
@ -4644,10 +4629,24 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
val oldVarName = varToUpdate.name();
|
||||
varToUpdate.setVarName(newVarName);
|
||||
renameVariable(oldVarName, newVarName);
|
||||
renameVariable(opToRename,oldVarName, newVarName);
|
||||
return varToUpdate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.
|
||||
* <p>
|
||||
* Note that if null for the new variable is passed in, it will just return the original input variable.
|
||||
*
|
||||
* @param varToUpdate the variable to update
|
||||
* @param newVarName the new variable name
|
||||
* @return the passed in variable
|
||||
*/
|
||||
public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) {
|
||||
SameDiffOp op = ops.get(varToUpdate.name());
|
||||
return updateVariableNameAndReference(op,varToUpdate,newVarName);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable.
|
||||
|
@ -4791,7 +4790,7 @@ public class SameDiff extends SDBaseOps {
|
|||
val flatNodes = new ArrayList<Integer>();
|
||||
|
||||
// first of all we build VariableSpace dump
|
||||
val variableList = new ArrayList<SDVariable>(variables());
|
||||
val variableList = new ArrayList<>(variables());
|
||||
val reverseMap = new LinkedHashMap<String, Integer>();
|
||||
val forwardMap = new LinkedHashMap<String, Integer>();
|
||||
val framesMap = new LinkedHashMap<String, Integer>();
|
||||
|
@ -5244,18 +5243,18 @@ public class SameDiff extends SDBaseOps {
|
|||
Variable v2 = sd.variables.get(n);
|
||||
|
||||
//Reconstruct control dependencies
|
||||
if(v.controlDepsLength() > 0){
|
||||
if(v.controlDepsLength() > 0) {
|
||||
int num = v.controlDepsLength();
|
||||
List<String> l = new ArrayList<>(num);
|
||||
for( int i=0; i<num; i++ ){
|
||||
for(int i = 0; i < num; i++) {
|
||||
l.add(v.controlDeps(i));
|
||||
}
|
||||
v2.setControlDeps(l);
|
||||
}
|
||||
if(v.controlDepForOpLength() > 0){
|
||||
if(v.controlDepForOpLength() > 0) {
|
||||
int num = v.controlDepForOpLength();
|
||||
List<String> l = new ArrayList<>(num);
|
||||
for( int i=0; i<num; i++ ){
|
||||
for( int i = 0; i < num; i++) {
|
||||
l.add(v.controlDepForOp(i));
|
||||
}
|
||||
v2.setControlDepsForOp(l);
|
||||
|
@ -5264,7 +5263,7 @@ public class SameDiff extends SDBaseOps {
|
|||
if(v.controlDepsForVarLength() > 0) {
|
||||
int num = v.controlDepsForVarLength();
|
||||
List<String> l = new ArrayList<>(num);
|
||||
for( int i = 0; i < num; i++ ){
|
||||
for(int i = 0; i < num; i++) {
|
||||
l.add(v.controlDepsForVar(i));
|
||||
}
|
||||
v2.setControlDepsForVar(l);
|
||||
|
|
|
@ -35,6 +35,8 @@ import org.nd4j.common.function.Predicate;
|
|||
|
||||
import java.util.*;
|
||||
|
||||
import static org.nd4j.imports.VariableUtils.stripVarSuffix;
|
||||
|
||||
@Slf4j
|
||||
public abstract class AbstractSession<T, O> {
|
||||
|
||||
|
@ -156,7 +158,6 @@ public abstract class AbstractSession<T, O> {
|
|||
|
||||
//Step 2: Check that we have required placeholders
|
||||
List<String> phNames = sameDiff.inputs();
|
||||
log.info("Placeholder names were " + phNames);
|
||||
if (placeholderValues == null || !placeholderValues.keySet().containsAll(phNames)) {
|
||||
/* We only have a subset of all placeholders
|
||||
Validate that we have all *required* placeholder values. Some might not be needed to calculate the requested outputs
|
||||
|
@ -184,15 +185,9 @@ public abstract class AbstractSession<T, O> {
|
|||
}
|
||||
|
||||
if (required && (placeholderValues == null || !placeholderValues.containsKey(s))) {
|
||||
if(placeholderValues != null)
|
||||
throw new IllegalStateException(
|
||||
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
||||
" but a placeholder value was not provided. Placeholders specified were " + placeholderValues.keySet());
|
||||
else {
|
||||
throw new IllegalStateException(
|
||||
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
||||
" but a placeholder value was not provided. Place holder values were null! ");
|
||||
}
|
||||
throw new IllegalStateException(
|
||||
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
||||
" but a placeholder value was not provided");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -282,7 +277,7 @@ public abstract class AbstractSession<T, O> {
|
|||
log.trace("Beginning execution step {}: {}", step, es);
|
||||
|
||||
FrameIter outFrameIter;
|
||||
boolean skipDepUpdate = false; //Only used for Switch ops, which have slightly different handling...
|
||||
boolean skipDepUpdate = false; //Only used for Switch ops, which have slighly different handling...
|
||||
boolean skipMarkSatisfied = false; //Only for enter ops, because of different frame/iter
|
||||
if (es.getType() == ExecType.CONSTANT || es.getType() == ExecType.VARIABLE) {
|
||||
VarId vid = new VarId(es.getName(), OUTER_FRAME, 0, null);
|
||||
|
@ -514,7 +509,7 @@ public abstract class AbstractSession<T, O> {
|
|||
* Execution failed - can't calculate all requested outputs, and there's nothing left to calculate.
|
||||
* Throws an exception with a useful message
|
||||
*
|
||||
* @param userRequestedUnique All outputs that the user reqeseted
|
||||
* @param userRequestedUnique All outputs that the user requested
|
||||
* @param out Current outputs
|
||||
* @param step Execution step
|
||||
*/
|
||||
|
@ -544,7 +539,7 @@ public abstract class AbstractSession<T, O> {
|
|||
}
|
||||
}
|
||||
String s = sb.toString();
|
||||
// System.out.println(sameDiff.summary());
|
||||
System.out.println(sameDiff.summary());
|
||||
throw new IllegalStateException(s);
|
||||
}
|
||||
|
||||
|
@ -564,6 +559,52 @@ public abstract class AbstractSession<T, O> {
|
|||
List<String> outNames = op.getOutputsOfOp();
|
||||
for (String s : outNames) {
|
||||
Variable v = sameDiff.getVariables().get(s);
|
||||
if(v != null) {
|
||||
List<String> inputsToOps = v.getInputsForOp();
|
||||
if (inputsToOps != null) {
|
||||
for (String opName : inputsToOps) {
|
||||
if (subgraphOps.contains(opName)) {
|
||||
//We've just executed X, and there's dependency X -> Y
|
||||
//But, there also might be a Z -> Y that we should mark as needed for Y
|
||||
addDependenciesForOp(opName, outFrameIter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//Also add control dependencies (variable)
|
||||
List<String> cdForOps = v.getControlDepsForOp();
|
||||
if (cdForOps != null) {
|
||||
for (String opName : cdForOps) {
|
||||
if (subgraphOps.contains(opName)) {
|
||||
//We've just executed X, and there's dependency X -> Y
|
||||
//But, there also might be a Z -> Y that we should mark as needed for Y
|
||||
addDependenciesForOp(opName, outFrameIter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) {
|
||||
Variable v = sameDiff.getVariables().get(n);
|
||||
if(v != null) {
|
||||
List<String> inputsToOps = v.getInputsForOp();
|
||||
if (inputsToOps != null) {
|
||||
for (String opName : inputsToOps) {
|
||||
if (subgraphOps.contains(opName)) {
|
||||
addDependenciesForOp(opName, outFrameIter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) {
|
||||
SameDiffOp op = sameDiff.getOps().get(n);
|
||||
List<String> outNames = op.getOutputsOfOp();
|
||||
String branchVarName = (justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1));
|
||||
Variable v = sameDiff.getVariables().get(branchVarName);
|
||||
if(v != null) {
|
||||
List<String> inputsToOps = v.getInputsForOp();
|
||||
if (inputsToOps != null) {
|
||||
for (String opName : inputsToOps) {
|
||||
|
@ -574,45 +615,8 @@ public abstract class AbstractSession<T, O> {
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//Also add control dependencies (variable)
|
||||
List<String> cdForOps = v.getControlDepsForOp();
|
||||
if (cdForOps != null) {
|
||||
for (String opName : cdForOps) {
|
||||
if (subgraphOps.contains(opName)) {
|
||||
//We've just executed X, and there's dependency X -> Y
|
||||
//But, there also might be a Z -> Y that we should mark as needed for Y
|
||||
addDependenciesForOp(opName, outFrameIter);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (t == ExecType.VARIABLE || t == ExecType.CONSTANT || t == ExecType.PLACEHOLDER) {
|
||||
Variable v = sameDiff.getVariables().get(n);
|
||||
List<String> inputsToOps = v.getInputsForOp();
|
||||
if (inputsToOps != null) {
|
||||
for (String opName : inputsToOps) {
|
||||
if (subgraphOps.contains(opName)) {
|
||||
addDependenciesForOp(opName, outFrameIter);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (justExecuted.getType() == ExecType.SWITCH_L || justExecuted.getType() == ExecType.SWITCH_R) {
|
||||
SameDiffOp op = sameDiff.getOps().get(n);
|
||||
List<String> outNames = op.getOutputsOfOp();
|
||||
String branchVarName = (justExecuted.getType() == ExecType.SWITCH_L ? outNames.get(0) : outNames.get(1));
|
||||
Variable v = sameDiff.getVariables().get(branchVarName);
|
||||
List<String> inputsToOps = v.getInputsForOp();
|
||||
if (inputsToOps != null) {
|
||||
for (String opName : inputsToOps) {
|
||||
if (subgraphOps.contains(opName)) {
|
||||
//We've just executed X, and there's dependency X -> Y
|
||||
//But, there also might be a Z -> Y that we should mark as needed for Y
|
||||
addDependenciesForOp(opName, outFrameIter);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Unknown or not yet implemented exec type: " + justExecuted);
|
||||
}
|
||||
|
@ -691,8 +695,17 @@ public abstract class AbstractSession<T, O> {
|
|||
return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
} else {
|
||||
//Array type. Must be output of an op
|
||||
if(v.getOutputOfOp() == null) {
|
||||
v = sameDiff.getVariables().get(stripVarSuffix(v.getName()));
|
||||
}
|
||||
|
||||
String outOfOp = v.getOutputOfOp();
|
||||
SameDiffOp sdo = sameDiff.getOps().get(outOfOp);
|
||||
|
||||
if(sdo == null) {
|
||||
throw new IllegalStateException("Samediff output op named " + v.getName() + " did not have any ops associated with it.");
|
||||
}
|
||||
|
||||
if (sdo.getOp() instanceof Switch) {
|
||||
//For dependency tracking purposes, we track left and right output branches of switch op separately
|
||||
//Otherwise, ops depending both branches will be marked as available if we just rely on "op has been executed"
|
||||
|
@ -771,7 +784,11 @@ public abstract class AbstractSession<T, O> {
|
|||
|
||||
if (!subgraph.contains(varName)) {
|
||||
String[] opInputs = opName == null ? null : sameDiff.getInputsForOp(sameDiff.getOpById(opName));
|
||||
List<String> controlDeps = sameDiff.getVariables().get(varName).getControlDeps();
|
||||
Variable currVar = sameDiff.getVariables().get(varName);
|
||||
log.trace("Adding " + varName + " to subgraph for output.");
|
||||
List<String> opInputsFor = currVar.getInputsForOp();
|
||||
List<String> controlDeps = currVar.getControlDeps();
|
||||
String output = currVar.getOutputOfOp();
|
||||
int numInputs = (opInputs == null ? 0 : opInputs.length);
|
||||
if (controlDeps != null) {
|
||||
//Also count variable control dependencies as inputs - even a constant may not be available for use
|
||||
|
@ -781,6 +798,8 @@ public abstract class AbstractSession<T, O> {
|
|||
if (numInputs == 0 && opName != null) {
|
||||
zeroInputOpsInSubgraph.add(opName);
|
||||
}
|
||||
|
||||
|
||||
subgraph.add(varName);
|
||||
|
||||
if (opName != null) {
|
||||
|
@ -796,11 +815,14 @@ public abstract class AbstractSession<T, O> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
if (opName != null) {
|
||||
//To execute op - and hence get this variable: need inputs to that op
|
||||
String[] inputs = sameDiff.getInputsForOp(sameDiff.getOpById(opName));
|
||||
DifferentialFunction opById = sameDiff.getOpById(opName);
|
||||
String[] inputs = sameDiff.getInputsForOp(opById);
|
||||
for (String s2 : inputs) {
|
||||
if (!subgraph.contains(s2)) {
|
||||
processingQueue.add(s2);
|
||||
|
|
|
@ -263,7 +263,7 @@ public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp,
|
|||
continue; //Switch case: we only ever get one of 2 outputs, other is null (branch not executed)
|
||||
|
||||
String name = outVarNames.get(i);
|
||||
Variable v = sameDiff.getVariables().get(stripVarSuffix(name));
|
||||
Variable v = sameDiff.getVariables().get(name);
|
||||
List<String> inputsForOps = v.getInputsForOp();
|
||||
if (inputsForOps != null) {
|
||||
for (String opName : inputsForOps) {
|
||||
|
@ -799,9 +799,9 @@ public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp,
|
|||
//Might be due to repeated inputs
|
||||
Set<String> uniqueArgNames = new HashSet<>();
|
||||
Collections.addAll(uniqueArgNames, argNames);
|
||||
Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters),
|
||||
/* Preconditions.checkState(uniqueArgNames.size() == (numNonConstIns + numConstPhIns + numNonConstInsAllIters),
|
||||
"Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(),
|
||||
opName, uniqueArgNames, opInputs, constAndPhInputs);
|
||||
opName, uniqueArgNames, opInputs, constAndPhInputs);*/
|
||||
} else {
|
||||
Preconditions.checkState(numArgs == (numNonConstIns + numConstPhIns),
|
||||
"Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", df.getClass().getSimpleName(),
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
@ -29,9 +28,7 @@ import org.nd4j.autodiff.functions.DifferentialFunction;
|
|||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Builder
|
||||
public class SameDiffOp {
|
||||
protected String name;
|
||||
protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set)
|
||||
|
@ -40,4 +37,71 @@ public class SameDiffOp {
|
|||
protected List<String> controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec)
|
||||
protected List<String> varControlDeps; //Variables (constants, placeholders, etc) that are control dependencies for this op
|
||||
protected List<String> controlDepFor; //Name of the variables that this op is a control dependency for
|
||||
|
||||
@Builder
|
||||
public SameDiffOp(String name, DifferentialFunction op, List<String> inputsToOp, List<String> outputsOfOp, List<String> controlDeps, List<String> varControlDeps, List<String> controlDepFor) {
|
||||
this.name = name;
|
||||
this.op = op;
|
||||
this.inputsToOp = inputsToOp;
|
||||
this.outputsOfOp = outputsOfOp;
|
||||
this.controlDeps = controlDeps;
|
||||
this.varControlDeps = varControlDeps;
|
||||
this.controlDepFor = controlDepFor;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public DifferentialFunction getOp() {
|
||||
return op;
|
||||
}
|
||||
|
||||
public void setOp(DifferentialFunction op) {
|
||||
this.op = op;
|
||||
}
|
||||
|
||||
public List<String> getInputsToOp() {
|
||||
return inputsToOp;
|
||||
}
|
||||
|
||||
public void setInputsToOp(List<String> inputsToOp) {
|
||||
this.inputsToOp = inputsToOp;
|
||||
}
|
||||
|
||||
public List<String> getOutputsOfOp() {
|
||||
return outputsOfOp;
|
||||
}
|
||||
|
||||
public void setOutputsOfOp(List<String> outputsOfOp) {
|
||||
this.outputsOfOp = outputsOfOp;
|
||||
}
|
||||
|
||||
public List<String> getControlDeps() {
|
||||
return controlDeps;
|
||||
}
|
||||
|
||||
public void setControlDeps(List<String> controlDeps) {
|
||||
this.controlDeps = controlDeps;
|
||||
}
|
||||
|
||||
public List<String> getVarControlDeps() {
|
||||
return varControlDeps;
|
||||
}
|
||||
|
||||
public void setVarControlDeps(List<String> varControlDeps) {
|
||||
this.varControlDeps = varControlDeps;
|
||||
}
|
||||
|
||||
public List<String> getControlDepFor() {
|
||||
return controlDepFor;
|
||||
}
|
||||
|
||||
public void setControlDepFor(List<String> controlDepFor) {
|
||||
this.controlDepFor = controlDepFor;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -147,16 +147,16 @@ public class GradCheckUtil {
|
|||
listenerIdx = 0;
|
||||
} else {
|
||||
boolean found = false;
|
||||
int i=0;
|
||||
for(Listener l : listenersBefore){
|
||||
if(l instanceof NonInplaceValidationListener){
|
||||
int i = 0;
|
||||
for(Listener l : listenersBefore) {
|
||||
if(l instanceof NonInplaceValidationListener) {
|
||||
found = true;
|
||||
listenerIdx = i;
|
||||
break;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
if(!found){
|
||||
if(!found) {
|
||||
sd.addListeners(new NonInplaceValidationListener());
|
||||
listenerIdx = i;
|
||||
}
|
||||
|
@ -199,7 +199,7 @@ public class GradCheckUtil {
|
|||
int totalCount = 0;
|
||||
double maxError = 0.0;
|
||||
Random r = new Random(12345);
|
||||
for(SDVariable s : sd.variables()){
|
||||
for(SDVariable s : sd.variables()) {
|
||||
if (fnOutputs.contains(s.name()) || !s.dataType().isFPType()) {
|
||||
//This is not an input to the graph, or is not a floating point input (so can't be gradient checked)
|
||||
continue;
|
||||
|
@ -210,7 +210,7 @@ public class GradCheckUtil {
|
|||
continue;
|
||||
}
|
||||
|
||||
if(s.dataType() != DataType.DOUBLE){
|
||||
if(s.dataType() != DataType.DOUBLE) {
|
||||
log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.name(), s.dataType());
|
||||
}
|
||||
|
||||
|
@ -222,7 +222,7 @@ public class GradCheckUtil {
|
|||
}
|
||||
|
||||
Iterator<long[]> iter;
|
||||
if(maxPerParam > 0 && subset != null && maxPerParam < a.length()){
|
||||
if(maxPerParam > 0 && subset != null && maxPerParam < a.length()) {
|
||||
//Subset case
|
||||
long[] shape = a.shape();
|
||||
List<long[]> l = new ArrayList<>();
|
||||
|
@ -243,7 +243,7 @@ public class GradCheckUtil {
|
|||
//Every N
|
||||
long everyN = n / maxPerParam;
|
||||
long curr = 0;
|
||||
while(curr < n){
|
||||
while(curr < n) {
|
||||
long[] pos = Shape.ind2subC(shape, curr);
|
||||
l.add(pos);
|
||||
curr += everyN;
|
||||
|
@ -262,8 +262,8 @@ public class GradCheckUtil {
|
|||
Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.name(), varMask.dataType());
|
||||
}
|
||||
|
||||
int i=0;
|
||||
while(iter.hasNext()){
|
||||
int i = 0;
|
||||
while(iter.hasNext()) {
|
||||
long[] idx = iter.next();
|
||||
String strIdx = null;
|
||||
if(print){
|
||||
|
@ -593,7 +593,7 @@ public class GradCheckUtil {
|
|||
|
||||
Set<String> varSetStr = new HashSet<>();
|
||||
for(SDVariable v : vars){
|
||||
if(varSetStr.contains(v.name())){
|
||||
if(varSetStr.contains(v.name())) {
|
||||
throw new IllegalStateException("Variable with name " + v.name() + " already encountered");
|
||||
}
|
||||
varSetStr.add(v.name());
|
||||
|
@ -605,15 +605,15 @@ public class GradCheckUtil {
|
|||
Preconditions.checkState(dfs.length == ops.size(), "All functions not present in incomingArgsReverse");
|
||||
for(DifferentialFunction df : dfs){
|
||||
Preconditions.checkState(ops.containsKey(df.getOwnName()), df.getOwnName() + " not present in ops map");
|
||||
|
||||
List<String> str = ops.get(df.getOwnName()).getInputsToOp();
|
||||
SameDiffOp sameDiffOp = ops.get(df.getOwnName());
|
||||
List<String> str = sameDiffOp.getInputsToOp();
|
||||
if(str != null) {
|
||||
for (String s : str) {
|
||||
Preconditions.checkState(varSetStr.contains(s), "Variable " + s + " in op inputs not a known variable name");
|
||||
}
|
||||
}
|
||||
|
||||
str = ops.get(df.getOwnName()).getOutputsOfOp();
|
||||
str = sameDiffOp.getOutputsOfOp();
|
||||
if(str != null) {
|
||||
for (String s : str) {
|
||||
Preconditions.checkState(varSetStr.contains(s), "Variable " + s + " in op outputs not a known variable name");
|
||||
|
|
|
@ -145,12 +145,12 @@ public class OpValidation {
|
|||
|
||||
SameDiff sameDiff = testCase.sameDiff();
|
||||
List<Listener> listeners = sameDiff.getListeners();
|
||||
if(listeners.isEmpty()){
|
||||
if(listeners.isEmpty()) {
|
||||
sameDiff.addListeners(new NonInplaceValidationListener());
|
||||
} else {
|
||||
boolean found = false;
|
||||
for(Listener l : listeners){
|
||||
if(l instanceof NonInplaceValidationListener){
|
||||
if(l instanceof NonInplaceValidationListener) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -1,22 +1,5 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
// Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
// source: mapper.proto
|
||||
|
||||
package org.nd4j.ir;
|
||||
|
||||
|
|
|
@ -1,22 +1,5 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
// Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
// source: op.proto
|
||||
|
||||
package org.nd4j.ir;
|
||||
|
||||
|
|
|
@ -1,22 +1,5 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
// Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
// source: tensor.proto
|
||||
|
||||
package org.nd4j.ir;
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ public class MmulBp extends DynamicCustomOp {
|
|||
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||
Preconditions.checkState(dataTypes != null && dataTypes.size() == 3, "Expected exactly 3 inputs to matmul_bp op, got %s", dataTypes);
|
||||
Preconditions.checkState(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType() && dataTypes.get(0).isFPType(), "Inputs to matmul_bp op must both be a floating" +
|
||||
"point type: got %s", dataTypes);
|
||||
|
|
|
@ -49,6 +49,11 @@ public class Reshape extends DynamicCustomOp {
|
|||
public Reshape(SameDiff sameDiff, SDVariable i_v, long[] shape) {
|
||||
super(null, sameDiff, new SDVariable[]{i_v});
|
||||
this.shape = shape;
|
||||
//c ordering: see (char) 99 for c ordering and (char) 'f' is 102
|
||||
//note it has to be negative for the long array case only
|
||||
//to flag the difference between an ordering being specified
|
||||
//and a dimension.
|
||||
addIArgument(-99);
|
||||
addIArgument(shape);
|
||||
}
|
||||
|
||||
|
@ -59,6 +64,11 @@ public class Reshape extends DynamicCustomOp {
|
|||
public Reshape(INDArray in, long... shape) {
|
||||
super(new INDArray[]{in}, null);
|
||||
this.shape = shape;
|
||||
//c ordering: see (char) 99 for c ordering and (char) 'f' is 102
|
||||
//note it has to be negative for the long array case only
|
||||
//to flag the difference between an ordering being specified
|
||||
//and a dimension.
|
||||
addIArgument(-99);
|
||||
addIArgument(shape);
|
||||
}
|
||||
|
||||
|
|
|
@ -1725,7 +1725,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
for( int i = 0; i < result.size(); i++) {
|
||||
arr[i] = result.get(i).toString();
|
||||
}
|
||||
log.trace("Calculated output shapes for op {} - {}", op.getClass().getName(), Arrays.toString(arr));
|
||||
|
||||
DifferentialFunction differentialFunction = (DifferentialFunction) op;
|
||||
log.trace("Calculated output shapes for op of name {} and type {} - {}",differentialFunction.getOwnName(), op.getClass().getName(), Arrays.toString(arr));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -2022,7 +2024,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, long extras) {
|
||||
val dbf = loop.shapeBufferEx(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, extras);
|
||||
OpaqueConstantShapeBuffer dbf = loop.shapeBufferEx(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, extras);
|
||||
if (loop.lastErrorCode() != 0)
|
||||
throw new RuntimeException(loop.lastErrorMessage());
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,3 +1,19 @@
|
|||
alpha
|
||||
Sum/reduction_indices
|
||||
Sum
|
||||
in_0
|
||||
while/Const
|
||||
while/add/y
|
||||
in_0/read
|
||||
while/Enter
|
||||
while/Enter_1
|
||||
while/Merge
|
||||
while/Merge_1
|
||||
while/Less
|
||||
while/LoopCond
|
||||
while/Switch
|
||||
while/Switch_1
|
||||
while/Identity
|
||||
while/Exit
|
||||
while/Identity_1
|
||||
while/Exit_1
|
||||
while/add
|
||||
while/NextIteration_1
|
||||
while/NextIteration
|
||||
|
|
|
@ -121,6 +121,7 @@ import java.util.Set;
|
|||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
@Ignore("No longer relevant after model import rewrite.")
|
||||
public class TestOpMapping extends BaseNd4jTest {
|
||||
|
||||
Set<Class<? extends DifferentialFunction>> subTypes;
|
||||
|
@ -151,7 +152,7 @@ public class TestOpMapping extends BaseNd4jTest {
|
|||
Map<String, DifferentialFunction> onnxOpNameMapping = ImportClassMapping.getOnnxOpMappingFunctions();
|
||||
|
||||
|
||||
for(Class<? extends DifferentialFunction> c : subTypes){
|
||||
for(Class<? extends DifferentialFunction> c : subTypes) {
|
||||
|
||||
if(Modifier.isAbstract(c.getModifiers()) || Modifier.isInterface(c.getModifiers()) || ILossFunction.class.isAssignableFrom(c))
|
||||
continue;
|
||||
|
|
|
@ -1,178 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.execution;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.OpValidationSuite;
|
||||
import org.nd4j.autodiff.execution.conf.ExecutionMode;
|
||||
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
|
||||
import org.nd4j.autodiff.execution.conf.OutputMode;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
@Slf4j
|
||||
public class GraphExecutionerTest extends BaseNd4jTest {
|
||||
|
||||
public GraphExecutionerTest(Nd4jBackend b){
|
||||
super(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering(){
|
||||
return 'c';
|
||||
}
|
||||
|
||||
protected static ExecutorConfiguration configVarSpace = ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build();
|
||||
protected static ExecutorConfiguration configExplicit = ExecutorConfiguration.builder().outputMode(OutputMode.EXPLICIT).build();
|
||||
protected static ExecutorConfiguration configImplicit = ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build();
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
//
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testConversion() throws Exception {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray ones = Nd4j.ones(4);
|
||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||
SDVariable result = sdVariable.add(1.0);
|
||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||
|
||||
val executioner = new NativeGraphExecutioner();
|
||||
|
||||
ByteBuffer buffer = executioner.convertToFlatBuffers(sameDiff, ExecutorConfiguration.builder().profilingMode(OpExecutioner.ProfilingMode.DISABLED).executionMode(ExecutionMode.SEQUENTIAL).outputMode(OutputMode.IMPLICIT).build());
|
||||
|
||||
val offset = buffer.position();
|
||||
val array = buffer.array();
|
||||
|
||||
try (val fos = new FileOutputStream("../../libnd4j/tests/resources/adam_sum.fb"); val dos = new DataOutputStream(fos)) {
|
||||
dos.write(array, offset, array.length - offset);
|
||||
}
|
||||
|
||||
|
||||
//INDArray[] res = executioner.executeGraph(sameDiff);
|
||||
//assertEquals(8.0, res[0].getDouble(0), 1e-5);
|
||||
/*
|
||||
INDArray output = null;
|
||||
for(int i = 0; i < 5; i++) {
|
||||
output = sameDiff.execAndEndResult(ops);
|
||||
System.out.println("Ones " + ones);
|
||||
System.out.println(output);
|
||||
}
|
||||
|
||||
assertEquals(Nd4j.valueArrayOf(4,7),ones);
|
||||
assertEquals(28,output.getDouble(0),1e-1);
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* VarSpace should dump everything. 4 variables in our case
|
||||
* @throws Exception
|
||||
*/
|
||||
@Test
|
||||
public void testEquality1() {
|
||||
OpValidationSuite.ignoreFailing(); //Failing 2019/01/24
|
||||
GraphExecutioner executionerA = new BasicGraphExecutioner();
|
||||
GraphExecutioner executionerB = new NativeGraphExecutioner();
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray ones = Nd4j.ones(4);
|
||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||
SDVariable scalarOne = sameDiff.var("scalar",Nd4j.scalar(1.0));
|
||||
SDVariable result = sdVariable.add(scalarOne);
|
||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||
|
||||
log.info("TOTAL: {}; Id: {}", total.name(), total);
|
||||
|
||||
INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace);
|
||||
|
||||
//Variables: ones, scalar, result, total
|
||||
assertEquals(sameDiff.variables().size(), resB.length);
|
||||
assertEquals(Nd4j.ones(4), resB[0]);
|
||||
assertEquals(Nd4j.scalar(1), resB[1]);
|
||||
assertEquals(Nd4j.create(new float[]{2f, 2f, 2f, 2f}), resB[2]);
|
||||
assertEquals(Nd4j.scalar(8.0), resB[3]);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Implicit should return tree edges. So, one variable
|
||||
* @throws Exception
|
||||
*/
|
||||
@Test
|
||||
public void testEquality2() {
|
||||
OpValidationSuite.ignoreFailing(); //Failing 2019/01/24
|
||||
GraphExecutioner executionerA = new BasicGraphExecutioner();
|
||||
GraphExecutioner executionerB = new NativeGraphExecutioner();
|
||||
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray ones = Nd4j.ones(4);
|
||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||
SDVariable scalarOne = sameDiff.var("add1",Nd4j.scalar(1.0));
|
||||
SDVariable result = sdVariable.add(scalarOne);
|
||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||
|
||||
// log.info("ID: {}",sameDiff.getGraph().getVertex(1).getValue().getId());
|
||||
|
||||
INDArray[] resB = executionerB.executeGraph(sameDiff, configImplicit);
|
||||
|
||||
assertEquals(1, resB.length);
|
||||
assertEquals(Nd4j.scalar(8.0), resB[0]);
|
||||
|
||||
//INDArray resA = executionerA.executeGraph(sameDiff)[0];
|
||||
|
||||
//assertEquals(resA, resB);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testSums1() {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
INDArray ones = Nd4j.ones(4);
|
||||
SDVariable sdVariable = sameDiff.var("ones",ones);
|
||||
SDVariable result = sdVariable.add(1.0);
|
||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||
|
||||
val executioner = new NativeGraphExecutioner();
|
||||
|
||||
INDArray[] res = executioner.executeGraph(sameDiff);
|
||||
assertEquals(8.0, res[0].getDouble(0), 1e-5);
|
||||
}
|
||||
}
|
|
@ -43,7 +43,7 @@ public class ActivationGradChecks extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testActivationGradientCheck1(){
|
||||
public void testActivationGradientCheck1() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("x", Nd4j.rand(DataType.DOUBLE, 3, 4));
|
||||
|
@ -61,7 +61,7 @@ public class ActivationGradChecks extends BaseOpValidation {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testActivationGradientCheck2(){
|
||||
public void testActivationGradientCheck2() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable x = sd.placeHolder("x", DataType.DOUBLE, 3, 4);
|
||||
|
|
|
@ -468,6 +468,11 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
assertEquals("Failed: " + failed, 0, failed.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return Long.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testReductionGradients2() {
|
||||
//Test reductions: NON-final function
|
||||
|
|
|
@ -567,7 +567,12 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
assertEquals(failed.toString(), 0, failed.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return Long.MAX_VALUE;
|
||||
}
|
||||
|
||||
@Test()
|
||||
public void testStack() {
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
|
@ -606,22 +611,22 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
INDArray expStack = null;
|
||||
if(Arrays.equals(new long[]{3,4}, shape)){
|
||||
if(Arrays.equals(new long[]{3,4}, shape)) {
|
||||
if(axis == 0){
|
||||
INDArray out = Nd4j.create(numInputs, 3, 4);
|
||||
for( int i=0; i<numInputs; i++ ){
|
||||
for( int i = 0; i < numInputs; i++) {
|
||||
out.get(point(i), all(), all()).assign(inArr[i]);
|
||||
}
|
||||
expStack = out;
|
||||
} else if(axis == 1) {
|
||||
INDArray out = Nd4j.create(3, numInputs, 4);
|
||||
for( int i=0; i<numInputs; i++ ){
|
||||
for( int i = 0; i<numInputs; i++) {
|
||||
out.get(all(), point(i), all()).assign(inArr[i]);
|
||||
}
|
||||
expStack = out;
|
||||
} else {
|
||||
INDArray out = Nd4j.create(3, 4, numInputs);
|
||||
for( int i=0; i<numInputs; i++ ){
|
||||
for( int i = 0; i < numInputs; i++) {
|
||||
out.get(all(), all(), point(i)).assign(inArr[i]);
|
||||
}
|
||||
expStack = out;
|
||||
|
|
|
@ -2815,9 +2815,9 @@ public class SameDiffTests extends BaseNd4jTest {
|
|||
out.markAsLoss();
|
||||
out.eval();
|
||||
|
||||
out.eval();
|
||||
sd.grad("a").eval();
|
||||
|
||||
INDArray outEvaled = out.eval();
|
||||
INDArray gradOutput = sd.grad("a").eval();
|
||||
INDArray bOutputEval = sd.grad("b").eval();
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.testFlatBufferSerialization(TestCase.TestSerialization.BOTH)
|
||||
.gradientCheck(true));
|
||||
|
|
|
@ -84,7 +84,7 @@ public class ProfilingListenerTest extends BaseNd4jTest {
|
|||
Map<String,INDArray> ph = new HashMap<>();
|
||||
ph.put("in", i);
|
||||
|
||||
for( int x=0; x<10; x++ ) {
|
||||
for( int x = 0; x < 10; x++) {
|
||||
sd.outputSingle(ph, "predictions");
|
||||
}
|
||||
|
||||
|
@ -94,8 +94,8 @@ public class ProfilingListenerTest extends BaseNd4jTest {
|
|||
|
||||
//Should be 2 begins and 2 ends for each entry
|
||||
//5 warmup iterations, 5 profile iterations, x2 for both the op name and the op "instance" name
|
||||
String[] opNames = {"mmul", "add", "softmax"};
|
||||
for(String s : opNames){
|
||||
String[] opNames = {"matmul", "add", "softmax"};
|
||||
for(String s : opNames) {
|
||||
assertEquals(s, 10, StringUtils.countMatches(content, s));
|
||||
}
|
||||
|
||||
|
|
|
@ -39,8 +39,8 @@
|
|||
|
||||
<logger name="org.apache.catalina.core" level="DEBUG" />
|
||||
<logger name="org.springframework" level="WARN" />
|
||||
<logger name="org.nd4j" level="DEBUG" />
|
||||
<logger name="org.deeplearning4j" level="INFO" />
|
||||
<logger name="org.nd4j" level="TRACE" />
|
||||
<logger name="org.deeplearning4j" level="TRACE" />
|
||||
|
||||
|
||||
<root level="ERROR">
|
||||
|
|
|
@ -1 +1,18 @@
|
|||
Sum,Sum
|
||||
in_0/read,in_0/read
|
||||
while/Enter,while/Enter
|
||||
while/Enter_1,while/Enter_1
|
||||
while/Merge,while/Merge
|
||||
while/Merge_1,while/Merge_1
|
||||
while/Less,while/Less
|
||||
while/LoopCond,while/LoopCond
|
||||
while/Switch,while/Switch
|
||||
while/Switch:1,while/Switch
|
||||
while/Switch_1,while/Switch_1
|
||||
while/Switch_1:1,while/Switch_1
|
||||
while/Identity,while/Identity
|
||||
while/Exit,while/Exit
|
||||
while/Identity_1,while/Identity_1
|
||||
while/Exit_1,while/Exit_1
|
||||
while/add,while/add
|
||||
while/NextIteration_1,while/NextIteration_1
|
||||
while/NextIteration,while/NextIteration
|
||||
|
|
|
@ -1654,7 +1654,8 @@ val randomCrop = mapTensorNamesWithOp(inputFrameworkOpName = "RandomCrop",opName
|
|||
attributeMappingRules = listOf(valueMapping(mutableMapOf("seed" to "seed")))
|
||||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||
|
||||
val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",tensorNames = mutableMapOf()
|
||||
val placeHolder = mapTensorNamesWithOp(inputFrameworkOpName = "Placeholder",opName = "placeholder",tensorNames = mutableMapOf(),
|
||||
attributeMappingRules = listOf()
|
||||
,tensorflowOpRegistry = tensorflowOpRegistry)
|
||||
|
||||
val randomGamma = mapTensorNamesWithOp(inputFrameworkOpName = "RandomGamma",opName = "random_gamma",tensorNames = mutableMapOf("shape" to "shape","alpha" to "alpha"),
|
||||
|
|
Loading…
Reference in New Issue