SameDiff cleanup and fixes (#12)

* #8160 Remove resolvePrepertiesFromSameDiffBeforeExecution

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff API cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More SameDiff cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8248 Switch SameDiff variable init from lazy to creation time for more predictable behaviour

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8252 TanhDerivative javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8225 Deconvolution2D input validation

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8265 Switch SameDiff.outputs() to user settable, instead of unreliable 'best guess'

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8224 SameDiff.zero and .one create constants, not variables

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More cleanup and fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small test fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J SameDiff fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Re-add hack for Deconvolution2DLayer until #8315 is resolved

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8270 Move CUDA device/version logging to Java; can be disabled via existing org.nd4j.log.initialization system property

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* All ND4J init logging checks system property

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small tweak

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove redundant device logging

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* One more fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* UX improvements

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Deconv fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add deconv tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove debug code

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-10-26 12:38:08 +11:00 committed by GitHub
parent d98784197a
commit d333d29099
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
92 changed files with 1204 additions and 1956 deletions

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.ConvolutionMode;
@ -693,4 +694,22 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
INDArray out = net.output(in); INDArray out = net.output(in);
assertArrayEquals(new long[]{2,7,6}, out.shape()); assertArrayEquals(new long[]{2,7,6}, out.shape());
} }
@Test
public void testDeconvBadInput(){
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
INDArray badInput = Nd4j.create(DataType.FLOAT, 1, 10, 5, 5);
try {
net.output(badInput);
} catch (DL4JInvalidInputException e){
String msg = e.getMessage();
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
}
}
} }

View File

@ -86,10 +86,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3); SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3);
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10); SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10);
SDVariable b0 = sd.zero("b0", 1, 10); SDVariable b0 = sd.var("b0", Nd4j.create(DataType.DOUBLE, 1, 10));
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3); SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3);
SDVariable b1 = sd.zero("b1", 1, 3); SDVariable b1 = sd.var("b1", Nd4j.create(DataType.DOUBLE, 1, 3));
SDVariable z0 = in.mmul(w0).add(b0); SDVariable z0 = in.mmul(w0).add(b0);
SDVariable a0 = sd.nn().tanh(z0); SDVariable a0 = sd.nn().tanh(z0);
@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
Map<String,INDArray> placeholders = new HashMap<>(); Map<String,INDArray> placeholders = new HashMap<>();
placeholders.put("input", f); placeholders.put("input", f);
placeholders.put("label", l); placeholders.put("label", l);
Map<String,INDArray> map = sd.output(placeholders, lossMse.getVarName(), a1.getVarName()); Map<String,INDArray> map = sd.output(placeholders, lossMse.name(), a1.name());
INDArray outSd = map.get(a1.getVarName()); INDArray outSd = map.get(a1.name());
INDArray outDl4j = net.output(f); INDArray outDl4j = net.output(f);
assertEquals(testName, outDl4j, outSd); assertEquals(testName, outDl4j, outSd);
@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
//Check score //Check score
double scoreDl4j = net.score(); double scoreDl4j = net.score();
double scoreSd = map.get(lossMse.getVarName()).getDouble(0) + sd.calcRegularizationScore(); double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
assertEquals(testName, scoreDl4j, scoreSd, 1e-6); assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
double lossRegScoreSD = sd.calcRegularizationScore(); double lossRegScoreSD = sd.calcRegularizationScore();
@ -197,15 +197,15 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
//Check gradients (before updater applied) //Check gradients (before updater applied)
Map<String,INDArray> grads = net.gradient().gradientForVariable(); Map<String,INDArray> grads = net.gradient().gradientForVariable();
sd.execBackwards(placeholders); Map<String,INDArray> gm = sd.calculateGradients(placeholders, b1.name(), w1.name(), b0.name(), w0.name());
//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only //Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only
//We can check correctness though with training param checks later //We can check correctness though with training param checks later
if(l1Val == 0 && l2Val == 0 && wdVal == 0) { if(l1Val == 0 && l2Val == 0 && wdVal == 0) {
assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr()); assertEquals(testName, grads.get("1_b"), gm.get(b1.name()));
assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr()); assertEquals(testName, grads.get("1_W"), gm.get(w1.name()));
assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr()); assertEquals(testName, grads.get("0_b"), gm.get(b0.name()));
assertEquals(testName, grads.get("0_W"), w0.getGradient().getArr()); assertEquals(testName, grads.get("0_W"), gm.get(w0.name()));
} }

View File

@ -65,7 +65,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
defineVertex(temp, tempInputs); defineVertex(temp, tempInputs);
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
for (Integer i : tempInputs.map.keySet()) { for (Integer i : tempInputs.map.keySet()) {
list.add(tempInputs.map.get(i).getVarName()); list.add(tempInputs.map.get(i).name());
} }
params.defineInputs(list.toArray(new String[list.size()])); params.defineInputs(list.toArray(new String[list.size()]));
} }

View File

@ -176,8 +176,10 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
int outDepth = (int) weights.size(1); int outDepth = (int) weights.size(1);
if (input.size(1) != inDepth && input.size(3) == inDepth) { if (input.size(1) != inDepth && input.size(3) == inDepth) {
//TODO AB 2019/10/25 this is an ugly "pseudo-NHWC support" hack that needs to be removed ASAD
//https://github.com/eclipse/deeplearning4j/issues/8315
input = input.permute(0, 3, 1, 2); input = input.permute(0, 3, 1, 2);
} else if (input.size(1) != inDepth && input.size(3) != inDepth) { } else if (input.size(1) != inDepth ) {
String layerName = conf.getLayer().getLayerName(); String layerName = conf.getLayer().getLayerName();
if (layerName == null) if (layerName == null)
layerName = "(not named)"; layerName = "(not named)";

View File

@ -192,7 +192,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
String name = inputs.get(j); String name = inputs.get(j);
dLdIns[j] = sameDiff.grad(name).getArr(); dLdIns[j] = sameDiff.grad(name).getArr();
String gradName = sameDiff.grad(inputNames.get(j)).getVarName(); String gradName = sameDiff.grad(inputNames.get(j)).name();
if(dLdIns[j] == null && fnName.equals(gradName)){ if(dLdIns[j] == null && fnName.equals(gradName)){
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders //Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
@ -271,7 +271,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
fn = sameDiff.f().externalErrors(layerOutput); fn = sameDiff.f().externalErrors(layerOutput);
fn.outputVariable(); fn.outputVariable();
this.outputKey = outputVar.getVarName(); this.outputKey = outputVar.name();
} }
} }

View File

@ -302,7 +302,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
fn = sameDiff.f().externalErrors(layerOutput); fn = sameDiff.f().externalErrors(layerOutput);
fn.outputVariable(); fn.outputVariable();
this.outputKey = outputVar.getVarName(); this.outputKey = outputVar.name();
} }
} }

View File

@ -112,7 +112,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
phMap.put(LABELS_KEY, labels); phMap.put(LABELS_KEY, labels);
} }
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName(); String s = activations ? layerConf().activationsVertexName() : outputVar.name();
INDArray out = sameDiff.outputSingle(phMap, s); INDArray out = sameDiff.outputSingle(phMap, s);
@ -160,31 +160,35 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
} }
List<String> gradVarNames = new ArrayList<>(); List<String> gradVarNames = new ArrayList<>();
for(String s : paramTable.keySet()){ gradVarNames.addAll(paramTable.keySet());
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName()); gradVarNames.add(INPUT_KEY);
}
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input); phMap.put(INPUT_KEY, input);
phMap.put(LABELS_KEY, labels); phMap.put(LABELS_KEY, labels);
sameDiff.execBackwards(phMap, gradVarNames); Map<String,INDArray> grads = sameDiff.calculateGradients(phMap, gradVarNames);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = grads.get(s);
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad); g.gradientForVariable().put(s, dl4jGrad);
if(sdGrad.closeable()){
sdGrad.close();
}
} }
dLdIn = sameDiff.grad(INPUT_KEY).getArr(); dLdIn = grads.get(INPUT_KEY);
} }
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS Pair<Gradient,INDArray> p = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
if(dLdIn.closeable())
dLdIn.close();
return p;
} }
/**Returns the parameters of the neural network as a flattened row vector /**Returns the parameters of the neural network as a flattened row vector
@ -297,7 +301,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey())); sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
} }
this.outputKey = layerOutput.getVarName(); this.outputKey = layerOutput.name();
} }
} }

View File

@ -66,13 +66,6 @@ namespace nd4j {
#endif #endif
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
BlasVersionHelper ver;
_blasMajorVersion = ver._blasMajorVersion;
_blasMinorVersion = ver._blasMinorVersion;
_blasPatchVersion = ver._blasPatchVersion;
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
fflush(stdout);
int devCnt = 0; int devCnt = 0;
cudaGetDeviceCount(&devCnt); cudaGetDeviceCount(&devCnt);
auto devProperties = new cudaDeviceProp[devCnt]; auto devProperties = new cudaDeviceProp[devCnt];
@ -83,10 +76,12 @@ namespace nd4j {
//cudaDeviceSetLimit(cudaLimitStackSize, 4096); //cudaDeviceSetLimit(cudaLimitStackSize, 4096);
Pair p(devProperties[i].major, devProperties[i].minor); Pair p(devProperties[i].major, devProperties[i].minor);
_capabilities.emplace_back(p); _capabilities.emplace_back(p);
printf("CUDA device %i: [%s]; cc: [%i.%i]; Total memory: [%lld];\n", i, devProperties[i].name, devProperties[i].major, devProperties[i].minor, (Nd4jLong) devProperties[i].totalGlobalMem);
} }
fflush(stdout);
BlasVersionHelper ver;
_blasMajorVersion = ver._blasMajorVersion;
_blasMinorVersion = ver._blasMinorVersion;
_blasPatchVersion = ver._blasPatchVersion;
cudaSetDevice(0); cudaSetDevice(0);
delete[] devProperties; delete[] devProperties;
@ -203,6 +198,18 @@ namespace nd4j {
#endif #endif
} }
int Environment::blasMajorVersion(){
return _blasMajorVersion;
}
int Environment::blasMinorVersion(){
return _blasMinorVersion;
}
int Environment::blasPatchVersion(){
return _blasPatchVersion;
}
nd4j::Environment *nd4j::Environment::_instance = 0; nd4j::Environment *nd4j::Environment::_instance = 0;
} }

View File

@ -97,6 +97,10 @@ namespace nd4j{
bool isCPU(); bool isCPU();
int blasMajorVersion();
int blasMinorVersion();
int blasPatchVersion();
std::vector<Pair>& capabilities(); std::vector<Pair>& capabilities();
}; };
} }

View File

@ -66,8 +66,10 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
if(!isNCHW) if(!isNCHW)
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
if(isSameMode) // SAME if(isSameMode){ // SAME
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); //Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
}
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext()); NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());

View File

@ -442,22 +442,12 @@ public abstract class DifferentialFunction {
setInstanceId(); setInstanceId();
if(sameDiff != null) { if(sameDiff != null) {
sameDiff.addArgsFor(args, this); sameDiff.addArgsFor(args, this);
for (int i = 0; i < args.length; i++) {
if (args[i].isPlaceHolder()) {
sameDiff.addPropertyToResolve(this, args[i].getVarName());
}
}
} }
} }
public void replaceArg(int i, SDVariable newArg){ public void replaceArg(int i, SDVariable newArg){
if(sameDiff != null){ if(sameDiff != null){
sameDiff.replaceArgFor(i, newArg, this); sameDiff.replaceArgFor(i, newArg, this);
if(args()[i].isPlaceHolder() && !newArg.isPlaceHolder()){
sameDiff.removePropertyToResolve(this, args()[i].getVarName());
} else if(!args()[i].isPlaceHolder() && newArg.isPlaceHolder()){
sameDiff.addPropertyToResolve(this, newArg.getVarName());
}
} }
} }
@ -483,7 +473,7 @@ public abstract class DifferentialFunction {
SDVariable[] outputVars = outputVariables(); SDVariable[] outputVars = outputVariables();
String[] out = new String[outputVars.length]; String[] out = new String[outputVars.length];
for( int i=0; i<out.length; i++ ){ for( int i=0; i<out.length; i++ ){
out[i] = outputVars[i].getVarName(); out[i] = outputVars[i].name();
} }
return out; return out;
} }
@ -538,69 +528,11 @@ public abstract class DifferentialFunction {
SDVariable[] args = args(); SDVariable[] args = args();
String[] out = new String[args.length]; String[] out = new String[args.length];
for( int i=0; i<args.length; i++ ){ for( int i=0; i<args.length; i++ ){
out[i] = args[i].getVarName(); out[i] = args[i].name();
} }
return out; return out;
} }
/**
* Resolve properties and arguments right before execution of
* this operation.
*
* @deprecated Will be removed in the future. Ops should support array arguments. Should not bs used or overridden.
*/
@Deprecated
public final void resolvePropertiesFromSameDiffBeforeExecution() {
val properties = sameDiff.propertiesToResolveForFunction(this);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
val currentFields = this.propertiesForFunction();
for(val property : properties) {
//property maybe a variable which is only an array
//just skip if this is the case
if(!fields.containsKey(property))
continue;
val var = sameDiff.getVarNameForFieldAndFunction(this,property);
if(var == null)
continue; //Rarely (like Conv2D) properties will be optional. For example kH/kW args will be inferred from weight shape
val fieldType = fields.get(property);
val varArr = sameDiff.getArrForVarName(var);
//already defined
if(currentFields.containsKey(property)) {
continue;
}
/**
* Possible cause:
* Might be related to output name alignment.
*
*/
if(varArr == null) {
throw new ND4JIllegalStateException("Unable to set null array!");
}
if(fieldType.getType().equals(int[].class)) {
setValueFor(fieldType,varArr.data().asInt());
}
else if(fieldType.equals(double[].class)) {
setValueFor(fieldType,varArr.data().asDouble());
}
else if(fieldType.equals(int.class)) {
setValueFor(fieldType,varArr.getInt(0));
}
else if(fieldType.equals(double.class)) {
setValueFor(fieldType,varArr.getDouble(0));
}
}
}
/** /**
* Return the first argument * Return the first argument
* @return * @return
@ -639,13 +571,12 @@ public abstract class DifferentialFunction {
SDVariable gradVar = f().add(grad, vals.get(i)); SDVariable gradVar = f().add(grad, vals.get(i));
vals.set(i, gradVar); vals.set(i, gradVar);
sameDiff.setGradientForVariableName(var.getVarName(), gradVar); sameDiff.setGradientForVariableName(var.name(), gradVar);
} else { } else {
SDVariable gradVar = vals.get(i); SDVariable gradVar = vals.get(i);
sameDiff.updateVariableNameAndReference(gradVar,var.getVarName() + "-grad"); sameDiff.updateVariableNameAndReference(gradVar,var.name() + "-grad");
sameDiff.setGradientForVariableName(var.getVarName(), gradVar); sameDiff.setGradientForVariableName(var.name(), gradVar);
sameDiff.setForwardVariableForVarName(gradVar.getVarName(),var);
} }
} }

View File

@ -184,7 +184,6 @@ import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation; import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance; import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.impl.transforms.Constant;
import org.nd4j.linalg.api.ops.impl.transforms.Pad; import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer; import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
@ -352,12 +351,6 @@ public class DifferentialFunctionFactory {
} }
} }
public Constant val(SDVariable iX) {
return new Constant(sameDiff(), iX,
iX.getShape());
}
public ExternalErrorsFunction externalErrors(SDVariable... inputs) { public ExternalErrorsFunction externalErrors(SDVariable... inputs) {
return externalErrors(null, inputs); return externalErrors(null, inputs);
} }
@ -384,10 +377,6 @@ public class DifferentialFunctionFactory {
return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); return new OnesLike(name, sameDiff(), input, dataType).outputVariable();
} }
public SDVariable constant(SDVariable input, long... shape) {
return new Constant(sameDiff(), input, (shape != null && shape.length > 0 ? shape : null)).outputVariable();
}
public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) { public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) {
return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable(); return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable();
} }
@ -1056,7 +1045,7 @@ public class DifferentialFunctionFactory {
public SDVariable gradientBackwardsMarker(SDVariable iX) { public SDVariable gradientBackwardsMarker(SDVariable iX) {
return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.getVarName() + "-pairgrad", 1.0)).outputVariable(); return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.name() + "-pairgrad", 1.0)).outputVariable();
} }
public SDVariable abs(SDVariable iX) { public SDVariable abs(SDVariable iX) {

View File

@ -178,7 +178,7 @@ public class ListenerEvaluations {
* @param evaluations The evaluations to run * @param evaluations The evaluations to run
*/ */
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
return trainEvaluation(variable.getVarName(), labelIndex, evaluations); return trainEvaluation(variable.name(), labelIndex, evaluations);
} }
/** /**
@ -202,7 +202,7 @@ public class ListenerEvaluations {
* @param evaluations The evaluations to run * @param evaluations The evaluations to run
*/ */
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) { public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
return validationEvaluation(variable.getVarName(), labelIndex, evaluations); return validationEvaluation(variable.name(), labelIndex, evaluations);
} }
/** /**

View File

@ -167,7 +167,7 @@ public class ListenerVariables {
String[] names = new String[variables.length]; String[] names = new String[variables.length];
for (int i = 0; i < variables.length; i++) for (int i = 0; i < variables.length; i++)
names[i] = variables[i].getVarName(); names[i] = variables[i].name();
return requireVariables(op, names); return requireVariables(op, names);
} }

View File

@ -226,7 +226,7 @@ public class UIListener extends BaseListener {
List<SDVariable> sdVars = sd.variables(); List<SDVariable> sdVars = sd.variables();
List<String> varNames = new ArrayList<>(sdVars.size()); List<String> varNames = new ArrayList<>(sdVars.size());
for(SDVariable v : sdVars){ for(SDVariable v : sdVars){
varNames.add(v.getVarName()); varNames.add(v.name());
} }
if(varNames.size() != vars.size() || !varNames.containsAll(vars)){ if(varNames.size() != vars.size() || !varNames.containsAll(vars)){

View File

@ -91,7 +91,7 @@ public class EvaluationRecord {
* @param param The target param/variable * @param param The target param/variable
*/ */
public List<IEvaluation> evaluations(SDVariable param) { public List<IEvaluation> evaluations(SDVariable param) {
return evaluations(param.getVarName()); return evaluations(param.name());
} }
/** /**
@ -105,7 +105,7 @@ public class EvaluationRecord {
* Get the evaluation for param at the specified index * Get the evaluation for param at the specified index
*/ */
public IEvaluation evaluation(SDVariable param, int index) { public IEvaluation evaluation(SDVariable param, int index) {
return evaluation(param.getVarName(), index); return evaluation(param.name(), index);
} }
/** /**
@ -132,7 +132,7 @@ public class EvaluationRecord {
* @param param The target param/variable * @param param The target param/variable
*/ */
public <T extends IEvaluation> T evaluation(SDVariable param) { public <T extends IEvaluation> T evaluation(SDVariable param) {
return evaluation(param.getVarName()); return evaluation(param.name());
} }
/** /**
@ -174,7 +174,7 @@ public class EvaluationRecord {
* @param evalClass The type of evaluation to look for * @param evalClass The type of evaluation to look for
*/ */
public <T extends IEvaluation<T>> T evaluation(SDVariable param, Class<T> evalClass) { public <T extends IEvaluation<T>> T evaluation(SDVariable param, Class<T> evalClass) {
return evaluation(param.getVarName(), evalClass); return evaluation(param.name(), evalClass);
} }
/** /**
@ -209,7 +209,7 @@ public class EvaluationRecord {
* @param metric The metric to calculate * @param metric The metric to calculate
*/ */
public double getValue(SDVariable param, IMetric metric) { public double getValue(SDVariable param, IMetric metric) {
return getValue(param.getVarName(), metric); return getValue(param.name(), metric);
} }
/** /**
@ -235,7 +235,7 @@ public class EvaluationRecord {
* @param metric The metric to calculate * @param metric The metric to calculate
*/ */
public double getValue(SDVariable param, int index, IMetric metric) { public double getValue(SDVariable param, int index, IMetric metric) {
return getValue(param.getVarName(), index, metric); return getValue(param.name(), index, metric);
} }
} }

View File

@ -125,7 +125,7 @@ public class History {
* Only works if there is only one evaluation with the given metric for param * Only works if there is only one evaluation with the given metric for param
*/ */
public List<Double> trainingEval(SDVariable param, IMetric metric){ public List<Double> trainingEval(SDVariable param, IMetric metric){
return trainingEval(param.getVarName(), metric); return trainingEval(param.name(), metric);
} }
/** /**
@ -149,7 +149,7 @@ public class History {
* Index determines the evaluation used not the epoch's results to return. * Index determines the evaluation used not the epoch's results to return.
*/ */
public List<Double> trainingEval(SDVariable param, int index, IMetric metric){ public List<Double> trainingEval(SDVariable param, int index, IMetric metric){
return trainingEval(param.getVarName(), index, metric); return trainingEval(param.name(), index, metric);
} }
/** /**
@ -184,7 +184,7 @@ public class History {
* Only works if there is only one evaluation for param. * Only works if there is only one evaluation for param.
*/ */
public List<IEvaluation> trainingEval(SDVariable param){ public List<IEvaluation> trainingEval(SDVariable param){
return trainingEval(param.getVarName()); return trainingEval(param.name());
} }
/** /**
@ -208,7 +208,7 @@ public class History {
* Index determines the evaluation used not the epoch's results to return. * Index determines the evaluation used not the epoch's results to return.
*/ */
public List<IEvaluation> trainingEval(SDVariable param, int index){ public List<IEvaluation> trainingEval(SDVariable param, int index){
return trainingEval(param.getVarName(), index); return trainingEval(param.name(), index);
} }
/** /**
@ -230,7 +230,7 @@ public class History {
* Only works if there is only one evaluation with the given metric for param * Only works if there is only one evaluation with the given metric for param
*/ */
public List<Double> validationEval(SDVariable param, IMetric metric){ public List<Double> validationEval(SDVariable param, IMetric metric){
return validationEval(param.getVarName(), metric); return validationEval(param.name(), metric);
} }
/** /**
@ -254,7 +254,7 @@ public class History {
* Index determines the evaluation used not the epoch's results to return. * Index determines the evaluation used not the epoch's results to return.
*/ */
public List<Double> validationEval(SDVariable param, int index, IMetric metric){ public List<Double> validationEval(SDVariable param, int index, IMetric metric){
return validationEval(param.getVarName(), index, metric); return validationEval(param.name(), index, metric);
} }
/** /**
@ -289,7 +289,7 @@ public class History {
* Only works if there is only one evaluation for param. * Only works if there is only one evaluation for param.
*/ */
public List<IEvaluation> validationEval(SDVariable param){ public List<IEvaluation> validationEval(SDVariable param){
return validationEval(param.getVarName()); return validationEval(param.name());
} }
/** /**
@ -313,7 +313,7 @@ public class History {
* Index determines the evaluation used not the epoch's results to return. * Index determines the evaluation used not the epoch's results to return.
*/ */
public List<IEvaluation> validationEval(SDVariable param, int index){ public List<IEvaluation> validationEval(SDVariable param, int index){
return validationEval(param.getVarName(), index); return validationEval(param.name(), index);
} }
/** /**

View File

@ -116,7 +116,7 @@ public class LossCurve {
* Return all mean loss values for a given variable * Return all mean loss values for a given variable
*/ */
public float[] meanLoss(@NonNull SDVariable loss){ public float[] meanLoss(@NonNull SDVariable loss){
return meanLoss(loss.getVarName()); return meanLoss(loss.name());
} }
/** /**
@ -143,7 +143,7 @@ public class LossCurve {
* See {@link #meanLoss(int)} * See {@link #meanLoss(int)}
*/ */
public float meanLoss(@NonNull SDVariable loss, int epoch){ public float meanLoss(@NonNull SDVariable loss, int epoch){
return meanLoss(loss.getVarName(), epoch); return meanLoss(loss.name(), epoch);
} }
/** /**
@ -162,7 +162,7 @@ public class LossCurve {
* Return the mean loss value for a given variable on the last epoch. * Return the mean loss value for a given variable on the last epoch.
*/ */
public float lastMeanLoss(@NonNull SDVariable loss){ public float lastMeanLoss(@NonNull SDVariable loss){
return lastMeanLoss(loss.getVarName()); return lastMeanLoss(loss.name());
} }
/** /**
@ -189,7 +189,7 @@ public class LossCurve {
* A positive delta means the loss is increasing, and a negative delta means it is decreasing. * A positive delta means the loss is increasing, and a negative delta means it is decreasing.
*/ */
public double lastMeanDelta(SDVariable loss){ public double lastMeanDelta(SDVariable loss){
return lastMeanDelta(loss.getVarName()); return lastMeanDelta(loss.name());
} }
/** /**

View File

@ -59,10 +59,6 @@ public class SDVariable implements Serializable {
@Setter @Setter
protected VariableType variableType; protected VariableType variableType;
@Getter
@Setter
protected WeightInitScheme weightInitScheme;
@Setter(AccessLevel.NONE) @Setter(AccessLevel.NONE)
protected long[] shape; protected long[] shape;
@ -75,9 +71,7 @@ public class SDVariable implements Serializable {
// autogen_tag::sdvars::start // autogen_tag::sdvars::start
public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType, WeightInitScheme weightInitScheme){ public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType){
Preconditions.checkState(weightInitScheme == null || varType == VariableType.VARIABLE, "Weight initalization schemes can only be applied to VARIABLE type" +
" SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme);
Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName); Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
varName = sameDiff.generateNewVarName(varName, 0, true); varName = sameDiff.generateNewVarName(varName, 0, true);
@ -86,10 +80,25 @@ public class SDVariable implements Serializable {
this.varName = varName; this.varName = varName;
this.variableType = varType; this.variableType = varType;
this.dataType = dataType; this.dataType = dataType;
this.weightInitScheme = weightInitScheme;
this.shape = shape; this.shape = shape;
} }
/**
* Get the name of the SDVariable
* @return Name of the variable
*/
public String name(){
return varName;
}
/**
* @deprecated Use {@link #name()}
*/
@Deprecated
public String getVarName(){
return name();
}
/** /**
* Returns true if this variable is a place holder * Returns true if this variable is a place holder
* @return * @return
@ -102,30 +111,6 @@ public class SDVariable implements Serializable {
return variableType == VariableType.CONSTANT; return variableType == VariableType.CONSTANT;
} }
/**
* Allocate and return a new array
* based on the vertex id and weight initialization.
* @return the allocated array
*/
public INDArray storeAndAllocateNewArray() {
Preconditions.checkState(variableType == VariableType.VARIABLE, "Unable to allocate and store array for variable of type %s: only" +
" VARIABLE type variables can be initialized using this method", variableType);
if(!sameDiff.arrayAlreadyExistsForVarName(varName)){
long[] shape = getShape();
INDArray arr = getWeightInitScheme().create(dataType(), shape);
sameDiff.associateArrayWithVariable(arr, this);
if(log.isTraceEnabled()){
log.trace("Generated and stored new array for variable \"{}\": shape {}", getVarName(), Arrays.toString(arr.shape()));
}
return arr;
}
//Variable type SDVariables: shape should never change (i.e., these are params in the net!)
INDArray ret = getArr();
return ret;
}
/** /**
* A getter for the allocated ndarray with this {@link SDVariable}. * A getter for the allocated ndarray with this {@link SDVariable}.
* *
@ -155,30 +140,14 @@ public class SDVariable implements Serializable {
public INDArray getArr(boolean enforceExistence){ public INDArray getArr(boolean enforceExistence){
if(sameDiff.arrayAlreadyExistsForVarName(getVarName())) if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
return sameDiff.getArrForVarName(getVarName()); return sameDiff.getArrForVarName(getVarName());
if(variableType == VariableType.ARRAY){ if(variableType == VariableType.ARRAY){
throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead"); throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead");
} }
INDArray ret = sameDiff.getArrForVarName(getVarName());
//initialize value if it's actually a scalar constant (zero or 1 typically...) if(enforceExistence && ret == null){
if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){ throw new IllegalStateException("No array exists for variable \"" + name() + "\"");
INDArray arr = weightInitScheme.create(dataType, shape);
sameDiff.associateArrayWithVariable(arr, this);
if(log.isTraceEnabled()){
log.trace("getArr() for variable \"{}\" allocated new array: shape {}", getVarName(), Arrays.toString(getShape()));
} }
return arr; return ret;
} else if(sameDiff.getShapeForVarName(getVarName()) == null) {
if (enforceExistence) {
throw new IllegalStateException("Cannot get array for SDVariable \"" + getVarName() + "\": no array has" +
" been defined, and array shape cannot be calculated");
}
if(log.isTraceEnabled()){
log.trace("SDVariable.getArr(): could not get array for variable {}: shape is null", getVarName());
}
return null;
}
return sameDiff.getArrForVarName(getVarName());
} }
@ -215,21 +184,13 @@ public class SDVariable implements Serializable {
* @return Shape of the variable * @return Shape of the variable
*/ */
public long[] getShape() { public long[] getShape() {
if (variableType == VariableType.PLACEHOLDER && getArr() == null) { if (variableType == VariableType.PLACEHOLDER ) {
if (shape != null)
return shape; return shape;
else } else if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){
return new long[0]; return getArr().shape();
} }
long[] initialShape = sameDiff.getShapeForVarName(getVarName()); return null;
if(initialShape == null && variableType != VariableType.ARRAY) {
val arr = getArr();
if(arr != null)
return arr.shape();
}
return initialShape;
} }
public void setShape(long... shape){ public void setShape(long... shape){
@ -1488,8 +1449,8 @@ public class SDVariable implements Serializable {
* @return * @return
*/ */
public INDArray eval() { public INDArray eval() {
sameDiff.exec(null, getVarName()); Map<String,INDArray> m = sameDiff.output((Map<String,INDArray>)null, name());
return getArr(); return m.get(name());
} }
@ -1498,8 +1459,8 @@ public class SDVariable implements Serializable {
* @return * @return
*/ */
public INDArray eval(Map<String, INDArray> placeholders) { public INDArray eval(Map<String, INDArray> placeholders) {
sameDiff.exec(placeholders, getVarName()); Map<String,INDArray> m = sameDiff.output(placeholders, name());
return getArr(); return m.get(name());
} }
@ -1519,7 +1480,7 @@ public class SDVariable implements Serializable {
*/ */
public void addControlDependency(SDVariable controlDependency){ public void addControlDependency(SDVariable controlDependency){
Variable vThis = sameDiff.getVariables().get(getVarName()); Variable vThis = sameDiff.getVariables().get(getVarName());
Variable vCD = sameDiff.getVariables().get(controlDependency.getVarName()); Variable vCD = sameDiff.getVariables().get(controlDependency.name());
//If possible: add control dependency on ops //If possible: add control dependency on ops
if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){ if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){
@ -1729,7 +1690,6 @@ public class SDVariable implements Serializable {
SDVariable v = new SDVariable(); SDVariable v = new SDVariable();
v.varName = varName; v.varName = varName;
v.variableType = variableType; v.variableType = variableType;
v.weightInitScheme = weightInitScheme;
v.shape = shape == null ? null : shape.clone(); v.shape = shape == null ? null : shape.clone();
v.dataType = dataType; v.dataType = dataType;
v.sameDiff = sd; v.sameDiff = sd;

View File

@ -440,7 +440,7 @@ public class TrainingConfig {
* @param evaluations The evaluations to run * @param evaluations The evaluations to run
*/ */
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
return trainEvaluation(variable.getVarName(), labelIndex, evaluations); return trainEvaluation(variable.name(), labelIndex, evaluations);
} }
/** /**
@ -468,7 +468,7 @@ public class TrainingConfig {
* @param evaluations The evaluations to run * @param evaluations The evaluations to run
*/ */
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
return validationEvaluation(variable.getVarName(), labelIndex, evaluations); return validationEvaluation(variable.name(), labelIndex, evaluations);
} }
/** /**

View File

@ -73,7 +73,7 @@ public class BatchOutputConfig {
public BatchOutputConfig output(@NonNull SDVariable... outputs){ public BatchOutputConfig output(@NonNull SDVariable... outputs){
String[] outNames = new String[outputs.length]; String[] outNames = new String[outputs.length];
for(int i = 0 ; i < outputs.length ; i++){ for(int i = 0 ; i < outputs.length ; i++){
outNames[i] = outputs[i].getVarName(); outNames[i] = outputs[i].name();
} }
return output(outNames); return output(outNames);
@ -104,7 +104,7 @@ public class BatchOutputConfig {
* See {@link #input(String, INDArray)} * See {@link #input(String, INDArray)}
*/ */
public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){ public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){
return input(variable.getVarName(), placeholder); return input(variable.name(), placeholder);
} }
/** /**
@ -132,19 +132,35 @@ public class BatchOutputConfig {
return this; return this;
} }
/**
* @deprecated Use {@link #output()}
*/
@Deprecated
public Map<String, INDArray> exec() {
return output();
}
/** /**
* Do inference and return the results * Do inference and return the results
*/ */
public Map<String, INDArray> exec(){ public Map<String,INDArray> output(){
return sd.output(placeholders, listeners, outputs.toArray(new String[0])); return sd.output(placeholders, listeners, outputs.toArray(new String[0]));
} }
/**
* @deprecated Use {@link #outputSingle()}
*/
@Deprecated
public INDArray execSingle() {
return outputSingle();
}
/** /**
* Do inference and return the results for the single output * Do inference and return the results for the single output
* *
* Only works if exactly one output is specified * Only works if exactly one output is specified
*/ */
public INDArray execSingle(){ public INDArray outputSingle(){
Preconditions.checkState(outputs.size() == 1, Preconditions.checkState(outputs.size() == 1,
"Can only use execSingle() when exactly one output is specified, there were %s", outputs.size()); "Can only use execSingle() when exactly one output is specified, there were %s", outputs.size());
return exec().get(outputs.get(0)); return exec().get(outputs.get(0));

View File

@ -81,7 +81,7 @@ public class EvaluationConfig {
* See {@link #evaluate(String, int, IEvaluation[])} * See {@link #evaluate(String, int, IEvaluation[])}
*/ */
public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){ public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
return evaluate(variable.getVarName(), labelIndex, evaluations); return evaluate(variable.name(), labelIndex, evaluations);
} }
/** /**
@ -106,7 +106,7 @@ public class EvaluationConfig {
* See {@link #evaluate(String, IEvaluation[])} * See {@link #evaluate(String, IEvaluation[])}
*/ */
public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){ public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){
return evaluate(variable.getVarName(), evaluations); return evaluate(variable.name(), evaluations);
} }
/** /**
@ -129,7 +129,7 @@ public class EvaluationConfig {
* See {@link #labelIndex(String, int)} * See {@link #labelIndex(String, int)}
*/ */
public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){ public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){
return labelIndex(variable.getVarName(), labelIndex); return labelIndex(variable.name(), labelIndex);
} }
/** /**

View File

@ -75,7 +75,7 @@ public class OutputConfig {
public OutputConfig output(@NonNull SDVariable... outputs) { public OutputConfig output(@NonNull SDVariable... outputs) {
String[] outNames = new String[outputs.length]; String[] outNames = new String[outputs.length];
for (int i = 0; i < outputs.length; i++) { for (int i = 0; i < outputs.length; i++) {
outNames[i] = outputs[i].getVarName(); outNames[i] = outputs[i].name();
} }
return output(outNames); return output(outNames);

View File

@ -204,10 +204,10 @@ public abstract class AbstractSession<T, O> {
VariableType vt = v.getVariableType(); VariableType vt = v.getVariableType();
if (vt == VariableType.VARIABLE || vt == VariableType.CONSTANT) { if (vt == VariableType.VARIABLE || vt == VariableType.CONSTANT) {
ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT; ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT;
ExecStep es = new ExecStep(et, v.getVarName(), new FrameIter(OUTER_FRAME, 0, null)); ExecStep es = new ExecStep(et, v.name(), new FrameIter(OUTER_FRAME, 0, null));
dt.addDependency(es, start); dt.addDependency(es, start);
Variable var = sameDiff.getVariables().get(v.getVarName()); Variable var = sameDiff.getVariables().get(v.name());
if (var.getControlDeps() != null) { if (var.getControlDeps() != null) {
addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed
} }
@ -668,11 +668,11 @@ public abstract class AbstractSession<T, O> {
Variable v = sameDiff.getVariables().get(varName); Variable v = sameDiff.getVariables().get(varName);
VariableType vt = v.getVariable().getVariableType(); VariableType vt = v.getVariable().getVariableType();
if (vt == VariableType.VARIABLE) { if (vt == VariableType.VARIABLE) {
return new ExecStep(ExecType.VARIABLE, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); return new ExecStep(ExecType.VARIABLE, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
} else if (vt == VariableType.PLACEHOLDER) { } else if (vt == VariableType.PLACEHOLDER) {
return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
} else if (vt == VariableType.CONSTANT) { } else if (vt == VariableType.CONSTANT) {
return new ExecStep(ExecType.CONSTANT, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null)); return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
} else { } else {
//Array type. Must be output of an op //Array type. Must be output of an op
String outOfOp = v.getOutputOfOp(); String outOfOp = v.getOutputOfOp();

View File

@ -98,9 +98,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//TODO we shouldn't be clearing this on every single iteration, in 99.5% of cases variables will be same as last iteration... //TODO we shouldn't be clearing this on every single iteration, in 99.5% of cases variables will be same as last iteration...
for (SDVariable v : sameDiff.variables()) { for (SDVariable v : sameDiff.variables()) {
if (v.getVariableType() == VariableType.CONSTANT) { if (v.getVariableType() == VariableType.CONSTANT) {
arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.getVarName())); arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.name()));
} else if (v.getVariableType() == VariableType.VARIABLE) { } else if (v.getVariableType() == VariableType.VARIABLE) {
arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.getVarName())); arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.name()));
} }
} }
@ -484,7 +484,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
if (op instanceof TensorArray) { if (op instanceof TensorArray) {
//Create a TensorArray //Create a TensorArray
VarId vid = outputFrameIter.toVarId(op.outputVariable().getVarName()); VarId vid = outputFrameIter.toVarId(op.outputVariable().name());
Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid); Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid);
tensorArrays.put(vid, new ArrayList<INDArray>()); tensorArrays.put(vid, new ArrayList<INDArray>());
@ -504,18 +504,18 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
//Work out the frame/iteration: //Work out the frame/iteration:
VarId v = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId v = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
if (v == null && allIterInputs != null) { if (v == null && allIterInputs != null) {
v = lookup(inTensorArray.getVarName(), allIterInputs, false); v = lookup(inTensorArray.name(), allIterInputs, false);
} }
Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName()); Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.name());
while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) { while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead
//TODO also TensorArrayWrite, scatter, etc?? //TODO also TensorArrayWrite, scatter, etc??
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
v = v.getParentFrame().toVarId(inTensorArray.getVarName()); v = v.getParentFrame().toVarId(inTensorArray.name());
} }
List<INDArray> list = getTensorArrays().get(v); List<INDArray> list = getTensorArrays().get(v);
@ -528,31 +528,31 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//TensorArrayWrite - also has a scalar 0.0 that it returns... //TensorArrayWrite - also has a scalar 0.0 that it returns...
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
//Work out the varid (frame/iteration) of the tensor array: //Work out the varid (frame/iteration) of the tensor array:
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
if (tArr == null && allIterInputs != null) { if (tArr == null && allIterInputs != null) {
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); tArr = lookup(inTensorArray.name(), allIterInputs, false);
} }
Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName()); Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.name());
while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) { while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite //Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite
//TODO also TensorArrayScatter, etc?? //TODO also TensorArrayScatter, etc??
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg(); inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
tArr = tArr.getParentFrame().toVarId(inTensorArray.getVarName()); tArr = tArr.getParentFrame().toVarId(inTensorArray.name());
} }
//Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead //Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead
//Input 1 is the index //Input 1 is the index
//Input 2 is the value to write //Input 2 is the value to write
String idxName = op.arg(1).getVarName(); String idxName = op.arg(1).name();
SDVariable idxSDV = sameDiff.getVariable(idxName); SDVariable idxSDV = sameDiff.getVariable(idxName);
INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs); INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs);
Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr); Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr);
int idx = idxArr.getInt(0); int idx = idxArr.getInt(0);
String inName = op.arg(2).getVarName(); String inName = op.arg(2).name();
SDVariable inSDV = sameDiff.getVariable(inName); SDVariable inSDV = sameDiff.getVariable(inName);
INDArray arr = getArray(inSDV, opInputs, allIterInputs); INDArray arr = getArray(inSDV, opInputs, allIterInputs);
Preconditions.checkState(arr != null, "Could not find array for %s", inName); Preconditions.checkState(arr != null, "Could not find array for %s", inName);
@ -577,9 +577,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//Index 0 is the TensorArray (or dummy variable that represents it) //Index 0 is the TensorArray (or dummy variable that represents it)
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
//Work out the varid (frame/iteration) of the tensor array: //Work out the varid (frame/iteration) of the tensor array:
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
if (tArr == null && allIterInputs != null) { if (tArr == null && allIterInputs != null) {
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); tArr = lookup(inTensorArray.name(), allIterInputs, false);
} }
List<INDArray> l = tensorArrays.get(tArr); List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
@ -588,9 +588,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
return new INDArray[]{scalar}; return new INDArray[]{scalar};
} else if (op instanceof TensorArrayConcat) { } else if (op instanceof TensorArrayConcat) {
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
if (tArr == null && allIterInputs != null) { if (tArr == null && allIterInputs != null) {
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); tArr = lookup(inTensorArray.name(), allIterInputs, false);
} }
List<INDArray> l = tensorArrays.get(tArr); List<INDArray> l = tensorArrays.get(tArr);
@ -605,14 +605,14 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//Input 1: the indices (1d integer vector) //Input 1: the indices (1d integer vector)
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
if (tArr == null && allIterInputs != null) { if (tArr == null && allIterInputs != null) {
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); tArr = lookup(inTensorArray.name(), allIterInputs, false);
} }
List<INDArray> l = tensorArrays.get(tArr); List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
String indicesName = op.arg(1).getVarName(); String indicesName = op.arg(1).name();
SDVariable indicesSDV = sameDiff.getVariable(indicesName); SDVariable indicesSDV = sameDiff.getVariable(indicesName);
INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs);
Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName); Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName);
@ -644,22 +644,22 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//Input 2: The values to scatter //Input 2: The values to scatter
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName()); TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.name());
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
if (tArr == null && allIterInputs != null) { if (tArr == null && allIterInputs != null) {
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); tArr = lookup(inTensorArray.name(), allIterInputs, false);
} }
List<INDArray> l = tensorArrays.get(tArr); List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
String indicesName = op.arg(1).getVarName(); String indicesName = op.arg(1).name();
SDVariable indicesSDV = sameDiff.getVariable(indicesName); SDVariable indicesSDV = sameDiff.getVariable(indicesName);
INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs); INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs);
Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName); Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName);
Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName); Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName);
int[] idxs = idxArr.toIntVector(); int[] idxs = idxArr.toIntVector();
String valuesName = op.arg(2).getVarName(); String valuesName = op.arg(2).name();
SDVariable valuesSDV = sameDiff.getVariable(valuesName); SDVariable valuesSDV = sameDiff.getVariable(valuesName);
INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs); INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs);
@ -697,18 +697,18 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//Input 2: the size of each split (1d integer vector) //Input 2: the size of each split (1d integer vector)
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false)); VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
if (tArr == null && allIterInputs != null) { if (tArr == null && allIterInputs != null) {
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false); tArr = lookup(inTensorArray.name(), allIterInputs, false);
} }
List<INDArray> l = tensorArrays.get(tArr); List<INDArray> l = tensorArrays.get(tArr);
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
String splitName = op.arg(1).getVarName(); String splitName = op.arg(1).name();
INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs); INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs);
String sizeName = op.arg(2).getVarName(); String sizeName = op.arg(2).name();
SDVariable sizeSDV = sameDiff.getVariable(sizeName); SDVariable sizeSDV = sameDiff.getVariable(sizeName);
INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs); INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs);
Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName); Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName);
@ -803,7 +803,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
VarId vid = lookup(s, opInputs, allIterInputs, true); VarId vid = lookup(s, opInputs, allIterInputs, true);
args[i] = nodeOutputs.get(vid); args[i] = nodeOutputs.get(vid);
} }
Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.getVarName()); Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.name());
i++; i++;
} }
} }
@ -825,7 +825,6 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
return sdo; return sdo;
} }
df.resolvePropertiesFromSameDiffBeforeExecution(); //TODO This is to be removed
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape(); List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
String[] outNames = df.outputVariablesNames(); String[] outNames = df.outputVariablesNames();
@ -918,7 +917,6 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
op.setZ(z); op.setZ(z);
} }
} }
df.resolvePropertiesFromSameDiffBeforeExecution();
} }
return sdo; return sdo;
@ -926,12 +924,12 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
protected INDArray getArray(SDVariable sdv, Collection<VarId> opInputs, Collection<VarId> allIterInputs) { protected INDArray getArray(SDVariable sdv, Collection<VarId> opInputs, Collection<VarId> allIterInputs) {
String n = sdv.getVarName(); String n = sdv.name();
if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) { if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) {
return getConstantOrVariable(n); return getConstantOrVariable(n);
} else { } else {
VarId inVarId = lookup(n, opInputs, allIterInputs, false); VarId inVarId = lookup(n, opInputs, allIterInputs, false);
Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.getVarName()); Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.name());
return nodeOutputs.get(inVarId); return nodeOutputs.get(inVarId);
} }
} }

View File

@ -88,9 +88,9 @@ public class TrainingSession extends InferenceSession {
continue; continue;
} }
requiredActivations.add(grad.getVarName()); requiredActivations.add(grad.name());
gradVarToVarMap.put(grad.getVarName(), s); gradVarToVarMap.put(grad.name(), s);
} }
//Set up losses //Set up losses

View File

@ -3266,7 +3266,7 @@ public abstract class SDBaseOps {
if (cond_result.dataType() != DataType.BOOL) if (cond_result.dataType() != DataType.BOOL)
throw new IllegalStateException("Can not use " + cond_result.getVarName() + " as the condition of an While loop, the condition must be a boolean."); throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean.");
final Set<String> alreadyEntered = Sets.newHashSet(); final Set<String> alreadyEntered = Sets.newHashSet();
@ -3275,7 +3275,7 @@ public abstract class SDBaseOps {
for(int i = 0 ; i < loopVars.length ; i++){ for(int i = 0 ; i < loopVars.length ; i++){
SDVariable[] s = f().switchOp(merged[i], cond_result); SDVariable[] s = f().switchOp(merged[i], cond_result);
trueSwitches[i] = s[1]; trueSwitches[i] = s[1];
alreadyEntered.add(s[1].getVarName()); alreadyEntered.add(s[1].name());
exits[i] = f().exit(s[0]); exits[i] = f().exit(s[0]);
} }
@ -3290,17 +3290,17 @@ public abstract class SDBaseOps {
@Override @Override
public SDVariable intercept(SDVariable argument) { public SDVariable intercept(SDVariable argument) {
if(!declared.contains(argument.getVarName())) if(!declared.contains(argument.name()))
return argument; return argument;
if(alreadyEntered.contains(argument.getVarName())) if(alreadyEntered.contains(argument.name()))
return argument; return argument;
if(done.containsKey(argument.getVarName())) if(done.containsKey(argument.name()))
return done.get(argument.getVarName()); return done.get(argument.name());
SDVariable e = f().enter(argument, frameName, true); SDVariable e = f().enter(argument, frameName, true);
done.put(argument.getVarName(), e); done.put(argument.name(), e);
return e; return e;
} }
}); });
@ -3371,7 +3371,7 @@ public abstract class SDBaseOps {
//cleanup partially added block //cleanup partially added block
for(SDVariable v : sd().getVariablesInScope(ifScope)) for(SDVariable v : sd().getVariablesInScope(ifScope))
sd().getVariables().remove(v.getVarName()); sd().getVariables().remove(v.name());
for(SameDiffOp op : sd().getOpsInScope(ifScope)) { for(SameDiffOp op : sd().getOpsInScope(ifScope)) {
for(String in : op.getInputsToOp()){ for(String in : op.getInputsToOp()){
@ -3381,7 +3381,7 @@ public abstract class SDBaseOps {
} }
throw new IllegalStateException("Can not use " + pred.getVarName() throw new IllegalStateException("Can not use " + pred.name()
+ " as the condition of an If statement, the condition must be a boolean."); + " as the condition of an If statement, the condition must be a boolean.");
} }
@ -3394,15 +3394,15 @@ public abstract class SDBaseOps {
public SDVariable intercept(SDVariable argument) { public SDVariable intercept(SDVariable argument) {
// if its declared in the if, we don't care acout it // if its declared in the if, we don't care acout it
if(!declared.contains(argument.getVarName())) if(!declared.contains(argument.name()))
return argument; return argument;
// if we've already added a switch, move on // if we've already added a switch, move on
if(switches.containsKey(argument.getVarName())) if(switches.containsKey(argument.name()))
return switches.get(argument.getVarName())[1]; return switches.get(argument.name())[1];
SDVariable[] s = f().switchOp(argument, pred); SDVariable[] s = f().switchOp(argument, pred);
switches.put(argument.getVarName(), s); switches.put(argument.name(), s);
return s[1]; return s[1];
} }
}); });
@ -3410,9 +3410,9 @@ public abstract class SDBaseOps {
SDVariable trueOut = trueBody.define(sd()); SDVariable trueOut = trueBody.define(sd());
sd().removeArgumentInterceptor(); sd().removeArgumentInterceptor();
if(declared.contains(trueOut.getVarName())) { if(declared.contains(trueOut.name())) {
SDVariable[] s = f().switchOp(trueOut, pred); SDVariable[] s = f().switchOp(trueOut, pred);
switches.put(trueOut.getVarName(), s); switches.put(trueOut.name(), s);
trueOut = s[1]; trueOut = s[1];
} }
@ -3424,15 +3424,15 @@ public abstract class SDBaseOps {
public SDVariable intercept(SDVariable argument) { public SDVariable intercept(SDVariable argument) {
// if its declared in the if, we don't care acout it // if its declared in the if, we don't care acout it
if(!declared2.contains(argument.getVarName())) if(!declared2.contains(argument.name()))
return argument; return argument;
// if we've already added a switch, move on // if we've already added a switch, move on
if(switches.containsKey(argument.getVarName())) if(switches.containsKey(argument.name()))
return switches.get(argument.getVarName())[0]; return switches.get(argument.name())[0];
SDVariable[] s = f().switchOp(argument, pred); SDVariable[] s = f().switchOp(argument, pred);
switches.put(argument.getVarName(), s); switches.put(argument.name(), s);
return s[0]; return s[0];
} }
}); });
@ -3440,9 +3440,9 @@ public abstract class SDBaseOps {
SDVariable falseOut = falseBody.define(sd()); SDVariable falseOut = falseBody.define(sd());
sd().removeArgumentInterceptor(); sd().removeArgumentInterceptor();
if(declared2.contains(falseOut.getVarName())) { if(declared2.contains(falseOut.name())) {
SDVariable[] s = f().switchOp(falseOut, pred); SDVariable[] s = f().switchOp(falseOut, pred);
switches.put(falseOut.getVarName(), s); switches.put(falseOut.name(), s);
falseOut = s[0]; falseOut = s[0];
} }
falseScope.close(); falseScope.close();

View File

@ -37,7 +37,7 @@ public class SDValidation {
if (v == null) if (v == null)
return; return;
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-numerical data type " + v.dataType()); throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-numerical data type " + v.dataType());
} }
/** /**
@ -52,7 +52,7 @@ public class SDValidation {
return; return;
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8) if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" +
v.getVarName() + "\" with non-integer data type " + v.dataType()); v.name() + "\" with non-integer data type " + v.dataType());
} }
/** /**
@ -65,8 +65,8 @@ public class SDValidation {
*/ */
protected static void validateNumerical(String opName, SDVariable v1, SDVariable v2) { protected static void validateNumerical(String opName, SDVariable v1, SDVariable v2) {
if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8) if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8)
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" +
v2.getVarName() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType()); v2.name() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType());
} }
/** /**
@ -79,7 +79,7 @@ public class SDValidation {
if (v == null) if (v == null)
return; return;
if (!v.dataType().isIntType()) if (!v.dataType().isIntType())
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-integer data type " + v.dataType()); throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-integer data type " + v.dataType());
} }
/** /**
@ -94,7 +94,7 @@ public class SDValidation {
return; return;
if (!v.dataType().isIntType()) if (!v.dataType().isIntType())
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" +
v.getVarName() + "\" with non-integer data type " + v.dataType()); v.name() + "\" with non-integer data type " + v.dataType());
} }
/** /**
@ -107,7 +107,7 @@ public class SDValidation {
if (v == null) if (v == null)
return; return;
if (!v.dataType().isFPType()) if (!v.dataType().isFPType())
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-floating point data type " + v.dataType()); throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-floating point data type " + v.dataType());
} }
/** /**
@ -122,7 +122,7 @@ public class SDValidation {
return; return;
if (!v.dataType().isFPType()) if (!v.dataType().isFPType())
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an floating point type; got variable \"" + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an floating point type; got variable \"" +
v.getVarName() + "\" with non-floating point data type " + v.dataType()); v.name() + "\" with non-floating point data type " + v.dataType());
} }
/** /**
@ -135,7 +135,7 @@ public class SDValidation {
if (v == null) if (v == null)
return; return;
if (v.dataType() != DataType.BOOL) if (v.dataType() != DataType.BOOL)
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-boolean point data type " + v.dataType()); throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-boolean point data type " + v.dataType());
} }
/** /**
@ -150,7 +150,7 @@ public class SDValidation {
return; return;
if (v.dataType() != DataType.BOOL) if (v.dataType() != DataType.BOOL)
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an boolean variable; got variable \"" + throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an boolean variable; got variable \"" +
v.getVarName() + "\" with non-boolean data type " + v.dataType()); v.name() + "\" with non-boolean data type " + v.dataType());
} }
/** /**
@ -162,8 +162,8 @@ public class SDValidation {
*/ */
protected static void validateBool(String opName, SDVariable v1, SDVariable v2) { protected static void validateBool(String opName, SDVariable v1, SDVariable v2) {
if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL) if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL)
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" +
v2.getVarName() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType()); v2.name() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType());
} }
/** /**
@ -190,7 +190,7 @@ public class SDValidation {
String[] names = new String[vars.length]; String[] names = new String[vars.length];
DataType[] dtypes = new DataType[vars.length]; DataType[] dtypes = new DataType[vars.length];
for (int j = 0; j < vars.length; j++) { for (int j = 0; j < vars.length; j++) {
names[j] = vars[j].getVarName(); names[j] = vars[j].name();
dtypes[j] = vars[j].dataType(); dtypes[j] = vars[j].dataType();
} }
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to variables with different datatypes:" + throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to variables with different datatypes:" +

View File

@ -763,7 +763,7 @@ public class FlatBuffersMapper {
SDVariable[] inputs = node.args(); SDVariable[] inputs = node.args();
for (SDVariable input : inputs) { for (SDVariable input : inputs) {
String varName = input.getVarName(); String varName = input.name();
int outIdx; int outIdx;
if (sameDiff.getVariables().get(varName).getOutputOfOp() != null) { if (sameDiff.getVariables().get(varName).getOutputOfOp() != null) {
DifferentialFunction df = sameDiff.getOps().get(sameDiff.getVariables().get(varName).getOutputOfOp()).getOp(); DifferentialFunction df = sameDiff.getOps().get(sameDiff.getVariables().get(varName).getOutputOfOp()).getOp();

View File

@ -69,8 +69,8 @@ public class GraphTransformUtil {
// we want to end up with (x -> A -> z) // we want to end up with (x -> A -> z)
List<DifferentialFunction> allSubGraphFns = sg.allFunctionsInSubgraph(); List<DifferentialFunction> allSubGraphFns = sg.allFunctionsInSubgraph();
for (int i = 0; i < oldOutputs.size(); i++) { for (int i = 0; i < oldOutputs.size(); i++) {
String oldOutVarName = oldOutputs.get(i).getVarName(); String oldOutVarName = oldOutputs.get(i).name();
String newOutVarName = newOutputs.get(i).getVarName(); String newOutVarName = newOutputs.get(i).name();
Preconditions.checkState(!oldOutVarName.equals(newOutVarName), "Reusing old variables not yet implemented"); Preconditions.checkState(!oldOutVarName.equals(newOutVarName), "Reusing old variables not yet implemented");
//Update inputs for ops: if X->opA, and now Y->opA, then X.inputsForOps contains "opA"; Y.inputsForOps should be updated //Update inputs for ops: if X->opA, and now Y->opA, then X.inputsForOps contains "opA"; Y.inputsForOps should be updated
@ -133,7 +133,7 @@ public class GraphTransformUtil {
//Step 2: Update input variables: if X -> (subgraph) exists, then X.inputsForOp needs to be updated //Step 2: Update input variables: if X -> (subgraph) exists, then X.inputsForOp needs to be updated
List<SDVariable> inputs = sg.inputs(); List<SDVariable> inputs = sg.inputs();
for (SDVariable v : inputs) { for (SDVariable v : inputs) {
Variable var = sd.getVariables().get(v.getVarName()); Variable var = sd.getVariables().get(v.name());
if (var.getInputsForOp() != null) { if (var.getInputsForOp() != null) {
List<String> newInputsForOp = new ArrayList<>(var.getInputsForOp()); List<String> newInputsForOp = new ArrayList<>(var.getInputsForOp());
for (String opName : var.getInputsForOp()) { for (String opName : var.getInputsForOp()) {
@ -160,7 +160,7 @@ public class GraphTransformUtil {
SDVariable[] outputs = df.outputVariables(); SDVariable[] outputs = df.outputVariables();
if (outputs != null) { if (outputs != null) {
for (SDVariable v : outputs) { for (SDVariable v : outputs) {
vars.remove(v.getVarName()); vars.remove(v.name());
} }
} }
} }

View File

@ -62,7 +62,7 @@ public class SubGraph {
//But suppose same subgraph, but connection y -> a exists; then Y must be an output, because it's used somewhere else //But suppose same subgraph, but connection y -> a exists; then Y must be an output, because it's used somewhere else
List<SDVariable> filteredOutputs = new ArrayList<>(allOutputs.size()); List<SDVariable> filteredOutputs = new ArrayList<>(allOutputs.size());
for(SDVariable v : allOutputs){ for(SDVariable v : allOutputs){
Variable var = sameDiff.getVariables().get(v.getVarName()); Variable var = sameDiff.getVariables().get(v.name());
List<String> inputsFor = var.getInputsForOp(); List<String> inputsFor = var.getInputsForOp();
boolean allInSubgraph = true; boolean allInSubgraph = true;
if(inputsFor != null){ if(inputsFor != null){

View File

@ -77,7 +77,7 @@ public class SubGraphPredicate extends OpPredicate {
} }
SDVariable in = inputs[inNum]; SDVariable in = inputs[inNum];
DifferentialFunction df = sameDiff.getVariableOutputOp(in.getVarName()); DifferentialFunction df = sameDiff.getVariableOutputOp(in.name());
if (df == null || !e.getValue().matches(sameDiff, df)) { if (df == null || !e.getValue().matches(sameDiff, df)) {
return false; return false;
} }
@ -103,7 +103,7 @@ public class SubGraphPredicate extends OpPredicate {
for(Map.Entry<Integer,OpPredicate> entry : opInputSubgraphPredicates.entrySet()){ for(Map.Entry<Integer,OpPredicate> entry : opInputSubgraphPredicates.entrySet()){
OpPredicate p2 = entry.getValue(); OpPredicate p2 = entry.getValue();
SDVariable arg = rootFn.arg(entry.getKey()); SDVariable arg = rootFn.arg(entry.getKey());
DifferentialFunction df = sd.getVariableOutputOp(arg.getVarName()); DifferentialFunction df = sd.getVariableOutputOp(arg.name());
if(df != null){ if(df != null){
childNodes.add(df); childNodes.add(df);

View File

@ -107,7 +107,7 @@ public class GradCheckUtil {
Set<String> fnOutputs = new HashSet<>(); Set<String> fnOutputs = new HashSet<>();
for(DifferentialFunction f : sd.ops()){ for(DifferentialFunction f : sd.ops()){
for(SDVariable s : f.outputVariables()){ for(SDVariable s : f.outputVariables()){
fnOutputs.add(s.getVarName()); fnOutputs.add(s.name());
} }
} }
@ -171,7 +171,7 @@ public class GradCheckUtil {
Map<String,INDArray> grad = new HashMap<>(); Map<String,INDArray> grad = new HashMap<>();
for(SDVariable v : sd.variables()){ for(SDVariable v : sd.variables()){
if (fnOutputs.contains(v.getVarName())) { if (fnOutputs.contains(v.name())) {
//This is not an input to the graph //This is not an input to the graph
continue; continue;
} }
@ -179,20 +179,20 @@ public class GradCheckUtil {
//Skip non-fp variables, or variables that don't impact loss function value //Skip non-fp variables, or variables that don't impact loss function value
continue; continue;
} }
SDVariable g = sd.grad(v.getVarName()); SDVariable g = sd.grad(v.name());
if(g == null){ if(g == null){
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\""); throw new IllegalStateException("Null gradient variable for \"" + v.name() + "\"");
} }
INDArray ga = gm.get(v.getVarName()); INDArray ga = gm.get(v.name());
if(ga == null){ if(ga == null){
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName()); throw new IllegalStateException("Null gradient array encountered for variable: " + v.name());
} }
if(!Arrays.equals(v.getArr().shape(), ga.shape())){ if(!Arrays.equals(v.getArr().shape(), ga.shape())){
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" +
v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + v.name() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " +
Arrays.toString(ga.shape())); Arrays.toString(ga.shape()));
} }
grad.put(v.getVarName(), ga.dup()); grad.put(v.name(), ga.dup());
} }
//Validate gradients for each variable: //Validate gradients for each variable:
@ -201,25 +201,25 @@ public class GradCheckUtil {
double maxError = 0.0; double maxError = 0.0;
Random r = new Random(12345); Random r = new Random(12345);
for(SDVariable s : sd.variables()){ for(SDVariable s : sd.variables()){
if (fnOutputs.contains(s.getVarName()) || !s.dataType().isFPType()) { if (fnOutputs.contains(s.name()) || !s.dataType().isFPType()) {
//This is not an input to the graph, or is not a floating point input (so can't be gradient checked) //This is not an input to the graph, or is not a floating point input (so can't be gradient checked)
continue; continue;
} }
if(skipVariables != null && skipVariables.contains(s.getVarName())){ if(skipVariables != null && skipVariables.contains(s.name())){
log.info("Grad check: skipping variable \"{}\"", s.getVarName()); log.info("Grad check: skipping variable \"{}\"", s.name());
continue; continue;
} }
if(s.dataType() != DataType.DOUBLE){ if(s.dataType() != DataType.DOUBLE){
log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.getVarName(), s.dataType()); log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.name(), s.dataType());
} }
String name = s.getVarName(); String name = s.name();
INDArray a = s.getArr(); INDArray a = s.getArr();
long n = a.length(); long n = a.length();
if(print){ if(print){
log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n); log.info("Starting test for variable \"{}\" with {} values", s.name(), n);
} }
Iterator<long[]> iter; Iterator<long[]> iter;
@ -256,11 +256,11 @@ public class GradCheckUtil {
iter = new NdIndexIterator('c',a.shape()); iter = new NdIndexIterator('c',a.shape());
} }
INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.getVarName())); INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.name()));
if(varMask != null){ if(varMask != null){
Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.getVarName(), a.shape(), varMask.shape()); Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.name(), a.shape(), varMask.shape());
Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.getVarName(), varMask.dataType()); Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.name(), varMask.dataType());
} }
int i=0; int i=0;
@ -281,12 +281,12 @@ public class GradCheckUtil {
double orig = a.getDouble(idx); double orig = a.getDouble(idx);
a.putScalar(idx, orig+eps); a.putScalar(idx, orig+eps);
double scorePlus = 0.0; double scorePlus = 0.0;
Map<String,INDArray> m = sd.exec(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue(); Map<String,INDArray> m = sd.output(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue();
for(INDArray arr : m.values()){ for(INDArray arr : m.values()){
scorePlus += arr.sumNumber().doubleValue(); scorePlus += arr.sumNumber().doubleValue();
} }
a.putScalar(idx, orig-eps); a.putScalar(idx, orig-eps);
m = sd.exec(placeholderValues, lossFnVariables); m = sd.output(placeholderValues, lossFnVariables);
double scoreMinus = 0.0; double scoreMinus = 0.0;
for(INDArray arr : m.values()){ for(INDArray arr : m.values()){
scoreMinus += arr.sumNumber().doubleValue(); scoreMinus += arr.sumNumber().doubleValue();
@ -294,9 +294,9 @@ public class GradCheckUtil {
a.putScalar(idx, orig); a.putScalar(idx, orig);
double numericalGrad = (scorePlus - scoreMinus) / (2 * eps); double numericalGrad = (scorePlus - scoreMinus) / (2 * eps);
INDArray aGrad = grad.get(s.getVarName()); INDArray aGrad = grad.get(s.name());
if(aGrad == null){ if(aGrad == null){
log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.getVarName()); log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.name());
continue; continue;
} }
double analyticGrad = aGrad.getDouble(idx); double analyticGrad = aGrad.getDouble(idx);
@ -497,12 +497,12 @@ public class GradCheckUtil {
listener.setIdx(idx); listener.setIdx(idx);
listener.setEps(config.getEps()); listener.setEps(config.getEps());
double scorePlus = 0.0; double scorePlus = 0.0;
Map<String,INDArray> m = sd.exec(config.getPlaceholderValues(), lossFnVariables); Map<String,INDArray> m = sd.output(config.getPlaceholderValues(), lossFnVariables);
for(INDArray arr : m.values()){ for(INDArray arr : m.values()){
scorePlus += arr.sumNumber().doubleValue(); scorePlus += arr.sumNumber().doubleValue();
} }
listener.setEps(-config.getEps()); listener.setEps(-config.getEps());
m = sd.exec(config.getPlaceholderValues(), lossFnVariables); m = sd.output(config.getPlaceholderValues(), lossFnVariables);
double scoreMinus = 0.0; double scoreMinus = 0.0;
for(INDArray arr : m.values()){ for(INDArray arr : m.values()){
scoreMinus += arr.sumNumber().doubleValue(); scoreMinus += arr.sumNumber().doubleValue();
@ -597,10 +597,10 @@ public class GradCheckUtil {
Set<String> varSetStr = new HashSet<>(); Set<String> varSetStr = new HashSet<>();
for(SDVariable v : vars){ for(SDVariable v : vars){
if(varSetStr.contains(v.getVarName())){ if(varSetStr.contains(v.name())){
throw new IllegalStateException("Variable with name " + v.getVarName() + " already encountered"); throw new IllegalStateException("Variable with name " + v.name() + " already encountered");
} }
varSetStr.add(v.getVarName()); varSetStr.add(v.name());
} }
Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list"); Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list");
@ -645,7 +645,7 @@ public class GradCheckUtil {
Map<String, Variable> variableMap = sd.getVariables(); Map<String, Variable> variableMap = sd.getVariables();
Preconditions.checkState(vars.size() == variableMap.size(), "Variable map size check failed"); Preconditions.checkState(vars.size() == variableMap.size(), "Variable map size check failed");
for(Map.Entry<String, Variable> e : variableMap.entrySet()){ for(Map.Entry<String, Variable> e : variableMap.entrySet()){
Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().getVarName()), "Name not equal"); Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().name()), "Name not equal");
} }
if(generateAndCheckGradFn) { if(generateAndCheckGradFn) {

View File

@ -208,7 +208,7 @@ public class OpValidation {
e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg()); e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
} }
INDArray actual = out.get(v.getVarName()); INDArray actual = out.get(v.name());
if (actual == null) { if (actual == null) {
throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\""); throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\"");
} }
@ -271,8 +271,8 @@ public class OpValidation {
for( int i=0; i<vars.size(); i++ ){ for( int i=0; i<vars.size(); i++ ){
SDVariable vO = vars.get(i); SDVariable vO = vars.get(i);
SDVariable vD = varsDe.get(i); SDVariable vD = varsDe.get(i);
Preconditions.checkState(vO.getVarName().equals(vD.getVarName()), "Names should be equal for variable %s: expected %s vs %s", Preconditions.checkState(vO.name().equals(vD.name()), "Names should be equal for variable %s: expected %s vs %s",
i, vO.getVarName(), vD.getVarName()); i, vO.name(), vD.name());
} }
//Check ops: //Check ops:

View File

@ -121,7 +121,7 @@ public class TestCase {
* @param output Expected INDArray * @param output Expected INDArray
*/ */
public TestCase expected(@NonNull SDVariable var, @NonNull INDArray output) { public TestCase expected(@NonNull SDVariable var, @NonNull INDArray output) {
return expected(var.getVarName(), output); return expected(var.name(), output);
} }
/** /**
@ -135,7 +135,7 @@ public class TestCase {
} }
public TestCase expected(SDVariable var, Function<INDArray,String> validationFn){ public TestCase expected(SDVariable var, Function<INDArray,String> validationFn){
return expected(var.getVarName(), validationFn); return expected(var.name(), validationFn);
} }
/** /**

View File

@ -487,11 +487,15 @@ public class LogFileWriter {
//Create outputs list: //Create outputs list:
List<String> outputs = sd.outputs(); List<String> outputs = sd.outputs();
int outputsOffset = 0;
if(outputs != null && !outputs.isEmpty()) {
int[] outputListStrOffsets = new int[outputs.size()]; int[] outputListStrOffsets = new int[outputs.size()];
for (int i = 0; i < outputListStrOffsets.length; i++) { for (int i = 0; i < outputListStrOffsets.length; i++) {
outputListStrOffsets[i] = fbb.createString(outputs.get(i)); outputListStrOffsets[i] = fbb.createString(outputs.get(i));
} }
int outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets); outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets);
}
//Create variables list //Create variables list
Map<String,Variable> varMap = sd.getVariables(); Map<String,Variable> varMap = sd.getVariables();

View File

@ -46,6 +46,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class, org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class, org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class, org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
org.nd4j.linalg.api.ops.custom.KnnMinDistance.class,
org.nd4j.linalg.api.ops.custom.SpTreeCell.class, org.nd4j.linalg.api.ops.custom.SpTreeCell.class,
org.nd4j.linalg.api.ops.custom.Flatten.class, org.nd4j.linalg.api.ops.custom.Flatten.class,
org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd.class, org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd.class,
@ -322,7 +323,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.BinCount.class, org.nd4j.linalg.api.ops.impl.transforms.BinCount.class,
org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class, org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class,
org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class, org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class,
org.nd4j.linalg.api.ops.impl.transforms.Constant.class,
org.nd4j.linalg.api.ops.impl.transforms.Histogram.class, org.nd4j.linalg.api.ops.impl.transforms.Histogram.class,
org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class, org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class,
org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class, org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class,

View File

@ -810,8 +810,6 @@ public class TFGraphMapper {
on.setValueFor(currentField, tensor.getFloat(0)); on.setValueFor(currentField, tensor.getFloat(0));
} }
} }
} else {
on.getSameDiff().addPropertyToResolve(on, entry.getKey());
} }
} }
} }

View File

@ -63,19 +63,11 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
this.inPlace = inPlace; this.inPlace = inPlace;
this.dimension = dimension; this.dimension = dimension;
if(Shape.isPlaceholderShape(i_v1.getShape())) {
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
}
if(Shape.isPlaceholderShape(i_v2.getShape())) {
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
}
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
} else { } else {
throw new IllegalArgumentException("Input not null variables."); throw new IllegalArgumentException("Input not null variables.");
} }
} }
public BaseBroadcastBoolOp(SameDiff sameDiff) { public BaseBroadcastBoolOp(SameDiff sameDiff) {

View File

@ -64,19 +64,10 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
this.inPlace = inPlace; this.inPlace = inPlace;
this.dimension = dimension; this.dimension = dimension;
if(Shape.isPlaceholderShape(i_v1.getShape())) {
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
}
if(Shape.isPlaceholderShape(i_v2.getShape())) {
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
}
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
} else { } else {
throw new IllegalArgumentException("Input not null variables."); throw new IllegalArgumentException("Input not null variables.");
} }
} }
public BaseBroadcastOp(SameDiff sameDiff) { public BaseBroadcastOp(SameDiff sameDiff) {

View File

@ -53,11 +53,8 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
this.dimensions = dimensions; this.dimensions = dimensions;
f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v);
sameDiff.addArgsFor(new SDVariable[]{i_v},this); sameDiff.addArgsFor(new SDVariable[]{i_v},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
} else { } else {
throw new IllegalArgumentException("Input not null variable."); throw new IllegalArgumentException("Input not null variable.");
} }
@ -75,17 +72,9 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
this.dimensions = dimensions; this.dimensions = dimensions;
f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v);
f().validateDifferentialFunctionsameDiff(i_v2); f().validateDifferentialFunctionsameDiff(i_v2);
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
this.yVertexId = i_v2.getVarName(); this.yVertexId = i_v2.name();
sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this); sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
if(Shape.isPlaceholderShape(i_v2.getShape())) {
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
}
} else { } else {
throw new IllegalArgumentException("Input not null variable."); throw new IllegalArgumentException("Input not null variable.");
} }

View File

@ -247,7 +247,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
val outputNames = sameDiff.getOutputsForOp(this); val outputNames = sameDiff.getOutputsForOp(this);
//no need to dynamically create if already exists //no need to dynamically create if already exists
if(outputNames != null) { if(outputNames != null) {
zVertexId = sameDiff.getVariable(outputNames[0]).getVarName(); zVertexId = sameDiff.getVariable(outputNames[0]).name();
return new SDVariable[]{sameDiff.getVariable(outputNames[0])}; return new SDVariable[]{sameDiff.getVariable(outputNames[0])};
@ -261,7 +261,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
return newVars; return newVars;
} }
sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr); sameDiff.setArrayForVariable(newVars[0].name(),inputArr);
z = inputArr; z = inputArr;
if(sameDiff.getOutputsForOp(this) == null) if(sameDiff.getOutputsForOp(this) == null)
sameDiff.addOutgoingFor(newVars,this); sameDiff.addOutgoingFor(newVars,this);

View File

@ -61,7 +61,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
this.dimensions = dimensions; this.dimensions = dimensions;
f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v);
this.keepDims = keepDims; this.keepDims = keepDims;
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
sameDiff.addArgsFor(new String[]{xVertexId},this); sameDiff.addArgsFor(new String[]{xVertexId},this);
} else { } else {
throw new IllegalArgumentException("Input not null variable."); throw new IllegalArgumentException("Input not null variable.");
@ -81,8 +81,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
this.dimensions = dimensions; this.dimensions = dimensions;
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
this.yVertexId = i_v2.getVarName(); this.yVertexId = i_v2.name();
f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v);
f().validateDifferentialFunctionsameDiff(i_v2); f().validateDifferentialFunctionsameDiff(i_v2);
this.keepDims = keepDims; this.keepDims = keepDims;

View File

@ -74,11 +74,8 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp {
super(sameDiff,inPlace,extraArgs); super(sameDiff,inPlace,extraArgs);
this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar);
if (i_v != null) { if (i_v != null) {
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
sameDiff.addArgsFor(new String[]{xVertexId},this); sameDiff.addArgsFor(new String[]{xVertexId},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v);
} else { } else {
throw new IllegalArgumentException("Input not null variable."); throw new IllegalArgumentException("Input not null variable.");

View File

@ -93,11 +93,8 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp {
Object[] extraArgs) { Object[] extraArgs) {
super(sameDiff,inPlace,extraArgs); super(sameDiff,inPlace,extraArgs);
this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar);
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
sameDiff.addArgsFor(new String[]{xVertexId},this); sameDiff.addArgsFor(new String[]{xVertexId},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v);
} }

View File

@ -56,16 +56,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp {
f().validateDifferentialFunctionsameDiff(i_v2); f().validateDifferentialFunctionsameDiff(i_v2);
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
this.inPlace = inPlace; this.inPlace = inPlace;
this.xVertexId = i_v1.getVarName(); this.xVertexId = i_v1.name();
this.yVertexId = i_v2.getVarName(); this.yVertexId = i_v2.name();
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
if(Shape.isPlaceholderShape(i_v1.getShape())) {
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
}
if(Shape.isPlaceholderShape(i_v2.getShape())) {
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
}
} else { } else {
throw new IllegalArgumentException("Input not null variables."); throw new IllegalArgumentException("Input not null variables.");
} }
@ -87,18 +80,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp {
f().validateDifferentialFunctionsameDiff(i_v1); f().validateDifferentialFunctionsameDiff(i_v1);
f().validateDifferentialFunctionsameDiff(i_v2); f().validateDifferentialFunctionsameDiff(i_v2);
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
this.xVertexId = i_v1.getVarName(); this.xVertexId = i_v1.name();
this.yVertexId = i_v2.getVarName(); this.yVertexId = i_v2.name();
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
if(Shape.isPlaceholderShape(i_v1.getShape())) {
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
}
if(Shape.isPlaceholderShape(i_v2.getShape())) {
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
}
} else { } else {
throw new IllegalArgumentException("Input not null variables."); throw new IllegalArgumentException("Input not null variables.");
} }
@ -130,14 +114,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp {
if (i_v != null) { if (i_v != null) {
f().validateDifferentialFunctionsameDiff(i_v); f().validateDifferentialFunctionsameDiff(i_v);
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
sameDiff.addArgsFor(new SDVariable[]{i_v},this); sameDiff.addArgsFor(new SDVariable[]{i_v},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
} else { } else {
throw new IllegalArgumentException("Input must not null variable."); throw new IllegalArgumentException("Input must not null variable.");
} }

View File

@ -223,7 +223,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
if (args().length >= 1) { if (args().length >= 1) {
val arr = args()[0].getArr(); val arr = args()[0].getArr();
if (arr != null) { if (arr != null) {
sameDiff.setArrayForVariable(newVars[0].getVarName(), arr); sameDiff.setArrayForVariable(newVars[0].name(), arr);
addOutputArgument(arr); addOutputArgument(arr);
} }
} }

View File

@ -57,7 +57,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
public ExternalErrorsFunction(){ } public ExternalErrorsFunction(){ }
public String getGradPlaceholderName(){ public String getGradPlaceholderName(){
return arg().getVarName() + "-grad"; return arg().name() + "-grad";
} }
@Override @Override
@ -70,7 +70,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
out = sameDiff.getVariable(name); out = sameDiff.getVariable(name);
} else { } else {
out = sameDiff.zero(name, Nd4j.dataType(), 1); out = sameDiff.zero(name, Nd4j.dataType(), 1);
sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.getVarName())); sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.name()));
sameDiff.getVariables().get(name).setOutputOfOp(getOwnName()); sameDiff.getVariables().get(name).setOutputOfOp(getOwnName());
} }
} }
@ -83,7 +83,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
if (gradVariables == null) { if (gradVariables == null) {
gradVariables = new HashMap<>(); gradVariables = new HashMap<>();
for(SDVariable arg : args()){ for(SDVariable arg : args()){
INDArray gradArr = gradients.get(arg.getVarName()); INDArray gradArr = gradients.get(arg.name());
SDVariable grad; SDVariable grad;
DataType dt = arg.dataType(); DataType dt = arg.dataType();
String n = getGradPlaceholderName(); String n = getGradPlaceholderName();
@ -94,7 +94,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
} else { } else {
grad = sameDiff.var(n, VariableType.PLACEHOLDER, null, dt); grad = sameDiff.var(n, VariableType.PLACEHOLDER, null, dt);
} }
gradVariables.put(arg.getVarName(), grad); gradVariables.put(arg.name(), grad);
out.add(grad); out.add(grad);
} }
} }

View File

@ -196,12 +196,12 @@ public class DeConv2D extends DynamicCustomOp {
val paddingMode = aPadding.getS().toStringUtf8(); val paddingMode = aPadding.getS().toStringUtf8();
val args = args(); val args = args();
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
if (arr == null) { if (arr == null) {
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
// TODO: arguable. it might be easier to permute weights once // TODO: arguable. it might be easier to permute weights once
//arr = (arr.permute(3, 2, 0, 1).dup('c')); //arr = (arr.permute(3, 2, 0, 1).dup('c'));
val varForOp = initWith.getVariable(args[1].getVarName()); val varForOp = initWith.getVariable(args[1].name());
if (arr != null) if (arr != null)
initWith.associateArrayWithVariable(arr, varForOp); initWith.associateArrayWithVariable(arr, varForOp);

View File

@ -158,10 +158,10 @@ public class DeConv3D extends DynamicCustomOp {
val paddingMode = aPadding.getS().toStringUtf8(); val paddingMode = aPadding.getS().toStringUtf8();
val args = args(); val args = args();
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
if (arr == null) { if (arr == null) {
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef); arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
val varForOp = initWith.getVariable(args[1].getVarName()); val varForOp = initWith.getVariable(args[1].name());
if (arr != null) if (arr != null)
initWith.associateArrayWithVariable(arr, varForOp); initWith.associateArrayWithVariable(arr, varForOp);
} }

View File

@ -193,12 +193,6 @@ public class Mmul extends DynamicCustomOp {
.transposeA(isTransposeA).transposeB(isTransposeB) .transposeA(isTransposeA).transposeB(isTransposeB)
.build(); .build();
this.mt = mMulTranspose; this.mt = mMulTranspose;
val args = args();
for(val arg : args) {
if(sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) {
sameDiff.addPropertyToResolve(this,arg.getVarName());
}
}
iArguments.clear(); iArguments.clear();
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB())); addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()));
} }

View File

@ -130,10 +130,7 @@ public class Concat extends DynamicCustomOp {
val variable = initWith.getVariable(input); val variable = initWith.getVariable(input);
// concat dimension is only possible // concat dimension is only possible
if (variable != null && variable.getArr() == null) { if (variable != null) {
sameDiff.addPropertyToResolve(this, input);
} else if (variable != null) {
val arr = variable.getArr(); val arr = variable.getArr();
if (arr.length() == 1) { if (arr.length() == 1) {
concatDimension = arr.getInt(0); concatDimension = arr.getInt(0);

View File

@ -124,13 +124,7 @@ public class Transpose extends DynamicCustomOp {
return; return;
} }
INDArray arr = sameDiff.getArrForVarName(arg().getVarName()); INDArray arr = sameDiff.getArrForVarName(arg().name());
if (arr == null) {
val arrVar = sameDiff.getVariable(arg().getVarName());
arr = arrVar.getWeightInitScheme().create(arrVar.dataType(), arrVar.getShape());
sameDiff.setArrayForVariable(arg().getVarName(), arr);
}
if(permuteArrayOp != null){ if(permuteArrayOp != null){
addInputArgument(arr, permuteArrayOp); addInputArgument(arr, permuteArrayOp);
@ -138,16 +132,12 @@ public class Transpose extends DynamicCustomOp {
addInputArgument(arr); addInputArgument(arr);
} }
if (arr != null && permuteDims == null) { if (arr != null && permuteDims == null) {
this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank())); this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
} }
if (permuteDims != null && permuteDims.length < arg().getShape().length) if (permuteDims != null && permuteDims.length < arg().getShape().length)
throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified"); throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
} }
@Override @Override

View File

@ -72,7 +72,7 @@ public class TensorArrayConcat extends BaseTensorOp {
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){ public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
//Same output type as the TensorArray - which is defined by input 0 //Same output type as the TensorArray - which is defined by input 0
SDVariable tArr = arg(0); SDVariable tArr = arg(0);
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name());
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
return Collections.singletonList(dt); return Collections.singletonList(dt);
} }

View File

@ -72,7 +72,7 @@ public class TensorArrayGather extends BaseTensorOp {
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){ public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
//Same output type as the TensorArray - which is defined by input 0 //Same output type as the TensorArray - which is defined by input 0
SDVariable tArr = arg(0); SDVariable tArr = arg(0);
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName()); TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name());
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType(); org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
return Collections.singletonList(dt); return Collections.singletonList(dt);
} }

View File

@ -73,7 +73,7 @@ public class TensorArrayRead extends BaseTensorOp {
dt = importDataType; dt = importDataType;
} else { } else {
SDVariable tArr = arg(0); SDVariable tArr = arg(0);
DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.getVarName()); DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.name());
TensorArray t3 = (TensorArray) op; TensorArray t3 = (TensorArray) op;
dt = t3.getTensorArrayDataType(); dt = t3.getTensorArrayDataType();
} }

View File

@ -71,9 +71,9 @@ public class CheckNumerics extends DynamicCustomOp {
SDVariable msg = initWith.constant(name + "/message", Nd4j.scalar(str)); SDVariable msg = initWith.constant(name + "/message", Nd4j.scalar(str));
List<String> newInputs = new ArrayList<>(2); List<String> newInputs = new ArrayList<>(2);
newInputs.addAll(initWith.getOps().get(name).getInputsToOp()); newInputs.addAll(initWith.getOps().get(name).getInputsToOp());
newInputs.add(msg.getVarName()); newInputs.add(msg.name());
initWith.getOps().get(name).setInputsToOp(newInputs); initWith.getOps().get(name).setInputsToOp(newInputs);
initWith.getVariables().get(msg.getVarName()).setInputsForOp(Collections.singletonList(getOwnName())); } initWith.getVariables().get(msg.name()).setInputsForOp(Collections.singletonList(getOwnName())); }
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){

View File

@ -1,91 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms;
import lombok.Data;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
@Data
public class Constant extends BaseTransformSameOp {
public Constant() {
}
protected Constant(SameDiff sameDiff,
SDVariable i_v,
long[] shape,
boolean inPlace) {
super();
sameDiff.putOrUpdateShapeForVarName(i_v.getVarName(), shape, false);
this.xVertexId = i_v.getVarName();
this.inPlace = inPlace;
this.sameDiff = sameDiff;
}
public Constant(SameDiff sameDiff, SDVariable i_v, long[] shape) {
this(sameDiff, i_v, shape, false);
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
return Collections.singletonList(sameDiff.zerosLike(arg()));
}
@Override
public DifferentialFunction dup() {
Constant ret = new Constant(sameDiff, sameDiff.getVariable(outputVariables()[0].getVarName())
, sameDiff.getShapeForVarName(outputVariables()[0].getVarName()));
Constant differentialFunction = ret;
return differentialFunction;
}
@Override
public int opNum() {
return 15;
}
@Override
public String opName() {
return "constant";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow opName found for " + opName());
}
}

View File

@ -118,7 +118,7 @@ public class MaxOut extends BaseTransformOp {
if(arg() == null) if(arg() == null)
throw new ND4JIllegalStateException("No arg found for op!"); throw new ND4JIllegalStateException("No arg found for op!");
val arr = sameDiff.getArrForVarName(arg().getVarName()); val arr = sameDiff.getArrForVarName(arg().name());
if(arr == null) if(arr == null)
return Collections.emptyList(); return Collections.emptyList();

View File

@ -28,13 +28,19 @@ import java.util.Collections;
import java.util.List; import java.util.List;
/** /**
* * TanhDerivative: calculated dL/dIn from dL/dOut and In
*/ */
public class TanhDerivative extends DynamicCustomOp { public class TanhDerivative extends DynamicCustomOp {
public TanhDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { public TanhDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, new SDVariable[]{i_v1, i_v2}); super(sameDiff, new SDVariable[]{i_v1, i_v2});
} }
/**
*
* @param x Input
* @param y Gradient at output (dL/dOut)
* @param z Output array, gradient at input (dL/dIn - to be calculated)
*/
public TanhDerivative(INDArray x, INDArray y, INDArray z) { public TanhDerivative(INDArray x, INDArray y, INDArray z) {
super(null, new INDArray[]{x, y}, new INDArray[]{z}); super(null, new INDArray[]{x, y}, new INDArray[]{z});
} }
@ -42,6 +48,10 @@ public class TanhDerivative extends DynamicCustomOp {
public TanhDerivative() { public TanhDerivative() {
} }
/**
* @param x Input
* @param y Gradient at output (dL/dOut)
*/
public TanhDerivative(INDArray x, INDArray y) { public TanhDerivative(INDArray x, INDArray y) {
this(x, y, null); this(x, y, null);
} }

View File

@ -43,11 +43,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
public BaseRandomOp(SameDiff sameDiff, SDVariable i_v) { public BaseRandomOp(SameDiff sameDiff, SDVariable i_v) {
Preconditions.checkNotNull(i_v, "Input variable can't be null with this constructor"); Preconditions.checkNotNull(i_v, "Input variable can't be null with this constructor");
this.sameDiff = sameDiff; this.sameDiff = sameDiff;
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.name();
sameDiff.addArgsFor(new String[]{xVertexId},this); sameDiff.addArgsFor(new String[]{xVertexId},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
} }
public BaseRandomOp(SameDiff sd, long[] shape){ public BaseRandomOp(SameDiff sd, long[] shape){
@ -73,11 +70,7 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
if(shape != null){ if(shape != null){
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType())); return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType()));
} else { } else {
List<LongShapeDescriptor> ret = new ArrayList<>(1); return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType())));
val shape = sameDiff.getShapeForVarName(args()[0].getVarName());
if (shape != null)
ret.add(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType())));
return ret;
} }
} }

View File

@ -5212,6 +5212,8 @@ public class Nd4j {
} }
} }
} }
backend.logBackendInit();
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
@ -5625,19 +5627,38 @@ public class Nd4j {
* @return an ndarray created from the in memory * @return an ndarray created from the in memory
* numpy pointer * numpy pointer
*/ */
@SuppressWarnings("WeakerAccess") @SuppressWarnings("WeakerAccess")
public static INDArray createFromNpyPointer(Pointer pointer) { public static INDArray createFromNpyPointer(Pointer pointer) {
return INSTANCE.createFromNpyPointer(pointer); return INSTANCE.createFromNpyPointer(pointer);
} }
/** /**
* Create from a given Numpy .npy file. * Create an INDArray from a given Numpy .npy file.
*
* @param path Path to the .npy file to read
* @return the created ndarray
*/
public static INDArray readNpy(@NonNull String path){
return readNpy(new File(path));
}
/**
* Create an INDArray from a given Numpy .npy file.
* *
* @param file the file to create the ndarray from * @param file the file to create the ndarray from
* @return the created ndarray * @return the created ndarray
*/ */
public static INDArray createFromNpyFile(File file) { public static INDArray readNpy(@NonNull File file){
return createFromNpyFile(file);
}
/**
* Create an INDArray from a given Numpy .npy file.
*
* @param file the file to create the ndarray from
* @return the created ndarray
*/
public static INDArray createFromNpyFile(@NonNull File file) {
if (!file.exists()) if (!file.exists())
throw new IllegalArgumentException("File [" + file.getAbsolutePath() + "] doesn't exist"); throw new IllegalArgumentException("File [" + file.getAbsolutePath() + "] doesn't exist");
@ -5654,7 +5675,7 @@ public class Nd4j {
* @return the loaded ndarray * @return the loaded ndarray
*/ */
@SuppressWarnings("unused") @SuppressWarnings("unused")
public static INDArray createNpyFromInputStream(InputStream is) throws IOException { public static INDArray createNpyFromInputStream(@NonNull InputStream is) throws IOException {
byte[] content = IOUtils.toByteArray(is); byte[] content = IOUtils.toByteArray(is);
return createNpyFromByteArray(content); return createNpyFromByteArray(content);
} }
@ -5668,7 +5689,7 @@ public class Nd4j {
* @param input the input byte array with the npy format * @param input the input byte array with the npy format
* @return the equivalent {@link INDArray} * @return the equivalent {@link INDArray}
*/ */
public static INDArray createNpyFromByteArray(byte[] input) { public static INDArray createNpyFromByteArray(@NonNull byte[] input) {
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length); ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length);
byteBuffer.put(input); byteBuffer.put(input);
byteBuffer.rewind(); byteBuffer.rewind();

View File

@ -20,6 +20,7 @@ import java.util.Properties;
import lombok.Getter; import lombok.Getter;
import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Loader;
import org.nd4j.config.ND4JEnvironmentVars; import org.nd4j.config.ND4JEnvironmentVars;
import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.context.Nd4jContext; import org.nd4j.context.Nd4jContext;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -101,7 +102,12 @@ public class NativeOpsHolder {
} }
//deviceNativeOps.setOmpNumThreads(4); //deviceNativeOps.setOmpNumThreads(4);
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
boolean logInit = Boolean.parseBoolean(logInitProperty);
if(logInit) {
log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads()); log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
}
} catch (Exception | Error e) { } catch (Exception | Error e) {
throw new RuntimeException( throw new RuntimeException(
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html", "ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",

View File

@ -47,7 +47,7 @@ public class MemoryTracker {
val f = new AtomicLong(NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(i)); val f = new AtomicLong(NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(i));
log.debug("Free memory on device_{}: {}", i, f); //log.debug("Free memory on device_{}: {}", i, f);
freePerDevice.add(i, f); freePerDevice.add(i, f);
} }
} }

View File

@ -16,14 +16,24 @@
package org.nd4j.linalg.jcublas; package org.nd4j.linalg.jcublas;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Loader;
import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.io.Resource; import org.nd4j.linalg.io.Resource;
import org.nd4j.nativeblas.Nd4jCuda;
import java.util.List;
import java.util.Map;
import java.util.Properties;
/** /**
* *
*/ */
@Slf4j
public class JCublasBackend extends Nd4jBackend { public class JCublasBackend extends Nd4jBackend {
@ -76,4 +86,34 @@ public class JCublasBackend extends Nd4jBackend {
return JCublasNDArray.class; return JCublasNDArray.class;
} }
@Override
public void logBackendInit() {
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
boolean logInit = Boolean.parseBoolean(logInitProperty);
if(logInit) {
try {
Nd4jCuda.Environment e = Nd4jCuda.Environment.getInstance();
int blasMajor = e.blasMajorVersion();
int blasMinor = e.blasMinorVersion();
int blasPatch = e.blasPatchVersion();
log.info("ND4J CUDA build version: {}.{}.{}", blasMajor, blasMinor, blasPatch);
int nGPUs = Nd4jEnvironment.getEnvironment().getNumGpus();
Properties props = Nd4j.getExecutioner().getEnvironmentInformation();
List<Map<String, Object>> devicesList = (List<Map<String, Object>>) props.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY);
for (int i = 0; i < nGPUs; i++) {
Map<String, Object> dev = devicesList.get(i);
String name = (String) dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY);
int major = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY)).intValue();
int minor = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY)).intValue();
long totalMem = ((Number) dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY)).longValue();
log.info("CUDA device {}: [{}]; cc: [{}.{}]; Total memory: [{}]", i, name, major, minor, totalMem);
}
} catch (Throwable t) {
log.debug("Error logging CUDA backend versions and devices", t);
}
}
}
} }

View File

@ -1890,14 +1890,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void printEnvironmentInformation() { public void printEnvironmentInformation() {
super.printEnvironmentInformation(); super.printEnvironmentInformation();
Properties env = getEnvironmentInformation();
List<Map<String, Object>> devicesList = (List<Map<String, Object>>) env.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY);
for (Map<String, Object> dev : devicesList) {
log.info("Device Name: [{}]; CC: [{}.{}]; Total/free memory: [{}]", dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY),
dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY));
}
} }
@Override @Override

View File

@ -60,4 +60,9 @@ public class CpuBackend extends Nd4jBackend {
public Class getNDArrayClass() { public Class getNDArrayClass() {
return NDArray.class; return NDArray.class;
} }
@Override
public void logBackendInit() {
//No additional logging for CPU backend
}
} }

View File

@ -146,8 +146,8 @@ public class TestSessions extends BaseNd4jTest {
System.out.println("----------------------------------"); System.out.println("----------------------------------");
InferenceSession is = new InferenceSession(sd); InferenceSession is = new InferenceSession(sd);
// String outName = merge.getVarName(); // String outName = merge.name();
String outName = outVar.getVarName(); String outName = outVar.name();
Map<String,INDArray> outMap = is.output(Collections.singletonList(outName), m, null, Map<String,INDArray> outMap = is.output(Collections.singletonList(outName), m, null,
Collections.<String>emptyList(), null, At.defaultAt(Operation.TRAINING)); Collections.<String>emptyList(), null, At.defaultAt(Operation.TRAINING));
@ -181,7 +181,7 @@ public class TestSessions extends BaseNd4jTest {
m.put("b", bArr); m.put("b", bArr);
InferenceSession is = new InferenceSession(sd); InferenceSession is = new InferenceSession(sd);
String n = merge.getVarName(); String n = merge.name();
System.out.println("----------------------------------"); System.out.println("----------------------------------");
Map<String,INDArray> outMap = is.output(Collections.singletonList(n), m, null, Collections.<String>emptyList(), Map<String,INDArray> outMap = is.output(Collections.singletonList(n), m, null, Collections.<String>emptyList(),

View File

@ -118,7 +118,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
SDVariable result = sdVariable.add(scalarOne); SDVariable result = sdVariable.add(scalarOne);
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE); SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
log.info("TOTAL: {}; Id: {}", total.getVarName(), total); log.info("TOTAL: {}; Id: {}", total.name(), total);
INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace); INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace);

View File

@ -79,7 +79,7 @@ public class LayerOpValidation extends BaseOpValidation {
TestCase tc = new TestCase(sameDiff) TestCase tc = new TestCase(sameDiff)
.gradientCheck(true) .gradientCheck(true)
.expectedOutput(res.getVarName(), exp); .expectedOutput(res.name(), exp);
System.out.println(sameDiff.summary()); System.out.println(sameDiff.summary());
System.out.println("============================"); System.out.println("============================");
@ -112,7 +112,7 @@ public class LayerOpValidation extends BaseOpValidation {
TestCase tc = new TestCase(sameDiff) TestCase tc = new TestCase(sameDiff)
.gradientCheck(true) .gradientCheck(true)
.expectedOutput(res.getVarName(), exp); .expectedOutput(res.name(), exp);
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
@ -137,7 +137,7 @@ public class LayerOpValidation extends BaseOpValidation {
TestCase tc = new TestCase(sameDiff) TestCase tc = new TestCase(sameDiff)
.gradientCheck(true) .gradientCheck(true)
.expectedOutput(res.getVarName(), exp); .expectedOutput(res.name(), exp);
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
assertNull(err); assertNull(err);
@ -591,7 +591,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable out = sd.cnn().sconv2d(vars, c); SDVariable out = sd.cnn().sconv2d(vars, c);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("out", out);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = out.eval();
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
val outShape = outArr.shape(); val outShape = outArr.shape();
assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape); assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape);
@ -637,7 +637,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable out = sd.cnn().sconv2d(vars, c); SDVariable out = sd.cnn().sconv2d(vars, c);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("out", out);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = out.eval();
//Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7 //Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7
val outShape = outArr.shape(); val outShape = outArr.shape();
assertArrayEquals(new long[]{mb, nOut, 7, 7}, outShape); assertArrayEquals(new long[]{mb, nOut, 7, 7}, outShape);
@ -688,7 +688,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable out = sd.cnn().deconv2d(vars, deconv); SDVariable out = sd.cnn().deconv2d(vars, deconv);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("out", out);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = out.eval();
//Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9 //Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9
val outShape = outArr.shape(); val outShape = outArr.shape();
assertArrayEquals(new long[]{mb, nOut, 9, 9}, outShape); assertArrayEquals(new long[]{mb, nOut, 9, 9}, outShape);
@ -736,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable out = sd.cnn().conv2d("conv", vars, c); SDVariable out = sd.cnn().conv2d("conv", vars, c);
out = sd.nn().tanh("out", out); out = sd.nn().tanh("out", out);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = out.eval();
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27 //Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
val outShape = outArr.shape(); val outShape = outArr.shape();
assertArrayEquals(new long[]{mb, nOut, 27, 27}, outShape); assertArrayEquals(new long[]{mb, nOut, 27, 27}, outShape);
@ -770,7 +770,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig); SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig);
SDVariable out = sd.nn().tanh("out", outPool); SDVariable out = sd.nn().tanh("out", outPool);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = out.eval();
val outShape = outArr.shape(); val outShape = outArr.shape();
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape); assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape);
@ -828,7 +828,7 @@ public class LayerOpValidation extends BaseOpValidation {
SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig); SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig);
SDVariable out = sd.nn().tanh("out", outPool); SDVariable out = sd.nn().tanh("out", outPool);
INDArray outArr = sd.execAndEndResult(); INDArray outArr = out.eval();
val outShape = outArr.shape(); val outShape = outArr.shape();
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape); assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape);
@ -996,7 +996,7 @@ public class LayerOpValidation extends BaseOpValidation {
} }
); );
TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.getVarName(), expected); TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.name(), expected);
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
assertNull(err); assertNull(err);

View File

@ -423,11 +423,11 @@ public class MiscOpValidation extends BaseOpValidation {
} }
SDVariable loss = sd.sum(scatter); //.standardDeviation(scatter, true); //.sum(scatter); //TODO stdev might be better here as gradients are non-symmetrical... SDVariable loss = sd.sum(scatter); //.standardDeviation(scatter, true); //.sum(scatter); //TODO stdev might be better here as gradients are non-symmetrical...
sd.execAndEndResult();
TestCase tc = new TestCase(sd) TestCase tc = new TestCase(sd)
.expected(scatter, exp) .expected(scatter, exp)
.gradCheckSkipVariables(indices.getVarName()); .gradCheckSkipVariables(indices.name());
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if(error != null){ if(error != null){
@ -493,7 +493,7 @@ public class MiscOpValidation extends BaseOpValidation {
TestCase tc = new TestCase(sd) TestCase tc = new TestCase(sd)
.testName(msg) .testName(msg)
.gradCheckSkipVariables(indices.getVarName()); .gradCheckSkipVariables(indices.name());
if (gatherExp != null) { if (gatherExp != null) {
tc.expected(gather, gatherExp); tc.expected(gather, gatherExp);
@ -589,16 +589,16 @@ public class MiscOpValidation extends BaseOpValidation {
Map<String,INDArray> m = sameDiff.outputAll(null); Map<String,INDArray> m = sameDiff.outputAll(null);
Map<String,INDArray> gm = sameDiff.calculateGradients(null, m.keySet()); Map<String,INDArray> gm = sameDiff.calculateGradients(null, m.keySet());
SDVariable finalResult = sameDiff.grad(sum.getVarName()); SDVariable finalResult = sameDiff.grad(sum.name());
SDVariable cGrad = sameDiff.grad(varMulPre.getVarName()); SDVariable cGrad = sameDiff.grad(varMulPre.name());
SDVariable mulGradResult = sameDiff.grad(varMul.getVarName()); SDVariable mulGradResult = sameDiff.grad(varMul.name());
SDVariable aGrad = sameDiff.grad(sdVariable.getVarName()); SDVariable aGrad = sameDiff.grad(sdVariable.name());
SDVariable wGrad = sameDiff.grad(sdVariable1.getVarName()); SDVariable wGrad = sameDiff.grad(sdVariable1.name());
SDVariable dGrad = sameDiff.grad(varMul.getVarName()); SDVariable dGrad = sameDiff.grad(varMul.name());
INDArray scalarGradTest = gm.get(sum.getVarName()); INDArray scalarGradTest = gm.get(sum.name());
assertEquals(scalar, scalarGradTest); assertEquals(scalar, scalarGradTest);
@ -738,11 +738,10 @@ public class MiscOpValidation extends BaseOpValidation {
SDVariable B2 = sd.var("B2", B); SDVariable B2 = sd.var("B2", B);
SDVariable[] batchMul = sd.batchMmul(new SDVariable[] {A1, A2}, new SDVariable[] {B1, B2}); SDVariable[] batchMul = sd.batchMmul(new SDVariable[] {A1, A2}, new SDVariable[] {B1, B2});
sd.exec(Collections.emptyMap(), sd.outputs()); Map<String,INDArray> m = sd.output(Collections.emptyMap(), sd.outputs());
INDArray resultingMatrix = batchMul[0].getArr();
System.out.print(resultingMatrix);
INDArray resultingMatrix = m.get(batchMul[0].name());
//System.out.print(resultingMatrix);
} }
@ -770,14 +769,14 @@ public class MiscOpValidation extends BaseOpValidation {
SDVariable mmul = sd.f().mmul(f, s, mt); SDVariable mmul = sd.f().mmul(f, s, mt);
sd.updateVariableNameAndReference(mmul, "mmul"); sd.updateVariableNameAndReference(mmul, "mmul");
INDArray out = sd.execAndEndResult(); INDArray out = mmul.eval();
INDArray exp = first.transpose().mmul(second); INDArray exp = first.transpose().mmul(second);
assertEquals(exp, out); assertEquals(exp, out);
SDVariable loss = sd.standardDeviation(mmul, true); SDVariable loss = sd.standardDeviation(mmul, true);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(mmul.getVarName(), exp)); .expected(mmul.name(), exp));
assertNull(err); assertNull(err);
} }
@ -1287,7 +1286,7 @@ public class MiscOpValidation extends BaseOpValidation {
SDVariable var = sd.var("in", i); SDVariable var = sd.var("in", i);
SDVariable diag = sd.math().diagPart(var); SDVariable diag = sd.math().diagPart(var);
INDArray out = sd.execAndEndResult(); INDArray out = diag.eval();
assertEquals(1, out.rank()); assertEquals(1, out.rank());
} }
@ -1644,10 +1643,10 @@ public class MiscOpValidation extends BaseOpValidation {
SDVariable v = new StopGradient(sd, w).outputVariable(); SDVariable v = new StopGradient(sd, w).outputVariable();
SDVariable loss = v.std(true); SDVariable loss = v.std(true);
Map<String,INDArray> gm = sd.calculateGradients(null, v.getVarName(), w.getVarName()); Map<String,INDArray> gm = sd.calculateGradients(null, v.name(), w.name());
INDArray vArr = gm.get(v.getVarName()); INDArray vArr = gm.get(v.name());
INDArray wArr = gm.get(w.getVarName()); INDArray wArr = gm.get(w.name());
System.out.println(vArr); System.out.println(vArr);
System.out.println(wArr); System.out.println(wArr);
@ -1669,18 +1668,18 @@ public class MiscOpValidation extends BaseOpValidation {
INDArray expLoss = in.std(true); INDArray expLoss = in.std(true);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expectedOutput(checkNumerics.getVarName(), in) .expectedOutput(checkNumerics.name(), in)
.placeholderValue("in", in) .placeholderValue("in", in)
.expectedOutput("loss", expLoss)); .expectedOutput("loss", expLoss));
Preconditions.checkState(err == null, err); Preconditions.checkState(err == null, err);
//Also check that it actually does what it's supposed to: //Also check that it actually does what it's supposed to:
sd.execAll(Collections.singletonMap("in", in)); sd.outputAll(Collections.singletonMap("in", in));
in.putScalar(0, Double.NaN); in.putScalar(0, Double.NaN);
try { try {
sd.execAll(Collections.singletonMap("in", in)); sd.outputAll(Collections.singletonMap("in", in));
fail("Expected exception"); fail("Expected exception");
} catch (Throwable t){ } catch (Throwable t){
//OK //OK
@ -1688,14 +1687,14 @@ public class MiscOpValidation extends BaseOpValidation {
in.putScalar(0, Double.POSITIVE_INFINITY); in.putScalar(0, Double.POSITIVE_INFINITY);
try { try {
sd.execAll(Collections.singletonMap("in", in)); sd.outputAll(Collections.singletonMap("in", in));
fail("Expected exception"); fail("Expected exception");
} catch (Throwable t){ } catch (Throwable t){
//OK //OK
} }
in.putScalar(0, 0.0); in.putScalar(0, 0.0);
sd.execAll(Collections.singletonMap("in", in)); sd.outputAll(Collections.singletonMap("in", in));
} }
@Test @Test

View File

@ -117,8 +117,8 @@ public class ReductionOpValidation extends BaseOpValidation {
SDVariable loss = nonZero.add(zero).castTo(DataType.DOUBLE).std(true); SDVariable loss = nonZero.add(zero).castTo(DataType.DOUBLE).std(true);
String error = OpValidation.validate(new TestCase(sd) String error = OpValidation.validate(new TestCase(sd)
.expectedOutput(nonZero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0)) .expectedOutput(nonZero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0))
.expectedOutput(zero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0)) .expectedOutput(zero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0))
.gradientCheck(false) .gradientCheck(false)
); );
if (error != null) if (error != null)
@ -148,7 +148,7 @@ public class ReductionOpValidation extends BaseOpValidation {
SDVariable zeroFraction = sd.math().zeroFraction(input); SDVariable zeroFraction = sd.math().zeroFraction(input);
String error = OpValidation.validate(new TestCase(sd) String error = OpValidation.validate(new TestCase(sd)
.expectedOutput(zeroFraction.getVarName(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f)) .expectedOutput(zeroFraction.name(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f))
.gradientCheck(i != 0) .gradientCheck(i != 0)
); );
if (error != null) if (error != null)
@ -429,7 +429,7 @@ public class ReductionOpValidation extends BaseOpValidation {
tc.gradientCheck(gradientCheckable); tc.gradientCheck(gradientCheckable);
if(exp != null){ if(exp != null){
tc.expectedOutput(loss.getVarName(), exp); tc.expectedOutput(loss.name(), exp);
} }
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
@ -996,7 +996,7 @@ public class ReductionOpValidation extends BaseOpValidation {
String msg = name + " - dims=" + Arrays.toString(reduceDims); String msg = name + " - dims=" + Arrays.toString(reduceDims);
INDArray out = sd.execAndEndResult(); INDArray out = reduced.eval();
log.info(msg + " - expected shape: " + Arrays.toString(expShape) + ", out=" + Arrays.toString(out.shape()) log.info(msg + " - expected shape: " + Arrays.toString(expShape) + ", out=" + Arrays.toString(out.shape())
+ ", outExp=" + Arrays.toString(expOut.shape())); + ", outExp=" + Arrays.toString(expOut.shape()));
@ -1069,10 +1069,10 @@ public class ReductionOpValidation extends BaseOpValidation {
sd.associateArrayWithVariable(inputArr, input); sd.associateArrayWithVariable(inputArr, input);
sd.associateArrayWithVariable(labelArr, label); sd.associateArrayWithVariable(labelArr, label);
INDArray result = sd.execAndEndResult(); INDArray result = loss.eval();
assertEquals(1, result.length()); assertEquals(1, result.length());
sd.execBackwards(Collections.emptyMap()); sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
} }
} }

View File

@ -76,11 +76,11 @@ public class RnnOpValidation extends BaseOpValidation {
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v.getAllOutputs()){ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.getVarName()); toExec.add(sdv.name());
} }
//Test forward pass: //Test forward pass:
Map<String,INDArray> m = sd.exec(null, toExec); Map<String,INDArray> m = sd.output(null, toExec);
//Weights and bias order: [i, f, z, o] //Weights and bias order: [i, f, z, o]
@ -179,11 +179,11 @@ public class RnnOpValidation extends BaseOpValidation {
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v.getAllOutputs()){ for(SDVariable sdv : v.getAllOutputs()){
toExec.add(sdv.getVarName()); toExec.add(sdv.name());
} }
//Test forward pass: //Test forward pass:
Map<String,INDArray> m = sd.exec(null, toExec); Map<String,INDArray> m = sd.output(null, toExec);
INDArray out0 = Nd4j.create(new float[]{0.27817473f, 0.53092605f}, new int[]{1,2}); //Input mod gate INDArray out0 = Nd4j.create(new float[]{0.27817473f, 0.53092605f}, new int[]{1,2}); //Input mod gate
INDArray out1 = Nd4j.create(new float[]{-0.18100877f, 0.19417824f}, new int[]{1,2}); //CS (pre tanh) INDArray out1 = Nd4j.create(new float[]{-0.18100877f, 0.19417824f}, new int[]{1,2}); //CS (pre tanh)
@ -233,11 +233,11 @@ public class RnnOpValidation extends BaseOpValidation {
List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs(); List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs();
List<String> toExec = new ArrayList<>(); List<String> toExec = new ArrayList<>();
for(SDVariable sdv : v){ for(SDVariable sdv : v){
toExec.add(sdv.getVarName()); toExec.add(sdv.name());
} }
//Test forward pass: //Test forward pass:
Map<String,INDArray> m = sd.exec(null, toExec); Map<String,INDArray> m = sd.output(null, toExec);
//Weights and bias order: [r, u], [c] //Weights and bias order: [r, u], [c]

View File

@ -128,7 +128,7 @@ public class ShapeOpValidation extends BaseOpValidation {
//Using stdev here: mean/sum would backprop the same gradient for each input... //Using stdev here: mean/sum would backprop the same gradient for each input...
SDVariable stdev = sd.standardDeviation("out", reshape, true); SDVariable stdev = sd.standardDeviation("out", reshape, true);
INDArray out = sd.execAndEndResult(); INDArray out = stdev.eval();
INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE); INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
String msg = "toShape=" + Arrays.toString(toShape) + ", order=" + order; String msg = "toShape=" + Arrays.toString(toShape) + ", order=" + order;
@ -247,7 +247,7 @@ public class ShapeOpValidation extends BaseOpValidation {
Map<String,INDArray> m = sd.outputAll(null); Map<String,INDArray> m = sd.outputAll(null);
INDArray expOut = in.getArr().std(true); INDArray expOut = in.getArr().std(true);
assertArrayEquals(expExpandShape, m.get(expand.getVarName()).shape()); assertArrayEquals(expExpandShape, m.get(expand.name()).shape());
INDArray expExpand = inArr.dup('c').reshape(expExpandShape); INDArray expExpand = inArr.dup('c').reshape(expExpandShape);
String msg = "expandDim=" + i + ", source=" + p.getSecond(); String msg = "expandDim=" + i + ", source=" + p.getSecond();
@ -256,7 +256,7 @@ public class ShapeOpValidation extends BaseOpValidation {
TestCase tc = new TestCase(sd); TestCase tc = new TestCase(sd);
tc.testName(msg) tc.testName(msg)
.expectedOutput("out", expOut) .expectedOutput("out", expOut)
.expectedOutput(expand.getVarName(), expExpand); .expectedOutput(expand.name(), expExpand);
String error = OpValidation.validate(tc); String error = OpValidation.validate(tc);
if(error != null){ if(error != null){
@ -306,17 +306,17 @@ public class ShapeOpValidation extends BaseOpValidation {
Map<String,INDArray> m = sd.outputAll(null); Map<String,INDArray> m = sd.outputAll(null);
INDArray squeezed = m.get(squeeze.getVarName()); INDArray squeezed = m.get(squeeze.name());
// assertArrayEquals(expShapePostSqueeze, squeezed.shape()); // assertArrayEquals(expShapePostSqueeze, squeezed.shape());
INDArray out = sd.execAndEndResult(); INDArray out = m.get(stdev.name());
INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE); INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
assertEquals(expOut, out); assertEquals(expOut, out);
String msg = "squeezeDim=" + i + ", source=" + p.getSecond(); String msg = "squeezeDim=" + i + ", source=" + p.getSecond();
TestCase tc = new TestCase(sd) TestCase tc = new TestCase(sd)
.testName(msg) .testName(msg)
.expected(squeeze.getVarName(), exp) .expected(squeeze.name(), exp)
.expectedOutput("out", expOut); .expectedOutput("out", expOut);
@ -618,7 +618,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable stack = sd.stack(axis, in); SDVariable stack = sd.stack(axis, in);
INDArray out = sd.execAndEndResult(); INDArray out = stack.eval();
assertArrayEquals(expOutShape, out.shape()); assertArrayEquals(expOutShape, out.shape());
if (ArrayUtil.prodLong(shape) == 1) { if (ArrayUtil.prodLong(shape) == 1) {
@ -714,7 +714,7 @@ public class ShapeOpValidation extends BaseOpValidation {
Map<String,INDArray> m = sd.outputAll(null); Map<String,INDArray> m = sd.outputAll(null);
for (SDVariable v : unstacked) { for (SDVariable v : unstacked) {
assertArrayEquals(msg, shape, m.get(v.getVarName()).shape()); assertArrayEquals(msg, shape, m.get(v.name()).shape());
} }
TestCase tc = new TestCase(sd).testName(msg); TestCase tc = new TestCase(sd).testName(msg);
@ -884,7 +884,7 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray exp = arr.dup('c').reshape('c', 4,3); INDArray exp = arr.dup('c').reshape('c', 4,3);
String err = OpValidation.validate(new TestCase(sameDiff) String err = OpValidation.validate(new TestCase(sameDiff)
.expectedOutput(result1.getVarName(), exp)); .expectedOutput(result1.name(), exp));
assertNull(err); assertNull(err);
} }
@ -920,7 +920,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable result = sameDiff.transpose(x); SDVariable result = sameDiff.transpose(x);
SDVariable loss = sameDiff.standardDeviation(result, true); SDVariable loss = sameDiff.standardDeviation(result, true);
String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.getVarName(), arr.transpose())); String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.name(), arr.transpose()));
assertNull(err); assertNull(err);
} }
@ -1022,17 +1022,16 @@ public class ShapeOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
INDArray ia = Nd4j.create(new double[]{1,2,3}); INDArray ia = Nd4j.create(new double[]{1,2,3});
SDVariable in = sd.var(ia); SDVariable in = sd.var(ia);
SDVariable constant = sd.constant(in, 3); SDVariable loss = in.std(true);
SDVariable loss = constant.std(true);
assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia))); assertNull(OpValidation.validate(new TestCase(sd).expected(in, ia)));
//Case 1: shape is provided + scalar //Case 1: shape is provided + scalar
sd = SameDiff.create(); sd = SameDiff.create();
ia = Nd4j.scalar(3.0); ia = Nd4j.scalar(3.0);
in = sd.var(ia); in = sd.var(ia);
constant = sd.constant(in, 3,4,5); SDVariable constant = sd.constant(Nd4j.create(DataType.FLOAT, 3,4,5));
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0); INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
loss = constant.std(true); loss = constant.std(true);
@ -1149,7 +1148,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable loss = sameDiff.standardDeviation(result, true); SDVariable loss = sameDiff.standardDeviation(result, true);
String err = OpValidation.validate(new TestCase(sameDiff) String err = OpValidation.validate(new TestCase(sameDiff)
.expected(result.getVarName(), expected) .expected(result.name(), expected)
.gradientCheck(false)); .gradientCheck(false));
assertNull(err); assertNull(err);
} }
@ -1172,7 +1171,7 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray outExp = Nd4j.scalar(d); INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp)); .expected(md.name(), outExp));
assertNull(err); assertNull(err);
} }
@ -1196,7 +1195,7 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray outExp = Nd4j.scalar(d); INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp)); .expected(md.name(), outExp));
assertNull(err); assertNull(err);
} }
@ -1227,7 +1226,7 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray outExp = Nd4j.scalar(d); INDArray outExp = Nd4j.scalar(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), outExp)); .expected(md.name(), outExp));
assertNull(err); assertNull(err);
} }
@ -1247,7 +1246,7 @@ public class ShapeOpValidation extends BaseOpValidation {
//System.out.println(d); //System.out.println(d);
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.expected(md.getVarName(), Nd4j.scalar(d))); .expected(md.name(), Nd4j.scalar(d)));
assertNull(err); assertNull(err);
} }
@ -1332,7 +1331,7 @@ public class ShapeOpValidation extends BaseOpValidation {
.testName(op) .testName(op)
.expected(sm, exp) .expected(sm, exp)
.gradientCheck(true) .gradientCheck(true)
.gradCheckSkipVariables(segments.getVarName()); .gradCheckSkipVariables(segments.name());
String err = OpValidation.validate(tc); String err = OpValidation.validate(tc);
if(err != null) if(err != null)
@ -1383,7 +1382,7 @@ public class ShapeOpValidation extends BaseOpValidation {
String err = OpValidation.validate(new TestCase(sameDiff) String err = OpValidation.validate(new TestCase(sameDiff)
.expected(result1, expected) .expected(result1, expected)
.gradCheckSkipVariables(lengths.getVarName())); .gradCheckSkipVariables(lengths.name()));
assertNull(err); assertNull(err);
// Test with dynamic maxlen // Test with dynamic maxlen
@ -1591,8 +1590,8 @@ public class ShapeOpValidation extends BaseOpValidation {
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3)); INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
SDVariable x = sameDiff.var("x", arr); SDVariable x = sameDiff.var("x", arr);
SDVariable result = sameDiff.permute(x, 1, 0); SDVariable result = sameDiff.permute(x, 1, 0);
sameDiff.execAll(null); Map<String,INDArray> m = sameDiff.outputAll(null);
assertArrayEquals(new long[]{3, 2}, result.getShape()); assertArrayEquals(new long[]{3, 2}, m.get(result.name()).shape());
} }
@ -1629,10 +1628,10 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable slice_full = sd.slice(in, new int[]{0, 0}, new int[]{3, 4}); SDVariable slice_full = sd.slice(in, new int[]{0, 0}, new int[]{3, 4});
SDVariable subPart = sd.slice(in, new int[]{1, 2}, new int[]{2, 2}); SDVariable subPart = sd.slice(in, new int[]{1, 2}, new int[]{2, 2});
sd.exec(Collections.emptyMap(), sd.outputs()); Map<String,INDArray> m = sd.outputAll(Collections.emptyMap());
assertEquals(inArr, slice_full.getArr()); assertEquals(inArr, m.get(slice_full.name()));
assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr()); assertEquals(inArr.get(interval(1, 3), interval(2, 4)), m.get(subPart.name()));
} }
@ -1645,10 +1644,10 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable slice_full = sd.slice(in, new int[]{0, 0, 0}, new int[]{3, 4, 5}); SDVariable slice_full = sd.slice(in, new int[]{0, 0, 0}, new int[]{3, 4, 5});
SDVariable subPart = sd.slice(in, new int[]{1, 2, 3}, new int[]{2, 2, 1}); SDVariable subPart = sd.slice(in, new int[]{1, 2, 3}, new int[]{2, 2, 1});
sd.exec(Collections.emptyMap(), sd.outputs()); Map<String,INDArray> m = sd.outputAll(null);
assertEquals(inArr, slice_full.getArr()); assertEquals(inArr, m.get(slice_full.name()));
assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), subPart.getArr()); assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name()));
} }
@Test @Test
@ -1661,7 +1660,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1}); SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1});
// SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2}); // SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2});
sd.execAll(null); sd.outputAll(null);
assertEquals(inArr, slice_full.getArr()); assertEquals(inArr, slice_full.getArr());
assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr()); assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr());
@ -1678,7 +1677,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0); SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0);
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0); SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0);
sd.execAll(null); sd.outputAll(null);
assertEquals(inArr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all()), slice1.getArr()); assertEquals(inArr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all()), slice1.getArr());
assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr()); assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr());
@ -1695,7 +1694,7 @@ public class ShapeOpValidation extends BaseOpValidation {
//[1:3,...,1:4] -> [1:3,:,1:4] //[1:3,...,1:4] -> [1:3,:,1:4]
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0); SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0);
sd.execAll(Collections.emptyMap()); sd.outputAll(Collections.emptyMap());
assertEquals(inArr.get(interval(1, 3), all(), all()), slice.getArr()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice.getArr());
assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr()); assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr());
@ -1708,7 +1707,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0); SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0);
INDArray out = sd.execAndEndResult(); INDArray out = slice.eval();
assertArrayEquals(new long[]{1, 3, 4, 5}, out.shape()); assertArrayEquals(new long[]{1, 3, 4, 5}, out.shape());
assertEquals(inArr, out.get(point(0), all(), all(), all())); assertEquals(inArr, out.get(point(0), all(), all(), all()));
@ -1720,7 +1719,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable in = sd.var("in", inArr); SDVariable in = sd.var("in", inArr);
SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0); SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0);
INDArray out = sd.execAndEndResult(); INDArray out = slice.eval();
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape()); assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
} }
@ -1735,7 +1734,7 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1); SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1);
SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1); SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1);
sd.execAll(null); sd.outputAll(null);
assertEquals(inArr.get(point(0), all(), all()), slice.getArr()); assertEquals(inArr.get(point(0), all(), all()), slice.getArr());
assertEquals(inArr.get(point(2), all(), all()), slice2.getArr()); assertEquals(inArr.get(point(2), all(), all()), slice2.getArr());
@ -1880,8 +1879,8 @@ public class ShapeOpValidation extends BaseOpValidation {
// log.info(sd.summary()); // log.info(sd.summary());
sd.exec(Collections.emptyMap(), Lists.newArrayList(s)); sd.output(Collections.emptyMap(), Lists.newArrayList(s));
sd.execBackwards(Collections.emptyMap()); sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
} }
} }
@ -2405,8 +2404,8 @@ public class ShapeOpValidation extends BaseOpValidation {
SDVariable gathered = sd.gather(input, indices, 1); SDVariable gathered = sd.gather(input, indices, 1);
SDVariable loss = gathered.std(true); SDVariable loss = gathered.std(true);
sd.exec(null, gathered.getVarName()); sd.output((Map<String,INDArray>)null, gathered.name());
sd.setLossVariables(gathered.getVarName()); sd.setLossVariables(gathered.name());
String err = OpValidation.validate(new TestCase(sd) String err = OpValidation.validate(new TestCase(sd)
.gradCheckEpsilon(1e-3) .gradCheckEpsilon(1e-3)

View File

@ -115,37 +115,37 @@ public class TransformOpValidation extends BaseOpValidation {
switch (i){ switch (i){
case 0: case 0:
out = in.mul(2); out = in.mul(2);
tc.expectedOutput(out.getVarName(), inArr.mul(2)); tc.expectedOutput(out.name(), inArr.mul(2));
msg = "mul - " + inOrder; msg = "mul - " + inOrder;
break; break;
case 1: case 1:
out = in.div(2); out = in.div(2);
tc.expectedOutput(out.getVarName(), inArr.div(2)); tc.expectedOutput(out.name(), inArr.div(2));
msg = "div - " + inOrder; msg = "div - " + inOrder;
break; break;
case 2: case 2:
out = in.add(2); out = in.add(2);
tc.expectedOutput(out.getVarName(), inArr.add(2)); tc.expectedOutput(out.name(), inArr.add(2));
msg = "add - " + inOrder; msg = "add - " + inOrder;
break; break;
case 3: case 3:
out = in.sub(2); out = in.sub(2);
tc.expectedOutput(out.getVarName(), inArr.sub(2)); tc.expectedOutput(out.name(), inArr.sub(2));
msg = "sub - " + inOrder; msg = "sub - " + inOrder;
break; break;
case 4: case 4:
out = in.rdiv(2); out = in.rdiv(2);
tc.expectedOutput(out.getVarName(), inArr.rdiv(2)); tc.expectedOutput(out.name(), inArr.rdiv(2));
msg = "rdiv - " + inOrder; msg = "rdiv - " + inOrder;
break; break;
case 5: case 5:
out = in.rsub(2); out = in.rsub(2);
tc.expectedOutput(out.getVarName(), inArr.rsub(2)); tc.expectedOutput(out.name(), inArr.rsub(2));
msg = "rsub - " + inOrder; msg = "rsub - " + inOrder;
break; break;
case 6: case 6:
out = sd.math().pow(in,2); out = sd.math().pow(in,2);
tc.expectedOutput(out.getVarName(), Transforms.pow(inArr, 2)); tc.expectedOutput(out.name(), Transforms.pow(inArr, 2));
msg = "pow - " + inOrder; msg = "pow - " + inOrder;
break; break;
case 7: case 7:
@ -584,219 +584,219 @@ public class TransformOpValidation extends BaseOpValidation {
switch (i) { switch (i) {
case 0: case 0:
t = in.add(5.0); t = in.add(5.0);
tc.expectedOutput(t.getVarName(), ia.add(5.0)); tc.expectedOutput(t.name(), ia.add(5.0));
break; break;
case 1: case 1:
t = in.sub(5.0); t = in.sub(5.0);
tc.expectedOutput(t.getVarName(), ia.sub(5.0)); tc.expectedOutput(t.name(), ia.sub(5.0));
break; break;
case 2: case 2:
t = in.mul(2.5); t = in.mul(2.5);
tc.expectedOutput(t.getVarName(), ia.mul(2.5)); tc.expectedOutput(t.name(), ia.mul(2.5));
break; break;
case 3: case 3:
t = in.div(4.0); t = in.div(4.0);
tc.expectedOutput(t.getVarName(), ia.div(4.0)); tc.expectedOutput(t.name(), ia.div(4.0));
break; break;
case 4: case 4:
t = in.rsub(5.0); t = in.rsub(5.0);
tc.expectedOutput(t.getVarName(), ia.rsub(5.0)); tc.expectedOutput(t.name(), ia.rsub(5.0));
break; break;
case 5: case 5:
t = in.rdiv(1.0); t = in.rdiv(1.0);
tc.expectedOutput(t.getVarName(), ia.rdiv(1.0)); tc.expectedOutput(t.name(), ia.rdiv(1.0));
break; break;
case 6: case 6:
t = sd.math().pow(in, 2.5); t = sd.math().pow(in, 2.5);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.5, true)); tc.expectedOutput(t.name(), Transforms.pow(ia, 2.5, true));
break; break;
case 7: case 7:
t = sd.nn().sigmoid(in); t = sd.nn().sigmoid(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0);
tc.expectedOutput(t.getVarName(), Transforms.sigmoid(ia, true)); tc.expectedOutput(t.name(), Transforms.sigmoid(ia, true));
break; break;
case 8: case 8:
t = sd.math().tanh(in); t = sd.math().tanh(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0);
tc.expectedOutput(t.getVarName(), Transforms.tanh(ia, true)); tc.expectedOutput(t.name(), Transforms.tanh(ia, true));
break; break;
case 9: case 9:
ia.assign(Nd4j.rand(DataType.DOUBLE, ia.shape())); ia.assign(Nd4j.rand(DataType.DOUBLE, ia.shape()));
t = sd.math().tan(in); t = sd.math().tan(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
tc.expectedOutput(t.getVarName(), Transforms.tan(ia)); tc.expectedOutput(t.name(), Transforms.tan(ia));
break; break;
case 10: case 10:
t = sd.math().cos(in); t = sd.math().cos(in);
tc.expectedOutput(t.getVarName(), Transforms.cos(ia, true)); tc.expectedOutput(t.name(), Transforms.cos(ia, true));
break; break;
case 11: case 11:
t = sd.math().sin(in); t = sd.math().sin(in);
tc.expectedOutput(t.getVarName(), Transforms.sin(ia, true)); tc.expectedOutput(t.name(), Transforms.sin(ia, true));
break; break;
case 12: case 12:
t = sd.nn().softplus(in); t = sd.nn().softplus(in);
tc.expectedOutput(t.getVarName(), Transforms.softPlus(ia, true)); tc.expectedOutput(t.name(), Transforms.softPlus(ia, true));
break; break;
case 13: case 13:
t = sd.math().log(in); t = sd.math().log(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
tc.expectedOutput(t.getVarName(), Transforms.log(ia, true)); tc.expectedOutput(t.name(), Transforms.log(ia, true));
break; break;
case 14: case 14:
t = sd.math().neg(in); t = sd.math().neg(in);
INDArray exp14 = ia.neg(); INDArray exp14 = ia.neg();
tc.expectedOutput(t.getVarName(), exp14); tc.expectedOutput(t.name(), exp14);
break; break;
case 15: case 15:
t = sd.math().acos(in); t = sd.math().acos(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9);
tc.expectedOutput(t.getVarName(), Transforms.acos(ia, true)); tc.expectedOutput(t.name(), Transforms.acos(ia, true));
break; break;
case 16: case 16:
t = sd.math().acosh(in); t = sd.math().acosh(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).addi(1.01); //Only defined for x >= 1 ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).addi(1.01); //Only defined for x >= 1
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ACosh(ia.dup()))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ACosh(ia.dup())));
break; break;
case 17: case 17:
t = sd.math().asin(in); t = sd.math().asin(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9);
tc.expectedOutput(t.getVarName(), Transforms.asin(ia, true)); tc.expectedOutput(t.name(), Transforms.asin(ia, true));
break; break;
case 18: case 18:
t = sd.math().atan(in); t = sd.math().atan(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(4).subi(2); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(4).subi(2);
tc.expectedOutput(t.getVarName(), Transforms.atan(ia, true)); tc.expectedOutput(t.name(), Transforms.atan(ia, true));
break; break;
case 19: case 19:
t = sd.math().atanh(in); t = sd.math().atanh(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9);
tc.expectedOutput(t.getVarName(), Transforms.atanh(ia, true)); tc.expectedOutput(t.name(), Transforms.atanh(ia, true));
break; break;
case 20: case 20:
t = sd.math().cosh(in); t = sd.math().cosh(in);
tc.expectedOutput(t.getVarName(), Transforms.cosh(ia, true)); tc.expectedOutput(t.name(), Transforms.cosh(ia, true));
break; break;
case 21: case 21:
t = sd.math().cube(in); t = sd.math().cube(in);
tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 3.0, true)); tc.expectedOutput(t.name(), Transforms.pow(ia, 3.0, true));
break; break;
case 22: case 22:
t = sd.nn().elu(in); t = sd.nn().elu(in);
tc.expectedOutput(t.getVarName(), Transforms.elu(ia, true)); tc.expectedOutput(t.name(), Transforms.elu(ia, true));
break; break;
case 23: case 23:
//TODO SHOULDN'T THIS HAVE A DIMENSION ARG??? //TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
t = sd.nn().softmax(in); t = sd.nn().softmax(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]);
break; break;
case 24: case 24:
t = sd.math().sqrt(in); t = sd.math().sqrt(in);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
tc.expectedOutput(t.getVarName(), Transforms.sqrt(ia, true)); tc.expectedOutput(t.name(), Transforms.sqrt(ia, true));
break; break;
case 25: case 25:
t = sd.math().square(in); t = sd.math().square(in);
tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.0, true)); tc.expectedOutput(t.name(), Transforms.pow(ia, 2.0, true));
break; break;
case 26: case 26:
t = sd.transpose(in); t = sd.transpose(in);
tc.expectedOutput(t.getVarName(), ia.transpose().dup()); tc.expectedOutput(t.name(), ia.transpose().dup());
break; break;
case 27: case 27:
t = sd.math().abs(in); t = sd.math().abs(in);
tc.expectedOutput(t.getVarName(), Transforms.abs(ia, true)); tc.expectedOutput(t.name(), Transforms.abs(ia, true));
break; break;
case 28: case 28:
t = sd.math().sinh(in); t = sd.math().sinh(in);
tc.expectedOutput(t.getVarName(), Transforms.sinh(ia, true)); tc.expectedOutput(t.name(), Transforms.sinh(ia, true));
break; break;
case 29: case 29:
t = sd.math().asinh(in); t = sd.math().asinh(in);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ASinh(ia.dup()))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ASinh(ia.dup())));
break; break;
case 30: case 30:
t = sd.math().exp(in); t = sd.math().exp(in);
tc.expectedOutput(t.getVarName(), Transforms.exp(ia, true)); tc.expectedOutput(t.name(), Transforms.exp(ia, true));
break; break;
case 31: case 31:
t = sd.math().floor(in); t = sd.math().floor(in);
tc.expectedOutput(t.getVarName(), Transforms.floor(ia, true)); tc.expectedOutput(t.name(), Transforms.floor(ia, true));
break; break;
case 32: case 32:
t = sd.nn().relu(in, 0.0); t = sd.nn().relu(in, 0.0);
ia = Nd4j.rand(minibatch, nOut); ia = Nd4j.rand(minibatch, nOut);
tc.expectedOutput(t.getVarName(), Transforms.relu(ia, true)); tc.expectedOutput(t.name(), Transforms.relu(ia, true));
break; break;
case 33: case 33:
t = sd.nn().hardTanh(in); t = sd.nn().hardTanh(in);
ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0); ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
tc.expectedOutput(t.getVarName(), Transforms.hardTanh(ia, true)); tc.expectedOutput(t.name(), Transforms.hardTanh(ia, true));
break; break;
case 34: case 34:
t = sd.nn().logSigmoid(in); t = sd.nn().logSigmoid(in);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup()))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup())));
break; break;
case 35: case 35:
t = sd.nn().swish(in); t = sd.nn().swish(in);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Swish(ia.dup()))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Swish(ia.dup())));
break; break;
case 36: case 36:
t = sd.math().sign(in); t = sd.math().sign(in);
tc.expectedOutput(t.getVarName(), Transforms.sign(ia, true)); tc.expectedOutput(t.name(), Transforms.sign(ia, true));
break; break;
case 37: case 37:
t = sd.nn().softsign(in); t = sd.nn().softsign(in);
tc.expectedOutput(t.getVarName(), Transforms.softsign(ia, true)); tc.expectedOutput(t.name(), Transforms.softsign(ia, true));
break; break;
case 38: case 38:
t = sd.nn().leakyRelu(in, 0.0); t = sd.nn().leakyRelu(in, 0.0);
ia = Nd4j.rand(minibatch, nOut); ia = Nd4j.rand(minibatch, nOut);
tc.expectedOutput(t.getVarName(), Transforms.leakyRelu(ia, true)); tc.expectedOutput(t.name(), Transforms.leakyRelu(ia, true));
break; break;
case 39: case 39:
if(OpValidationSuite.IGNORE_FAILING) if(OpValidationSuite.IGNORE_FAILING)
continue; continue;
t = sd.nn().logSoftmax(in); t = sd.nn().logSoftmax(in);
ia = Nd4j.rand(minibatch, nOut).muli(10).subi(5); ia = Nd4j.rand(minibatch, nOut).muli(10).subi(5);
tc.expectedOutput(t.getVarName(), Transforms.log(Transforms.softmax(ia, true))); tc.expectedOutput(t.name(), Transforms.log(Transforms.softmax(ia, true)));
stdevLoss = true; stdevLoss = true;
break; break;
case 40: case 40:
t = sd.nn().selu(in); t = sd.nn().selu(in);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SELU(ia.dup()))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SELU(ia.dup())));
break; break;
case 41: case 41:
t = sd.gt(in, 1.0).castTo(DataType.DOUBLE); t = sd.gt(in, 1.0).castTo(DataType.DOUBLE);
tc.expectedOutput(t.getVarName(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); tc.expectedOutput(t.name(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
break; break;
case 42: case 42:
t = sd.gte(in, 1.0).castTo(DataType.DOUBLE); t = sd.gte(in, 1.0).castTo(DataType.DOUBLE);
tc.expectedOutput(t.getVarName(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); tc.expectedOutput(t.name(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
break; break;
case 43: case 43:
t = sd.lt(in, 1.0).castTo(DataType.DOUBLE); t = sd.lt(in, 1.0).castTo(DataType.DOUBLE);
tc.expectedOutput(t.getVarName(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); tc.expectedOutput(t.name(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
break; break;
case 44: case 44:
t = sd.lte(in, 1.0).castTo(DataType.DOUBLE); t = sd.lte(in, 1.0).castTo(DataType.DOUBLE);
tc.expectedOutput(t.getVarName(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false); tc.expectedOutput(t.name(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
break; break;
case 45: case 45:
t = sd.eq(in, 2.0).castTo(DataType.DOUBLE); t = sd.eq(in, 2.0).castTo(DataType.DOUBLE);
ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut); ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut);
tc.expectedOutput(t.getVarName(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); tc.expectedOutput(t.name(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false);
break; break;
case 46: case 46:
t = sd.neq(in, 2.0).castTo(DataType.DOUBLE); t = sd.neq(in, 2.0).castTo(DataType.DOUBLE);
ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut); ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut);
tc.expectedOutput(t.getVarName(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false); tc.expectedOutput(t.name(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false);
break; break;
case 47: case 47:
t = sd.math().ceil(in); t = sd.math().ceil(in);
tc.expectedOutput(t.getVarName(), Transforms.ceil(ia, true)); tc.expectedOutput(t.name(), Transforms.ceil(ia, true));
break; break;
case 48: case 48:
ia = Nd4j.randn(DataType.DOUBLE, ia.shape()).muli(2); ia = Nd4j.randn(DataType.DOUBLE, ia.shape()).muli(2);
@ -804,7 +804,7 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray expOut48 = ia.dup(); INDArray expOut48 = ia.dup();
BooleanIndexing.replaceWhere(expOut48, -3, Conditions.lessThan(-3)); BooleanIndexing.replaceWhere(expOut48, -3, Conditions.lessThan(-3));
BooleanIndexing.replaceWhere(expOut48, 2, Conditions.greaterThan(2)); BooleanIndexing.replaceWhere(expOut48, 2, Conditions.greaterThan(2));
tc.expectedOutput(t.getVarName(), expOut48); tc.expectedOutput(t.name(), expOut48);
break; break;
case 49: case 49:
//Clip by norm, dimension 0, some below threshold, some above //Clip by norm, dimension 0, some below threshold, some above
@ -825,7 +825,7 @@ public class TransformOpValidation extends BaseOpValidation {
expOut49.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue())); expOut49.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue()));
} }
} }
tc.expectedOutput(t.getVarName(), expOut49); tc.expectedOutput(t.name(), expOut49);
//System.out.println(expOut.norm2(0)); //System.out.println(expOut.norm2(0));
break; break;
//TODO clip by norm along other dimensions //TODO clip by norm along other dimensions
@ -837,7 +837,7 @@ public class TransformOpValidation extends BaseOpValidation {
.addIntegerArguments(dim) .addIntegerArguments(dim)
.addInputs(ia).addOutputs(expOut50).build(); .addInputs(ia).addOutputs(expOut50).build();
Nd4j.getExecutioner().exec(reverse); Nd4j.getExecutioner().exec(reverse);
tc.expectedOutput(t.getVarName(), expOut50); tc.expectedOutput(t.name(), expOut50);
break; break;
case 51: case 51:
dim = 0; dim = 0;
@ -850,7 +850,7 @@ public class TransformOpValidation extends BaseOpValidation {
.addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim) .addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim)
.addInputs(ia).addOutputs(expOut51).build(); .addInputs(ia).addOutputs(expOut51).build();
Nd4j.getExecutioner().exec(cumsum); Nd4j.getExecutioner().exec(cumsum);
tc.expectedOutput(t.getVarName(), expOut51); tc.expectedOutput(t.name(), expOut51);
break; break;
case 52: case 52:
if(OpValidationSuite.IGNORE_FAILING){ if(OpValidationSuite.IGNORE_FAILING){
@ -869,7 +869,7 @@ public class TransformOpValidation extends BaseOpValidation {
expOut52.putScalar(s0, s1, prod); expOut52.putScalar(s0, s1, prod);
} }
} }
tc.expectedOutput(t.getVarName(), expOut52); tc.expectedOutput(t.name(), expOut52);
break; break;
case 53: case 53:
if(OpValidationSuite.IGNORE_FAILING){ if(OpValidationSuite.IGNORE_FAILING){
@ -881,90 +881,90 @@ public class TransformOpValidation extends BaseOpValidation {
INDArray expOut53 = Nd4j.create(DataType.DOUBLE, 2, 2); INDArray expOut53 = Nd4j.create(DataType.DOUBLE, 2, 2);
DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut53).build(); DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut53).build();
Nd4j.getExecutioner().exec(op); Nd4j.getExecutioner().exec(op);
tc.expectedOutput(t.getVarName(), expOut53); tc.expectedOutput(t.name(), expOut53);
break; break;
case 54: case 54:
t = sd.math().erf(in); t = sd.math().erf(in);
INDArray expOut54 = Nd4j.createUninitialized(DataType.DOUBLE, ia.shape(), ia.ordering()); INDArray expOut54 = Nd4j.createUninitialized(DataType.DOUBLE, ia.shape(), ia.ordering());
Nd4j.getExecutioner().exec(new Erf(ia, expOut54)); Nd4j.getExecutioner().exec(new Erf(ia, expOut54));
tc.expectedOutput(t.getVarName(), expOut54); tc.expectedOutput(t.name(), expOut54);
break; break;
case 55: case 55:
t = sd.math().erfc(in); t = sd.math().erfc(in);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering())))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering()))));
break; break;
case 56: case 56:
t = sd.math().expm1(in); t = sd.math().expm1(in);
tc.expectedOutput(t.getVarName(),Transforms.expm1(ia, true)); tc.expectedOutput(t.name(),Transforms.expm1(ia, true));
break; break;
case 57: case 57:
t = sd.math().log1p(in); t = sd.math().log1p(in);
ia = Nd4j.rand(minibatch, nOut); ia = Nd4j.rand(minibatch, nOut);
tc.expectedOutput(t.getVarName(), Transforms.log1p(ia, true)); tc.expectedOutput(t.name(), Transforms.log1p(ia, true));
break; break;
case 58: case 58:
t = sd.math().round(in); t = sd.math().round(in);
tc.expectedOutput(t.getVarName(), Transforms.round(ia, true)); tc.expectedOutput(t.name(), Transforms.round(ia, true));
break; break;
case 59: case 59:
ia = Nd4j.create(new float[]{4, 2}).castTo(DataType.DOUBLE); ia = Nd4j.create(new float[]{4, 2}).castTo(DataType.DOUBLE);
// in = sd.var("in", new int[]{1, 2}); // in = sd.var("in", new int[]{1, 2});
t = sd.math().rsqrt(in); t = sd.math().rsqrt(in);
tc.expectedOutput(t.getVarName(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering())))); tc.expectedOutput(t.name(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering()))));
break; break;
case 60: case 60:
t = sd.nn().relu6(in, 0); t = sd.nn().relu6(in, 0);
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
tc.expectedOutput(t.getVarName(),Transforms.relu6(ia, true)); tc.expectedOutput(t.name(),Transforms.relu6(ia, true));
break; break;
case 61: case 61:
ia = Nd4j.create(new float[] {2, 2}).castTo(DataType.DOUBLE); ia = Nd4j.create(new float[] {2, 2}).castTo(DataType.DOUBLE);
sd.associateArrayWithVariable(ia, in); sd.associateArrayWithVariable(ia, in);
double value = 42; double value = 42;
t = sd.fill(in.castTo(DataType.INT), DataType.DOUBLE, value); t = sd.fill(in.castTo(DataType.INT), DataType.DOUBLE, value);
tc.expectedOutput(t.getVarName(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false); tc.expectedOutput(t.name(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false);
opName = "fill"; opName = "fill";
break; break;
case 62: case 62:
t = sd.nn().hardSigmoid(in); t = sd.nn().hardSigmoid(in);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup()))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup())));
break; break;
case 63: case 63:
t = sd.scalarMax(in, 0.5); t = sd.scalarMax(in, 0.5);
tc.expectedOutput(t.getVarName(), Transforms.max(ia, 0.5, true)); tc.expectedOutput(t.name(), Transforms.max(ia, 0.5, true));
break; break;
case 64: case 64:
t = sd.scalarMin(in, 0.5); t = sd.scalarMin(in, 0.5);
tc.expectedOutput(t.getVarName(), Transforms.min(ia, 0.5, true)); tc.expectedOutput(t.name(), Transforms.min(ia, 0.5, true));
break; break;
case 65: case 65:
t = sd.assign(in, 0.5); t = sd.assign(in, 0.5);
tc.expectedOutput(t.getVarName(), ia.dup().assign(0.5)); tc.expectedOutput(t.name(), ia.dup().assign(0.5));
break; break;
case 66: case 66:
t = sd.scalarFloorMod(in, 0.5); t = sd.scalarFloorMod(in, 0.5);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5))); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5)));
break; break;
case 67: case 67:
t = sd.math().reciprocal(in); t = sd.math().reciprocal(in);
tc.expectedOutput(t.getVarName(), ia.rdiv(1.0)); tc.expectedOutput(t.name(), ia.rdiv(1.0));
break; break;
case 68: case 68:
t = sd.shape(in).castTo(DataType.DOUBLE); t = sd.shape(in).castTo(DataType.DOUBLE);
tc.expectedOutput(t.getVarName(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false); tc.expectedOutput(t.name(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false);
break; break;
case 69: case 69:
t = sd.rank(in).castTo(DataType.DOUBLE); t = sd.rank(in).castTo(DataType.DOUBLE);
tc.expectedOutput(t.getVarName(), Nd4j.scalar((double)ia.rank())).gradientCheck(false); tc.expectedOutput(t.name(), Nd4j.scalar((double)ia.rank())).gradientCheck(false);
break; break;
case 70: case 70:
t = sd.onesLike(in); t = sd.onesLike(in);
tc.expectedOutput(t.getVarName(), Nd4j.ones(ia.shape())); tc.expectedOutput(t.name(), Nd4j.ones(ia.shape()));
break; break;
case 71: case 71:
ia = Nd4j.randn(DataType.DOUBLE, nOut, nOut); ia = Nd4j.randn(DataType.DOUBLE, nOut, nOut);
t = sd.math().diagPart(in); t = sd.math().diagPart(in);
tc.expectedOutput(t.getVarName(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE)); tc.expectedOutput(t.name(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE));
break; break;
case 72: case 72:
t = sd.identity(in); t = sd.identity(in);
@ -1087,109 +1087,109 @@ public class TransformOpValidation extends BaseOpValidation {
switch (i) { switch (i) {
case 0: case 0:
t = in1.add(in2); t = in1.add(in2);
tc.expectedOutput(t.getVarName(), ia.add(ib)); tc.expectedOutput(t.name(), ia.add(ib));
break; break;
case 1: case 1:
t = in1.sub(in2); t = in1.sub(in2);
tc.expectedOutput(t.getVarName(),ia.sub(ib)); tc.expectedOutput(t.name(),ia.sub(ib));
break; break;
case 2: case 2:
t = in1.mul(in2); t = in1.mul(in2);
tc.expectedOutput(t.getVarName(), ia.mul(ib)); tc.expectedOutput(t.name(), ia.mul(ib));
break; break;
case 3: case 3:
t = in1.div(in2); t = in1.div(in2);
tc.expectedOutput(t.getVarName(), ia.div(ib)); tc.expectedOutput(t.name(), ia.div(ib));
break; break;
case 4: case 4:
t = in1.rsub(in2); t = in1.rsub(in2);
tc.expectedOutput(t.getVarName(), ia.rsub(ib)); tc.expectedOutput(t.name(), ia.rsub(ib));
break; break;
case 5: case 5:
ia.assign(Nd4j.rand(ia.shape())).addi(0.5); ia.assign(Nd4j.rand(ia.shape())).addi(0.5);
ib.assign(Nd4j.rand(ib.shape())).addi(0.5); ib.assign(Nd4j.rand(ib.shape())).addi(0.5);
t = in1.rdiv(in2); t = in1.rdiv(in2);
tc.expectedOutput(t.getVarName(), ia.rdiv(ib)); tc.expectedOutput(t.name(), ia.rdiv(ib));
break; break;
case 6: case 6:
t = sd.eq(in1, in2); t = sd.eq(in1, in2);
opName = "eq"; opName = "eq";
tc.expectedOutput(t.getVarName(), ia.eq(ib)).gradientCheck(false); tc.expectedOutput(t.name(), ia.eq(ib)).gradientCheck(false);
break; break;
case 7: case 7:
t = sd.neq(in1, in2); t = sd.neq(in1, in2);
opName = "neq"; opName = "neq";
tc.expectedOutput(t.getVarName(), ia.neq(ib)).gradientCheck(false);; tc.expectedOutput(t.name(), ia.neq(ib)).gradientCheck(false);;
break; break;
case 8: case 8:
t = sd.gt(in1, in2); t = sd.gt(in1, in2);
opName = "gt"; opName = "gt";
tc.expectedOutput(t.getVarName(), ia.gt(ib)).gradientCheck(false); tc.expectedOutput(t.name(), ia.gt(ib)).gradientCheck(false);
break; break;
case 9: case 9:
t = sd.lt(in1, in2); t = sd.lt(in1, in2);
opName = "lt"; opName = "lt";
tc.expectedOutput(t.getVarName(), ia.lt(ib)).gradientCheck(false); tc.expectedOutput(t.name(), ia.lt(ib)).gradientCheck(false);
break; break;
case 10: case 10:
t = sd.gte(in1, in2); t = sd.gte(in1, in2);
opName = "gte"; opName = "gte";
INDArray expOut10 = Nd4j.create(DataType.BOOL, ia.shape()); INDArray expOut10 = Nd4j.create(DataType.BOOL, ia.shape());
Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut10})); Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut10}));
tc.expectedOutput(t.getVarName(), expOut10).gradientCheck(false); tc.expectedOutput(t.name(), expOut10).gradientCheck(false);
break; break;
case 11: case 11:
t = sd.lte(in1, in2); t = sd.lte(in1, in2);
opName = "lte"; opName = "lte";
INDArray expOut11 = Nd4j.create(DataType.BOOL, ia.shape()); INDArray expOut11 = Nd4j.create(DataType.BOOL, ia.shape());
Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut11})); Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut11}));
tc.expectedOutput(t.getVarName(), expOut11).gradientCheck(false); tc.expectedOutput(t.name(), expOut11).gradientCheck(false);
break; break;
case 12: case 12:
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL));
opName = "or"; opName = "or";
tc.expectedOutput(t.getVarName(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); tc.expectedOutput(t.name(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
break; break;
case 13: case 13:
ib = Nd4j.randn(DataType.DOUBLE, nOut, nOut); ib = Nd4j.randn(DataType.DOUBLE, nOut, nOut);
t = sd.mmul(in1, in2); t = sd.mmul(in1, in2);
tc.expectedOutput(t.getVarName(), ia.mmul(ib)); tc.expectedOutput(t.name(), ia.mmul(ib));
break; break;
case 14: case 14:
t = sd.max(in1, in2); t = sd.max(in1, in2);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]);
break; break;
case 15: case 15:
t = sd.min(in1, in2); t = sd.min(in1, in2);
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]);
break; break;
case 16: case 16:
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL));
opName = "and"; opName = "and";
tc.expectedOutput(t.getVarName(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); tc.expectedOutput(t.name(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
break; break;
case 17: case 17:
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5)); ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5)); ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL)); t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL));
opName = "xor"; opName = "xor";
tc.expectedOutput(t.getVarName(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false); tc.expectedOutput(t.name(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
break; break;
case 18: case 18:
t = sd.assign(in1, in2); t = sd.assign(in1, in2);
tc.expectedOutput(t.getVarName(), ib); tc.expectedOutput(t.name(), ib);
break; break;
case 19: case 19:
t = sd.math().atan2(in1, in2); t = sd.math().atan2(in1, in2);
tc.expectedOutput(t.getVarName(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms tc.expectedOutput(t.name(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms
break; break;
case 20: case 20:
t = sd.math().mergeAdd(in1, in2, in2); t = sd.math().mergeAdd(in1, in2, in2);
tc.expectedOutput(t.getVarName(), ia.add(ib).add(ib)); tc.expectedOutput(t.name(), ia.add(ib).add(ib));
break; break;
case 21: case 21:
t = in1.squaredDifference(in2); t = in1.squaredDifference(in2);
@ -1199,7 +1199,7 @@ public class TransformOpValidation extends BaseOpValidation {
.addOutputs(expOut21) .addOutputs(expOut21)
.build(); .build();
Nd4j.getExecutioner().exec(squareDiff); Nd4j.getExecutioner().exec(squareDiff);
tc.expectedOutput(t.getVarName(), expOut21); tc.expectedOutput(t.name(), expOut21);
break; break;
case 22: case 22:
//set diag //set diag
@ -1210,7 +1210,7 @@ public class TransformOpValidation extends BaseOpValidation {
expOut22.putScalar(j,j, ib.getDouble(j)); expOut22.putScalar(j,j, ib.getDouble(j));
} }
t = sd.math().setDiag(in1, in2); t = sd.math().setDiag(in1, in2);
tc.expectedOutput(t.getVarName(), expOut22); tc.expectedOutput(t.name(), expOut22);
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
@ -1341,7 +1341,6 @@ public class TransformOpValidation extends BaseOpValidation {
} }
} }
//TODO UPDATE TO OP VALIDATION OR DELETE
@Test @Test
public void testLogGrad() { public void testLogGrad() {
SameDiff sameDiff = SameDiff.create(); SameDiff sameDiff = SameDiff.create();
@ -1349,7 +1348,7 @@ public class TransformOpValidation extends BaseOpValidation {
SDVariable log = sameDiff.math().log(input); SDVariable log = sameDiff.math().log(input);
SDVariable sum = sameDiff.sum(log, Integer.MAX_VALUE); SDVariable sum = sameDiff.sum(log, Integer.MAX_VALUE);
INDArray result = null; INDArray result = null;
sameDiff.execBackwards(Collections.emptyMap()); sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet());
} }
@ -1362,8 +1361,8 @@ public class TransformOpValidation extends BaseOpValidation {
SDVariable input = sameDiff.var("x", inputs.get("x")); SDVariable input = sameDiff.var("x", inputs.get("x"));
SDVariable sigmoid = sameDiff.nn().sigmoid(input); SDVariable sigmoid = sameDiff.nn().sigmoid(input);
SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE); SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE);
sameDiff.execBackwards(Collections.emptyMap()); Map<String,INDArray> m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet());
INDArray arr = input.gradient().getArr(); INDArray arr = m.get(input.name());
assertTrue(Nd4j.create(new double[][]{ assertTrue(Nd4j.create(new double[][]{
{0.1966, 0.1050}, {0.1966, 0.1050},
{0.0452, 0.0177} {0.0452, 0.0177}
@ -1384,12 +1383,12 @@ public class TransformOpValidation extends BaseOpValidation {
public void testRank0EdgeCase(){ public void testRank0EdgeCase(){
SameDiff sd = SameDiff.create(); SameDiff sd = SameDiff.create();
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
double d0 = sd.execAndEndResult().getDouble(0); double d0 = v1.eval().getDouble(0);
assertEquals(8, d0, 0); assertEquals(8, d0, 0);
SDVariable v2 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))).div(2.0); SDVariable v2 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))).div(2.0);
sd.exec(Collections.emptyMap(), sd.outputs()); Map<String,INDArray> m = sd.outputAll(Collections.emptyMap());
double d1 = v2.getArr().getDouble(0); double d1 = m.get(v2.name()).getDouble(0);
assertEquals(4, d1, 0); assertEquals(4, d1, 0);
} }

View File

@ -87,12 +87,12 @@ public class FailingSameDiffTests extends BaseNd4jTest {
SDVariable tanh = sd.math().tanh(in); SDVariable tanh = sd.math().tanh(in);
INDArray exp = Transforms.tanh(in.getArr(), true); INDArray exp = Transforms.tanh(in.getArr(), true);
INDArray out = sd.execAndEndResult(); INDArray out = tanh.eval();
assertEquals(exp, out); assertEquals(exp, out);
//Now, replace with minibatch 5: //Now, replace with minibatch 5:
in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4)); in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4));
INDArray out2 = sd.execAndEndResult(); INDArray out2 = tanh.eval();
assertArrayEquals(new long[]{5,4}, out2.shape()); assertArrayEquals(new long[]{5,4}, out2.shape());
exp = Transforms.tanh(in.getArr(), true); exp = Transforms.tanh(in.getArr(), true);
@ -124,12 +124,12 @@ public class FailingSameDiffTests extends BaseNd4jTest {
SDVariable mmul = sd.mmul(in,w).add(b); SDVariable mmul = sd.mmul(in,w).add(b);
INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
INDArray out = sd.execAndEndResult(); INDArray out = mmul.eval();
assertEquals(exp, out); assertEquals(exp, out);
//Now, replace with minibatch 5: //Now, replace with minibatch 5:
in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4)); in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4));
INDArray out2 = sd.execAndEndResult(); INDArray out2 = mmul.eval();
assertArrayEquals(new long[]{5,5}, out2.shape()); assertArrayEquals(new long[]{5,5}, out2.shape());
exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr()); exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
@ -137,11 +137,10 @@ public class FailingSameDiffTests extends BaseNd4jTest {
//Generate gradient function, and exec //Generate gradient function, and exec
SDVariable loss = mmul.std(true); SDVariable loss = mmul.std(true);
sd.execBackwards(Collections.emptyMap()); sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
in.setArray(Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4)); in.setArray(Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));
sd.execAndEndResult(); out2 = mmul.eval();
out2 = mmul.getArr();
assertArrayEquals(new long[]{3,5}, out2.shape()); assertArrayEquals(new long[]{3,5}, out2.shape());
} }

View File

@ -173,7 +173,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
} }
if(execFirst){ if(execFirst){
sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name()));
} }
File f = testDir.newFile(); File f = testDir.newFile();
@ -186,7 +186,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
List<SDVariable> varsRestored = restored.variables(); List<SDVariable> varsRestored = restored.variables();
assertEquals(varsOrig.size(), varsRestored.size()); assertEquals(varsOrig.size(), varsRestored.size());
for (int j = 0; j < varsOrig.size(); j++) { for (int j = 0; j < varsOrig.size(); j++) {
assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName()); assertEquals(varsOrig.get(j).name(), varsRestored.get(j).name());
} }
DifferentialFunction[] fOrig = sd.ops(); DifferentialFunction[] fOrig = sd.ops();
@ -200,10 +200,10 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
assertEquals(sd.getLossVariables(), restored.getLossVariables()); assertEquals(sd.getLossVariables(), restored.getLossVariables());
Map<String,INDArray> m = sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); Map<String,INDArray> m = sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name()));
INDArray outOrig = m.get(x.getVarName()); INDArray outOrig = m.get(x.name());
Map<String,INDArray> m2 = restored.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName())); Map<String,INDArray> m2 = restored.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name()));
INDArray outRestored = m2.get(x.getVarName()); INDArray outRestored = m2.get(x.name());
assertEquals(String.valueOf(i), outOrig, outRestored); assertEquals(String.valueOf(i), outOrig, outRestored);
@ -320,7 +320,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
if(v.isPlaceHolder() || v.getVariableType() == VariableType.ARRAY) if(v.isPlaceHolder() || v.getVariableType() == VariableType.ARRAY)
continue; continue;
SDVariable v2 = sd2.getVariable(v.getVarName()); SDVariable v2 = sd2.getVariable(v.name());
INDArray a1 = v.getArr(); INDArray a1 = v.getArr();
INDArray a2 = v2.getArr(); INDArray a2 = v2.getArr();

View File

@ -57,17 +57,17 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
SDVariable sub = add.sub(add2); SDVariable sub = add.sub(add2);
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.getVarName()))); assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.name())));
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.name())));
assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.name())));
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.getVarName()))); assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.name())));
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.getVarName()))); assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.name())));
assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.getVarName()))); assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.name())));
assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.getVarName()))); assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.name())));
assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.getVarName()))); assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.name())));
assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.getVarName()))); assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.name())));
SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class)); SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class));
@ -76,11 +76,11 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
assertEquals(2, l.size()); assertEquals(2, l.size());
SubGraph sg1 = l.get(0); SubGraph sg1 = l.get(0);
assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.getVarName())); assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.name()));
assertEquals(0, sg1.getChildNodes().size()); assertEquals(0, sg1.getChildNodes().size());
SubGraph sg2 = l.get(1); SubGraph sg2 = l.get(1);
assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.getVarName())); assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.name()));
assertEquals(0, sg2.getChildNodes().size()); assertEquals(0, sg2.getChildNodes().size());
} }
@ -118,7 +118,7 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
}); });
INDArray exp2 = p1.div(p2).mul(p1.sub(p2)); INDArray exp2 = p1.div(p2).mul(p1.sub(p2));
INDArray out2 = sd2.getVariable(mul.getVarName()).eval(); INDArray out2 = sd2.getVariable(mul.name()).eval();
assertEquals(exp2, out2); assertEquals(exp2, out2);

View File

@ -33,18 +33,18 @@ public class NameScopeTests extends BaseNd4jTest {
SDVariable v = sd.var("x"); SDVariable v = sd.var("x");
try(NameScope ns = sd.withNameScope("nameScope")){ try(NameScope ns = sd.withNameScope("nameScope")){
SDVariable v2 = sd.var("x2"); SDVariable v2 = sd.var("x2");
assertEquals("nameScope/x2", v2.getVarName()); assertEquals("nameScope/x2", v2.name());
assertTrue(sd.getVariables().containsKey("nameScope/x2")); assertTrue(sd.getVariables().containsKey("nameScope/x2"));
assertEquals("nameScope", sd.currentNameScope()); assertEquals("nameScope", sd.currentNameScope());
SDVariable v3 = sd.var("x"); SDVariable v3 = sd.var("x");
assertEquals("nameScope/x", v3.getVarName()); assertEquals("nameScope/x", v3.name());
assertTrue(sd.getVariables().containsKey("nameScope/x")); assertTrue(sd.getVariables().containsKey("nameScope/x"));
try(NameScope ns2 = sd.withNameScope("scope2")){ try(NameScope ns2 = sd.withNameScope("scope2")){
assertEquals("nameScope/scope2", sd.currentNameScope()); assertEquals("nameScope/scope2", sd.currentNameScope());
SDVariable v4 = sd.var("x"); SDVariable v4 = sd.var("x");
assertEquals("nameScope/scope2/x", v4.getVarName()); assertEquals("nameScope/scope2/x", v4.name());
assertTrue(sd.getVariables().containsKey("nameScope/scope2/x")); assertTrue(sd.getVariables().containsKey("nameScope/scope2/x"));
} }
@ -76,19 +76,19 @@ public class NameScopeTests extends BaseNd4jTest {
} }
SDVariable a = sd.var("a", DataType.FLOAT, 1); SDVariable a = sd.var("a", DataType.FLOAT, 1);
assertEquals("x", x.getVarName()); assertEquals("x", x.name());
assertEquals("s1/y", y.getVarName()); assertEquals("s1/y", y.name());
assertEquals("s1/s2/z", z.getVarName()); assertEquals("s1/s2/z", z.name());
assertEquals("a", a.getVarName()); assertEquals("a", a.name());
assertTrue(add.getVarName(), add.getVarName().startsWith("s1/")); assertTrue(add.name(), add.name().startsWith("s1/"));
assertEquals("s1/addxy", addWithName.getVarName()); assertEquals("s1/addxy", addWithName.name());
assertTrue(merge.getVarName(), merge.getVarName().startsWith("s1/s2/")); assertTrue(merge.name(), merge.name().startsWith("s1/s2/"));
assertEquals("s1/s2/mmax", mergeWithName.getVarName()); assertEquals("s1/s2/mmax", mergeWithName.name());
Set<String> allowedVarNames = new HashSet<>(Arrays.asList("x", "s1/y", "s1/s2/z", "a", Set<String> allowedVarNames = new HashSet<>(Arrays.asList("x", "s1/y", "s1/s2/z", "a",
add.getVarName(), addWithName.getVarName(), merge.getVarName(), mergeWithName.getVarName())); add.name(), addWithName.name(), merge.name(), mergeWithName.name()));
Set<String> allowedOpNames = new HashSet<>(); Set<String> allowedOpNames = new HashSet<>();
//Check op names: //Check op names:
@ -102,8 +102,8 @@ public class NameScopeTests extends BaseNd4jTest {
//Check fields - Variable, SDOp, etc //Check fields - Variable, SDOp, etc
for(Variable v : sd.getVariables().values()){ for(Variable v : sd.getVariables().values()){
assertTrue(v.getVariable().getVarName(), allowedVarNames.contains(v.getVariable().getVarName())); assertTrue(v.getVariable().name(), allowedVarNames.contains(v.getVariable().name()));
assertEquals(v.getName(), v.getVariable().getVarName()); assertEquals(v.getName(), v.getVariable().name());
if(v.getInputsForOp() != null){ if(v.getInputsForOp() != null){
for(String s : v.getInputsForOp()){ for(String s : v.getInputsForOp()){
assertTrue(s, allowedOpNames.contains(s)); assertTrue(s, allowedOpNames.contains(s));

View File

@ -108,14 +108,14 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
sd.fit(ds); sd.fit(ds);
} }
for(String s : new String[]{"w", "b", badd.getVarName(), add.getVarName(), "l1", "l2"}){ for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){
SDVariable gradVar = sd.getVariable(s).gradient(); SDVariable gradVar = sd.getVariable(s).gradient();
assertNotNull(s, gradVar); assertNotNull(s, gradVar);
} }
//Unused: //Unused:
assertFalse(shape.hasGradient()); assertFalse(shape.hasGradient());
try{ assertNull(shape.gradient()); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("only floating point variables")); } try{ assertNull(shape.gradient()); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("only floating point variables")); }
for(String s : new String[]{unused1.getVarName(), unused2.getVarName(), unused3.getVarName()}){ for(String s : new String[]{unused1.name(), unused2.name(), unused3.name()}){
assertNull(sd.getVariable(s).gradient()); assertNull(sd.getVariable(s).gradient());
} }
} }
@ -151,20 +151,20 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
sd.setLossVariables("loss1"); sd.setLossVariables("loss1");
sd.createGradFunction(); sd.createGradFunction();
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
assertNotNull(v.getVarName(), v.gradient()); assertNotNull(v.name(), v.gradient());
} }
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
assertNull(v.getVarName(), v.gradient()); assertNull(v.name(), v.gradient());
} }
//Now, set to other loss function //Now, set to other loss function
sd.setLossVariables("loss2"); sd.setLossVariables("loss2");
sd.createGradFunction(); sd.createGradFunction();
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){ for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
assertNull(v.getVarName(), v.gradient()); assertNull(v.name(), v.gradient());
} }
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){ for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
assertNotNull(v.getVarName(), v.gradient()); assertNotNull(v.name(), v.gradient());
} }
//Train the first side of the graph. The other side should remain unmodified! //Train the first side of the graph. The other side should remain unmodified!

View File

@ -109,6 +109,8 @@ public class FileReadWriteTests extends BaseNd4jTest {
for (int i = 0; i < s.outputsLength(); i++) { for (int i = 0; i < s.outputsLength(); i++) {
outputs.add(s.outputs(i)); outputs.add(s.outputs(i));
} }
if(outputs.isEmpty())
outputs = null;
assertEquals(sd.outputs(), outputs); assertEquals(sd.outputs(), outputs);
//Check variables //Check variables

View File

@ -63,7 +63,7 @@ public class UIListenerTest {
Map<String, INDArray> m = new HashMap<>(); Map<String, INDArray> m = new HashMap<>();
iter.reset(); iter.reset();
m.put("in", iter.next().getFeatures()); m.put("in", iter.next().getFeatures());
INDArray out = sd.execSingle(m, "softmax"); INDArray out = sd.outputSingle(m, "softmax");
assertNotNull(out); assertNotNull(out);
assertArrayEquals(new long[]{150, 3}, out.shape()); assertArrayEquals(new long[]{150, 3}, out.shape());
} }

View File

@ -78,7 +78,7 @@ public class ExecutionTests extends BaseNd4jTest {
val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null); val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null);
System.out.println(tg.summary()); System.out.println(tg.summary());
Map<String,INDArray> result_0 = tg.exec(Collections.emptyMap(), tg.outputs()); Map<String,INDArray> result_0 = tg.outputAll(null);
val exp_0 = Nd4j.create(DataType.FLOAT, 3).assign(3.0); val exp_0 = Nd4j.create(DataType.FLOAT, 3).assign(3.0);
assertEquals(exp_0, result_0.get("Sum")); assertEquals(exp_0, result_0.get("Sum"));

View File

@ -174,7 +174,7 @@ public class BERTGraphTest extends BaseNd4jTest {
//Find pre-dropout input variable: //Find pre-dropout input variable:
SDVariable newOut = null; SDVariable newOut = null;
for(SDVariable v : inputs){ for(SDVariable v : inputs){
if(v.getVarName().endsWith("/BiasAdd") || v.getVarName().endsWith("/Softmax") || v.getVarName().endsWith("/add_1") || v.getVarName().endsWith("/Tanh")){ if(v.name().endsWith("/BiasAdd") || v.name().endsWith("/Softmax") || v.name().endsWith("/add_1") || v.name().endsWith("/Tanh")){
newOut = v; newOut = v;
break; break;
} }
@ -249,7 +249,7 @@ public class BERTGraphTest extends BaseNd4jTest {
placeholderValues.put("IteratorGetNext:1", mask); placeholderValues.put("IteratorGetNext:1", mask);
placeholderValues.put("IteratorGetNext:4", segmentIdxs); placeholderValues.put("IteratorGetNext:4", segmentIdxs);
Map<String, INDArray> out = sd.exec(placeholderValues, "loss/Softmax"); Map<String, INDArray> out = sd.output(placeholderValues, "loss/Softmax");
INDArray softmax = out.get("loss/Softmax"); INDArray softmax = out.get("loss/Softmax");
// System.out.println("OUTPUT - Softmax"); // System.out.println("OUTPUT - Softmax");
// System.out.println(softmax); // System.out.println(softmax);
@ -335,8 +335,8 @@ public class BERTGraphTest extends BaseNd4jTest {
//For training, convert weights and biases from constants to variables: //For training, convert weights and biases from constants to variables:
for(SDVariable v : sd.variables()){ for(SDVariable v : sd.variables()){
if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.getVarName())){ //Skip scalars - trainable params if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.name())){ //Skip scalars - trainable params
log.info("Converting to variable: {} - dtype: {} - shape: {}", v.getVarName(), v.dataType(), Arrays.toString(v.getArr().shape())); log.info("Converting to variable: {} - dtype: {} - shape: {}", v.name(), v.dataType(), Arrays.toString(v.getArr().shape()));
v.convertToVariable(); v.convertToVariable();
} }
} }
@ -393,14 +393,14 @@ public class BERTGraphTest extends BaseNd4jTest {
placeholderValues.put("IteratorGetNext:4", segmentIdxs); placeholderValues.put("IteratorGetNext:4", segmentIdxs);
placeholderValues.put("label", labelArr); placeholderValues.put("label", labelArr);
INDArray lossArr = sd.exec(placeholderValues, "loss").get("loss"); INDArray lossArr = sd.output(placeholderValues, "loss").get("loss");
assertTrue(lossArr.isScalar()); assertTrue(lossArr.isScalar());
double scoreBefore = lossArr.getDouble(0); double scoreBefore = lossArr.getDouble(0);
for( int i=0; i<5; i++ ){ for( int i=0; i<5; i++ ){
sd.fit(mds); sd.fit(mds);
} }
lossArr = sd.exec(placeholderValues, "loss").get("loss"); lossArr = sd.output(placeholderValues, "loss").get("loss");
assertTrue(lossArr.isScalar()); assertTrue(lossArr.isScalar());
double scoreAfter = lossArr.getDouble(0); double scoreAfter = lossArr.getDouble(0);

View File

@ -105,11 +105,9 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
//Perform inference //Perform inference
List<String> inputs = sd.inputs(); List<String> inputs = sd.inputs();
assertEquals(1, inputs.size()); assertEquals(1, inputs.size());
List<String> outputs = sd.outputs();
assertEquals(1, outputs.size());
String out = outputs.get(0); String out = "MobilenetV1/Predictions/Softmax";
Map<String,INDArray> m = sd.exec(Collections.singletonMap(inputs.get(0), img), out); Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out);
INDArray outArr = m.get(out); INDArray outArr = m.get(out);
@ -167,7 +165,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
assertEquals(1, inputs.size()); assertEquals(1, inputs.size());
String out = "softmax_tensor"; String out = "softmax_tensor";
Map<String,INDArray> m = sd.exec(Collections.singletonMap(inputs.get(0), img), out); Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out);
INDArray outArr = m.get(out); INDArray outArr = m.get(out);

View File

@ -106,7 +106,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
//g.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mnist.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build()); //g.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mnist.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build());
g.execAndEndResult(); g.outputAll(null);
} }
@ -129,7 +129,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream()); val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream());
log.info(graph.asFlatPrint()); log.info(graph.asFlatPrint());
val result = graph.execAndEndResult(); val result = graph.outputAll(null).get(graph.outputs().get(0));
val exp = Nd4j.createFromArray(new long[]{2, 2, 2}); val exp = Nd4j.createFromArray(new long[]{2, 2, 2});
@ -222,7 +222,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
graph.var("Placeholder", p0); graph.var("Placeholder", p0);
graph.var("Placeholder_1", p1); graph.var("Placeholder_1", p1);
val res = graph.execAndEndResult(); val res = graph.outputAll(null).get(graph.outputs().get(0));
@ -341,7 +341,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val constIn = tg.getVariable("StridedSlice/input"); val constIn = tg.getVariable("StridedSlice/input");
assertNotNull(constIn); assertNotNull(constIn);
val arr = tg.getArrForVarName(constIn.getVarName()); val arr = tg.getArrForVarName(constIn.name());
assertEquals(139.5, arr.sumNumber().doubleValue(), 1e-5); assertEquals(139.5, arr.sumNumber().doubleValue(), 1e-5);
@ -651,9 +651,8 @@ public class TensorFlowImportTest extends BaseNd4jTest {
INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4); INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4);
INDArray expectedOutput = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4).addRowVector(Nd4j.linspace(1,4,4, DataType.FLOAT)); INDArray expectedOutput = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4).addRowVector(Nd4j.linspace(1,4,4, DataType.FLOAT));
INDArray actual = graph.execSingle(Collections.singletonMap("input",input), graph.outputs().get(0)); INDArray actual = graph.outputSingle(Collections.singletonMap("input",input), graph.outputs().get(0));
assertEquals(input,graph.getVariable("input").getArr()); assertEquals(input,graph.getVariable("input").getArr());
assertArrayEquals(input.shape(),graph.getShapeForVarName(graph.getVariable("input").getVarName()));
assertEquals(expectedOutput,actual); assertEquals(expectedOutput,actual);
} }
@ -665,13 +664,13 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val variables = new HashMap<String, SDVariable>(); val variables = new HashMap<String, SDVariable>();
for (val var : tg.variables()) { for (val var : tg.variables()) {
variables.put(var.getVarName(), var); variables.put(var.name(), var);
} }
val functions = new HashMap<String, DifferentialFunction>(); val functions = new HashMap<String, DifferentialFunction>();
for (val func: tg.ops()) { for (val func: tg.ops()) {
val ownName = func.getOwnName(); val ownName = func.getOwnName();
String outName = func.outputVariables()[0].getVarName(); String outName = func.outputVariables()[0].name();
assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName)); assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
assertEquals(ownName, outName); assertEquals(ownName, outName);
@ -704,7 +703,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb")); tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
//log.info("{}", tg.asFlatPrint()); //log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult(); val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(1); val exp = Nd4j.create(2, 2).assign(1);
assertNotNull(array); assertNotNull(array);
assertEquals(exp, array); assertEquals(exp, array);
@ -723,7 +722,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
//log.info("{}", tg.asFlatPrint()); //log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult(); val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(1); val exp = Nd4j.create(2, 2).assign(1);
assertNotNull(array); assertNotNull(array);
assertEquals(exp, array); assertEquals(exp, array);
@ -741,7 +740,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
//log.info("{}", tg.asFlatPrint()); //log.info("{}", tg.asFlatPrint());
/* /*
val array = tg.execAndEndResult(); val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(2); val exp = Nd4j.create(2, 2).assign(2);
assertNotNull(array); assertNotNull(array);
assertEquals(exp, array);*/ assertEquals(exp, array);*/
@ -759,7 +758,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
//log.info("{}", tg.asFlatPrint()); //log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult(); val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(4); val exp = Nd4j.create(2, 2).assign(4);
assertNotNull(array); assertNotNull(array);
assertEquals(exp, array); assertEquals(exp, array);
@ -780,7 +779,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
//log.info("{}", tg.asFlatPrint()); //log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult(); INDArray array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(-1); val exp = Nd4j.create(2, 2).assign(-1);
assertNotNull(array); assertNotNull(array);
assertEquals(exp, array); assertEquals(exp, array);
@ -800,7 +799,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
//log.info("{}", tg.asFlatPrint()); //log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult(); val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(-3); val exp = Nd4j.create(2, 2).assign(-3);
assertNotNull(array); assertNotNull(array);
assertEquals(exp, array); assertEquals(exp, array);
@ -822,7 +821,8 @@ public class TensorFlowImportTest extends BaseNd4jTest {
//log.info("{}", tg.asFlatPrint()); //log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult(); Map<String,INDArray> m = tg.outputAll(null);
val array = m.get(tg.outputs().get(0));
//val array = tg.getVariable("output").getArr(); //val array = tg.getVariable("output").getArr();
val exp = Nd4j.create(2, 2).assign(15.0); val exp = Nd4j.create(2, 2).assign(15.0);
assertNotNull(array); assertNotNull(array);
@ -968,7 +968,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
assertNotNull(tg); assertNotNull(tg);
val input_matrix = Nd4j.ones(3, 2); val input_matrix = Nd4j.ones(3, 2);
val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
val exp = Nd4j.create(new float[] {1, 1, 2, 2, 3, 3}, new int[]{3, 2}); val exp = Nd4j.create(new float[] {1, 1, 2, 2, 3, 3}, new int[]{3, 2});
@ -982,7 +982,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val input_matrix = Nd4j.ones(3, 2); val input_matrix = Nd4j.ones(3, 2);
val array = tg.exec(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); val array = tg.output(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)).get(tg.outputs().get(0));
val exp = Nd4j.create(new float[] {2, 2}, new int[]{2}); val exp = Nd4j.create(new float[] {2, 2}, new int[]{2});
@ -997,7 +997,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream()); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream());
assertNotNull(tg); assertNotNull(tg);
val array = tg.execSingle(Collections.emptyMap(), tg.outputs().get(0)); val array = tg.outputSingle(Collections.emptyMap(), tg.outputs().get(0));
val exp = Nd4j.create(new float[] {5, 6, 7, 8}, new int[]{4}); val exp = Nd4j.create(new float[] {5, 6, 7, 8}, new int[]{4});
@ -1011,7 +1011,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2); val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2);
log.info("Graph: {}", tg.asFlatPrint()); log.info("Graph: {}", tg.asFlatPrint());
val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)); val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
val exp = Nd4j.create(new float[] {3,6, 9,12, 15,18, 21,24, 27,30}, new int[]{5, 2}); val exp = Nd4j.create(new float[] {3,6, 9,12, 15,18, 21,24, 27,30}, new int[]{5, 2});
@ -1023,7 +1023,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
Nd4j.create(1); Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream()); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream());
tg.execAndEndResult(); tg.outputAll(null);
} }
@Test @Test
@ -1040,7 +1040,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
for (int e = 0; e < 1000; e++){ for (int e = 0; e < 1000; e++){
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream()); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream());
Map<String,INDArray> result = tg.exec(Collections.emptyMap(), tg.outputs()); Map<String,INDArray> result = tg.output(Collections.emptyMap(), tg.outputs());
assertNotNull(result); assertNotNull(result);
assertTrue(result.size() > 0); assertTrue(result.size() > 0);
@ -1052,7 +1052,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
Nd4j.create(1); Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream()); val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream());
tg.execAndEndResult(); tg.outputAll(null);
} }
@Test @Test

View File

@ -118,7 +118,7 @@ public class ImportModelDebugger {
List<String> outputs = sd.outputs(); List<String> outputs = sd.outputs();
sd.exec(ph, outputs); sd.output(ph, outputs);
} }

View File

@ -0,0 +1,93 @@
package org.nd4j.linalg.convolution;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.resources.Resources;
import java.io.File;
import java.util.*;
import static org.junit.Assert.*;
public class DeconvTests extends BaseNd4jTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
public DeconvTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void compareKeras() throws Exception {
File f = testDir.newFolder();
Resources.copyDirectory("keras/deconv", f);
File[] files = f.listFiles();
Set<String> tests = new HashSet<>();
for(File file : files){
String n = file.getName();
if(!n.startsWith("mb"))
continue;
int idx = n.lastIndexOf('_');
String name = n.substring(0, idx);
tests.add(name);
}
List<String> l = new ArrayList<>(tests);
Collections.sort(l);
assertFalse(l.isEmpty());
for(String s : l){
String s2 = s.replaceAll("[a-zA-Z]", "");
String[] nums = s2.split("_");
int mb = Integer.parseInt(nums[0]);
int k = Integer.parseInt(nums[1]);
int size = Integer.parseInt(nums[2]);
int stride = Integer.parseInt(nums[3]);
boolean same = s.contains("same");
int d = Integer.parseInt(nums[5]);
boolean nchw = s.contains("nchw");
INDArray w = Nd4j.readNpy(new File(f, s + "_W.npy"));
INDArray b = Nd4j.readNpy(new File(f, s + "_b.npy"));
INDArray in = Nd4j.readNpy(new File(f, s + "_in.npy")).castTo(DataType.FLOAT);
INDArray expOut = Nd4j.readNpy(new File(f, s + "_out.npy"));
CustomOp op = DynamicCustomOp.builder("deconv2d")
.addInputs(in, w, b)
.addIntegerArguments(
k, k,
stride, stride,
0, 0, //padding
d, d,
same ? 1 : 0,
nchw ? 0 : 1)
.callInplace(false)
.build();
INDArray out = Nd4j.create(op.calculateOutputShape().get(0));
out.assign(Double.NaN);
op.addOutputArgument(out);
Nd4j.exec(op);
boolean eq = expOut.equalsWithEps(out, 1e-4);
assertTrue(eq);
}
}
}

View File

@ -21,8 +21,6 @@ import org.nd4j.config.ND4JEnvironmentVars;
import org.nd4j.config.ND4JSystemProperties; import org.nd4j.config.ND4JSystemProperties;
import org.nd4j.context.Nd4jContext; import org.nd4j.context.Nd4jContext;
import org.nd4j.linalg.io.Resource; import org.nd4j.linalg.io.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -152,6 +150,9 @@ public abstract class Nd4jBackend {
*/ */
public static Nd4jBackend load() throws NoAvailableBackendException { public static Nd4jBackend load() throws NoAvailableBackendException {
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
boolean logInit = Boolean.parseBoolean(logInitProperty);
List<Nd4jBackend> backends = new ArrayList<>(1); List<Nd4jBackend> backends = new ArrayList<>(1);
ServiceLoader<Nd4jBackend> loader = ServiceLoader.load(Nd4jBackend.class); ServiceLoader<Nd4jBackend> loader = ServiceLoader.load(Nd4jBackend.class);
try { try {
@ -183,7 +184,9 @@ public abstract class Nd4jBackend {
error = e.getMessage(); error = e.getMessage();
} }
if (!available) { if (!available) {
if(logInit) {
log.warn("Skipped [{}] backend (unavailable): {}", backend.getClass().getSimpleName(), error); log.warn("Skipped [{}] backend (unavailable): {}", backend.getClass().getSimpleName(), error);
}
continue; continue;
} }
@ -193,7 +196,9 @@ public abstract class Nd4jBackend {
e.printStackTrace(); e.printStackTrace();
} }
if(logInit) {
log.info("Loaded [{}] backend", backend.getClass().getSimpleName()); log.info("Loaded [{}] backend", backend.getClass().getSimpleName());
}
return backend; return backend;
} }
@ -273,6 +278,8 @@ public abstract class Nd4jBackend {
return getClass().getName(); return getClass().getName();
} }
public abstract void logBackendInit();
@SuppressWarnings("serial") @SuppressWarnings("serial")
public static class NoAvailableBackendException extends Exception { public static class NoAvailableBackendException extends Exception {

View File

@ -200,7 +200,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
map.put(n, mds.getFeatures(cnt++)); map.put(n, mds.getFeatures(cnt++));
} }
val output = sdModel.exec(map, orderedOutputNodes); val output = sdModel.output(map, orderedOutputNodes);
val arrays = new INDArray[output.size()]; val arrays = new INDArray[output.size()];
// now we need to get ordered output arrays, as specified in server constructor // now we need to get ordered output arrays, as specified in server constructor