SameDiff cleanup and fixes (#12)
* #8160 Remove resolvePrepertiesFromSameDiffBeforeExecution Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff API cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More SameDiff cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8248 Switch SameDiff variable init from lazy to creation time for more predictable behaviour Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8252 TanhDerivative javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8225 Deconvolution2D input validation Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8265 Switch SameDiff.outputs() to user settable, instead of unreliable 'best guess' Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8224 SameDiff.zero and .one create constants, not variables Signed-off-by: AlexDBlack <blacka101@gmail.com> * More cleanup and fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small test fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J SameDiff fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Re-add hack for Deconvolution2DLayer until #8315 is resolved Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8270 Move CUDA device/version logging to Java; can be disabled via existing org.nd4j.log.initialization system property Signed-off-by: AlexDBlack <blacka101@gmail.com> * All ND4J init logging checks system property Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove redundant device logging Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * UX improvements Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add deconv tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug code Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
d98784197a
commit
d333d29099
|
@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.deeplearning4j.eval.Evaluation;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nn.api.Layer;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.ConvolutionMode;
|
||||
|
@ -693,4 +694,22 @@ public class ConvolutionLayerTest extends BaseDL4JTest {
|
|||
INDArray out = net.output(in);
|
||||
assertArrayEquals(new long[]{2,7,6}, out.shape());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDeconvBadInput(){
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.list()
|
||||
.layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build())
|
||||
.build();
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
INDArray badInput = Nd4j.create(DataType.FLOAT, 1, 10, 5, 5);
|
||||
try {
|
||||
net.output(badInput);
|
||||
} catch (DL4JInvalidInputException e){
|
||||
String msg = e.getMessage();
|
||||
assertTrue(msg,msg.contains("Deconvolution2D") && msg.contains("input") && msg.contains("channels"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -86,10 +86,10 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
|||
SDVariable label = sd.placeHolder("label", DataType.DOUBLE, -1, 3);
|
||||
|
||||
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.DOUBLE, 4, 10);
|
||||
SDVariable b0 = sd.zero("b0", 1, 10);
|
||||
SDVariable b0 = sd.var("b0", Nd4j.create(DataType.DOUBLE, 1, 10));
|
||||
|
||||
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.DOUBLE, 10, 3);
|
||||
SDVariable b1 = sd.zero("b1", 1, 3);
|
||||
SDVariable b1 = sd.var("b1", Nd4j.create(DataType.DOUBLE, 1, 3));
|
||||
|
||||
SDVariable z0 = in.mmul(w0).add(b0);
|
||||
SDVariable a0 = sd.nn().tanh(z0);
|
||||
|
@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
|||
Map<String,INDArray> placeholders = new HashMap<>();
|
||||
placeholders.put("input", f);
|
||||
placeholders.put("label", l);
|
||||
Map<String,INDArray> map = sd.output(placeholders, lossMse.getVarName(), a1.getVarName());
|
||||
INDArray outSd = map.get(a1.getVarName());
|
||||
Map<String,INDArray> map = sd.output(placeholders, lossMse.name(), a1.name());
|
||||
INDArray outSd = map.get(a1.name());
|
||||
INDArray outDl4j = net.output(f);
|
||||
|
||||
assertEquals(testName, outDl4j, outSd);
|
||||
|
@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
|||
|
||||
//Check score
|
||||
double scoreDl4j = net.score();
|
||||
double scoreSd = map.get(lossMse.getVarName()).getDouble(0) + sd.calcRegularizationScore();
|
||||
double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore();
|
||||
assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
|
||||
|
||||
double lossRegScoreSD = sd.calcRegularizationScore();
|
||||
|
@ -197,15 +197,15 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
|||
|
||||
//Check gradients (before updater applied)
|
||||
Map<String,INDArray> grads = net.gradient().gradientForVariable();
|
||||
sd.execBackwards(placeholders);
|
||||
Map<String,INDArray> gm = sd.calculateGradients(placeholders, b1.name(), w1.name(), b0.name(), w0.name());
|
||||
|
||||
//Note that the SameDiff gradients don't include the L1/L2 terms at present just from execBackwards()... these are added in fitting only
|
||||
//We can check correctness though with training param checks later
|
||||
if(l1Val == 0 && l2Val == 0 && wdVal == 0) {
|
||||
assertEquals(testName, grads.get("1_b"), b1.getGradient().getArr());
|
||||
assertEquals(testName, grads.get("1_W"), w1.getGradient().getArr());
|
||||
assertEquals(testName, grads.get("0_b"), b0.getGradient().getArr());
|
||||
assertEquals(testName, grads.get("0_W"), w0.getGradient().getArr());
|
||||
assertEquals(testName, grads.get("1_b"), gm.get(b1.name()));
|
||||
assertEquals(testName, grads.get("1_W"), gm.get(w1.name()));
|
||||
assertEquals(testName, grads.get("0_b"), gm.get(b0.name()));
|
||||
assertEquals(testName, grads.get("0_W"), gm.get(w0.name()));
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex {
|
|||
defineVertex(temp, tempInputs);
|
||||
List<String> list = new ArrayList<>();
|
||||
for (Integer i : tempInputs.map.keySet()) {
|
||||
list.add(tempInputs.map.get(i).getVarName());
|
||||
list.add(tempInputs.map.get(i).name());
|
||||
}
|
||||
params.defineInputs(list.toArray(new String[list.size()]));
|
||||
}
|
||||
|
|
|
@ -176,8 +176,10 @@ public class Deconvolution2DLayer extends ConvolutionLayer {
|
|||
int outDepth = (int) weights.size(1);
|
||||
|
||||
if (input.size(1) != inDepth && input.size(3) == inDepth) {
|
||||
//TODO AB 2019/10/25 this is an ugly "pseudo-NHWC support" hack that needs to be removed ASAD
|
||||
//https://github.com/eclipse/deeplearning4j/issues/8315
|
||||
input = input.permute(0, 3, 1, 2);
|
||||
} else if (input.size(1) != inDepth && input.size(3) != inDepth) {
|
||||
} else if (input.size(1) != inDepth ) {
|
||||
String layerName = conf.getLayer().getLayerName();
|
||||
if (layerName == null)
|
||||
layerName = "(not named)";
|
||||
|
|
|
@ -192,7 +192,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
String name = inputs.get(j);
|
||||
dLdIns[j] = sameDiff.grad(name).getArr();
|
||||
|
||||
String gradName = sameDiff.grad(inputNames.get(j)).getVarName();
|
||||
String gradName = sameDiff.grad(inputNames.get(j)).name();
|
||||
if(dLdIns[j] == null && fnName.equals(gradName)){
|
||||
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||
|
@ -271,7 +271,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
fn = sameDiff.f().externalErrors(layerOutput);
|
||||
fn.outputVariable();
|
||||
|
||||
this.outputKey = outputVar.getVarName();
|
||||
this.outputKey = outputVar.name();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -302,7 +302,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
fn = sameDiff.f().externalErrors(layerOutput);
|
||||
fn.outputVariable();
|
||||
|
||||
this.outputKey = outputVar.getVarName();
|
||||
this.outputKey = outputVar.name();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -112,7 +112,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
phMap.put(LABELS_KEY, labels);
|
||||
}
|
||||
|
||||
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
|
||||
String s = activations ? layerConf().activationsVertexName() : outputVar.name();
|
||||
|
||||
INDArray out = sameDiff.outputSingle(phMap, s);
|
||||
|
||||
|
@ -160,31 +160,35 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
}
|
||||
|
||||
List<String> gradVarNames = new ArrayList<>();
|
||||
for(String s : paramTable.keySet()){
|
||||
gradVarNames.add(sameDiff.getVariable(s).getGradient().getVarName());
|
||||
}
|
||||
gradVarNames.add(sameDiff.grad(INPUT_KEY).getVarName());
|
||||
gradVarNames.addAll(paramTable.keySet());
|
||||
gradVarNames.add(INPUT_KEY);
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
phMap.put(LABELS_KEY, labels);
|
||||
|
||||
sameDiff.execBackwards(phMap, gradVarNames);
|
||||
Map<String,INDArray> grads = sameDiff.calculateGradients(phMap, gradVarNames);
|
||||
for(String s : paramTable.keySet() ){
|
||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||
INDArray sdGrad = grads.get(s);
|
||||
INDArray dl4jGrad = gradTable.get(s);
|
||||
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
|
||||
g.gradientForVariable().put(s, dl4jGrad);
|
||||
if(sdGrad.closeable()){
|
||||
sdGrad.close();
|
||||
}
|
||||
}
|
||||
|
||||
dLdIn = sameDiff.grad(INPUT_KEY).getArr();
|
||||
dLdIn = grads.get(INPUT_KEY);
|
||||
}
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
sameDiff.clearOpInputs();
|
||||
|
||||
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||
Pair<Gradient,INDArray> p = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
|
||||
if(dLdIn.closeable())
|
||||
dLdIn.close();
|
||||
return p;
|
||||
}
|
||||
|
||||
/**Returns the parameters of the neural network as a flattened row vector
|
||||
|
@ -297,7 +301,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
this.outputKey = layerOutput.getVarName();
|
||||
this.outputKey = layerOutput.name();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -66,13 +66,6 @@ namespace nd4j {
|
|||
#endif
|
||||
|
||||
#ifdef __CUDABLAS__
|
||||
BlasVersionHelper ver;
|
||||
_blasMajorVersion = ver._blasMajorVersion;
|
||||
_blasMinorVersion = ver._blasMinorVersion;
|
||||
_blasPatchVersion = ver._blasPatchVersion;
|
||||
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
|
||||
fflush(stdout);
|
||||
|
||||
int devCnt = 0;
|
||||
cudaGetDeviceCount(&devCnt);
|
||||
auto devProperties = new cudaDeviceProp[devCnt];
|
||||
|
@ -83,10 +76,12 @@ namespace nd4j {
|
|||
//cudaDeviceSetLimit(cudaLimitStackSize, 4096);
|
||||
Pair p(devProperties[i].major, devProperties[i].minor);
|
||||
_capabilities.emplace_back(p);
|
||||
|
||||
printf("CUDA device %i: [%s]; cc: [%i.%i]; Total memory: [%lld];\n", i, devProperties[i].name, devProperties[i].major, devProperties[i].minor, (Nd4jLong) devProperties[i].totalGlobalMem);
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
BlasVersionHelper ver;
|
||||
_blasMajorVersion = ver._blasMajorVersion;
|
||||
_blasMinorVersion = ver._blasMinorVersion;
|
||||
_blasPatchVersion = ver._blasPatchVersion;
|
||||
|
||||
cudaSetDevice(0);
|
||||
delete[] devProperties;
|
||||
|
@ -203,6 +198,18 @@ namespace nd4j {
|
|||
#endif
|
||||
}
|
||||
|
||||
int Environment::blasMajorVersion(){
|
||||
return _blasMajorVersion;
|
||||
}
|
||||
|
||||
int Environment::blasMinorVersion(){
|
||||
return _blasMinorVersion;
|
||||
}
|
||||
|
||||
int Environment::blasPatchVersion(){
|
||||
return _blasPatchVersion;
|
||||
}
|
||||
|
||||
nd4j::Environment *nd4j::Environment::_instance = 0;
|
||||
|
||||
}
|
||||
|
|
|
@ -97,6 +97,10 @@ namespace nd4j{
|
|||
|
||||
bool isCPU();
|
||||
|
||||
int blasMajorVersion();
|
||||
int blasMinorVersion();
|
||||
int blasPatchVersion();
|
||||
|
||||
std::vector<Pair>& capabilities();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -66,8 +66,10 @@ CUSTOM_OP_IMPL(deconv2d, 2, 1, false, 0, 9) {
|
|||
if(!isNCHW)
|
||||
output = new NDArray(output->permute({0, 3, 1, 2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
|
||||
if(isSameMode) // SAME
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
if(isSameMode){ // SAME
|
||||
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
|
||||
NDArray columns(input->ordering(), {bS, oC, kH, kW, iH, iW}, input->dataType(), block.launchContext());
|
||||
|
||||
|
|
|
@ -442,22 +442,12 @@ public abstract class DifferentialFunction {
|
|||
setInstanceId();
|
||||
if(sameDiff != null) {
|
||||
sameDiff.addArgsFor(args, this);
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
if (args[i].isPlaceHolder()) {
|
||||
sameDiff.addPropertyToResolve(this, args[i].getVarName());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void replaceArg(int i, SDVariable newArg){
|
||||
if(sameDiff != null){
|
||||
sameDiff.replaceArgFor(i, newArg, this);
|
||||
if(args()[i].isPlaceHolder() && !newArg.isPlaceHolder()){
|
||||
sameDiff.removePropertyToResolve(this, args()[i].getVarName());
|
||||
} else if(!args()[i].isPlaceHolder() && newArg.isPlaceHolder()){
|
||||
sameDiff.addPropertyToResolve(this, newArg.getVarName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -483,7 +473,7 @@ public abstract class DifferentialFunction {
|
|||
SDVariable[] outputVars = outputVariables();
|
||||
String[] out = new String[outputVars.length];
|
||||
for( int i=0; i<out.length; i++ ){
|
||||
out[i] = outputVars[i].getVarName();
|
||||
out[i] = outputVars[i].name();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
@ -538,69 +528,11 @@ public abstract class DifferentialFunction {
|
|||
SDVariable[] args = args();
|
||||
String[] out = new String[args.length];
|
||||
for( int i=0; i<args.length; i++ ){
|
||||
out[i] = args[i].getVarName();
|
||||
out[i] = args[i].name();
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Resolve properties and arguments right before execution of
|
||||
* this operation.
|
||||
*
|
||||
* @deprecated Will be removed in the future. Ops should support array arguments. Should not bs used or overridden.
|
||||
*/
|
||||
@Deprecated
|
||||
public final void resolvePropertiesFromSameDiffBeforeExecution() {
|
||||
val properties = sameDiff.propertiesToResolveForFunction(this);
|
||||
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
|
||||
val currentFields = this.propertiesForFunction();
|
||||
|
||||
for(val property : properties) {
|
||||
//property maybe a variable which is only an array
|
||||
//just skip if this is the case
|
||||
if(!fields.containsKey(property))
|
||||
continue;
|
||||
|
||||
val var = sameDiff.getVarNameForFieldAndFunction(this,property);
|
||||
if(var == null)
|
||||
continue; //Rarely (like Conv2D) properties will be optional. For example kH/kW args will be inferred from weight shape
|
||||
val fieldType = fields.get(property);
|
||||
val varArr = sameDiff.getArrForVarName(var);
|
||||
//already defined
|
||||
if(currentFields.containsKey(property)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Possible cause:
|
||||
* Might be related to output name alignment.
|
||||
*
|
||||
*/
|
||||
if(varArr == null) {
|
||||
throw new ND4JIllegalStateException("Unable to set null array!");
|
||||
}
|
||||
|
||||
if(fieldType.getType().equals(int[].class)) {
|
||||
setValueFor(fieldType,varArr.data().asInt());
|
||||
}
|
||||
|
||||
else if(fieldType.equals(double[].class)) {
|
||||
setValueFor(fieldType,varArr.data().asDouble());
|
||||
}
|
||||
|
||||
else if(fieldType.equals(int.class)) {
|
||||
setValueFor(fieldType,varArr.getInt(0));
|
||||
}
|
||||
|
||||
else if(fieldType.equals(double.class)) {
|
||||
setValueFor(fieldType,varArr.getDouble(0));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the first argument
|
||||
* @return
|
||||
|
@ -639,13 +571,12 @@ public abstract class DifferentialFunction {
|
|||
|
||||
SDVariable gradVar = f().add(grad, vals.get(i));
|
||||
vals.set(i, gradVar);
|
||||
sameDiff.setGradientForVariableName(var.getVarName(), gradVar);
|
||||
sameDiff.setGradientForVariableName(var.name(), gradVar);
|
||||
} else {
|
||||
SDVariable gradVar = vals.get(i);
|
||||
|
||||
sameDiff.updateVariableNameAndReference(gradVar,var.getVarName() + "-grad");
|
||||
sameDiff.setGradientForVariableName(var.getVarName(), gradVar);
|
||||
sameDiff.setForwardVariableForVarName(gradVar.getVarName(),var);
|
||||
sameDiff.updateVariableNameAndReference(gradVar,var.name() + "-grad");
|
||||
sameDiff.setGradientForVariableName(var.name(), gradVar);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -184,7 +184,6 @@ import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
|
|||
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
|
||||
import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation;
|
||||
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Constant;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
|
||||
|
@ -352,12 +351,6 @@ public class DifferentialFunctionFactory {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
public Constant val(SDVariable iX) {
|
||||
return new Constant(sameDiff(), iX,
|
||||
iX.getShape());
|
||||
}
|
||||
|
||||
public ExternalErrorsFunction externalErrors(SDVariable... inputs) {
|
||||
return externalErrors(null, inputs);
|
||||
}
|
||||
|
@ -384,10 +377,6 @@ public class DifferentialFunctionFactory {
|
|||
return new OnesLike(name, sameDiff(), input, dataType).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable constant(SDVariable input, long... shape) {
|
||||
return new Constant(sameDiff(), input, (shape != null && shape.length > 0 ? shape : null)).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) {
|
||||
return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable();
|
||||
}
|
||||
|
@ -1056,7 +1045,7 @@ public class DifferentialFunctionFactory {
|
|||
|
||||
|
||||
public SDVariable gradientBackwardsMarker(SDVariable iX) {
|
||||
return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.getVarName() + "-pairgrad", 1.0)).outputVariable();
|
||||
return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.name() + "-pairgrad", 1.0)).outputVariable();
|
||||
}
|
||||
|
||||
public SDVariable abs(SDVariable iX) {
|
||||
|
|
|
@ -178,7 +178,7 @@ public class ListenerEvaluations {
|
|||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
return trainEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
return trainEvaluation(variable.name(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -202,7 +202,7 @@ public class ListenerEvaluations {
|
|||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
return validationEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
return validationEvaluation(variable.name(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -167,7 +167,7 @@ public class ListenerVariables {
|
|||
String[] names = new String[variables.length];
|
||||
|
||||
for (int i = 0; i < variables.length; i++)
|
||||
names[i] = variables[i].getVarName();
|
||||
names[i] = variables[i].name();
|
||||
|
||||
return requireVariables(op, names);
|
||||
}
|
||||
|
|
|
@ -226,7 +226,7 @@ public class UIListener extends BaseListener {
|
|||
List<SDVariable> sdVars = sd.variables();
|
||||
List<String> varNames = new ArrayList<>(sdVars.size());
|
||||
for(SDVariable v : sdVars){
|
||||
varNames.add(v.getVarName());
|
||||
varNames.add(v.name());
|
||||
}
|
||||
|
||||
if(varNames.size() != vars.size() || !varNames.containsAll(vars)){
|
||||
|
|
|
@ -91,7 +91,7 @@ public class EvaluationRecord {
|
|||
* @param param The target param/variable
|
||||
*/
|
||||
public List<IEvaluation> evaluations(SDVariable param) {
|
||||
return evaluations(param.getVarName());
|
||||
return evaluations(param.name());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -105,7 +105,7 @@ public class EvaluationRecord {
|
|||
* Get the evaluation for param at the specified index
|
||||
*/
|
||||
public IEvaluation evaluation(SDVariable param, int index) {
|
||||
return evaluation(param.getVarName(), index);
|
||||
return evaluation(param.name(), index);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -132,7 +132,7 @@ public class EvaluationRecord {
|
|||
* @param param The target param/variable
|
||||
*/
|
||||
public <T extends IEvaluation> T evaluation(SDVariable param) {
|
||||
return evaluation(param.getVarName());
|
||||
return evaluation(param.name());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -174,7 +174,7 @@ public class EvaluationRecord {
|
|||
* @param evalClass The type of evaluation to look for
|
||||
*/
|
||||
public <T extends IEvaluation<T>> T evaluation(SDVariable param, Class<T> evalClass) {
|
||||
return evaluation(param.getVarName(), evalClass);
|
||||
return evaluation(param.name(), evalClass);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -209,7 +209,7 @@ public class EvaluationRecord {
|
|||
* @param metric The metric to calculate
|
||||
*/
|
||||
public double getValue(SDVariable param, IMetric metric) {
|
||||
return getValue(param.getVarName(), metric);
|
||||
return getValue(param.name(), metric);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -235,7 +235,7 @@ public class EvaluationRecord {
|
|||
* @param metric The metric to calculate
|
||||
*/
|
||||
public double getValue(SDVariable param, int index, IMetric metric) {
|
||||
return getValue(param.getVarName(), index, metric);
|
||||
return getValue(param.name(), index, metric);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -125,7 +125,7 @@ public class History {
|
|||
* Only works if there is only one evaluation with the given metric for param
|
||||
*/
|
||||
public List<Double> trainingEval(SDVariable param, IMetric metric){
|
||||
return trainingEval(param.getVarName(), metric);
|
||||
return trainingEval(param.name(), metric);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -149,7 +149,7 @@ public class History {
|
|||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<Double> trainingEval(SDVariable param, int index, IMetric metric){
|
||||
return trainingEval(param.getVarName(), index, metric);
|
||||
return trainingEval(param.name(), index, metric);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -184,7 +184,7 @@ public class History {
|
|||
* Only works if there is only one evaluation for param.
|
||||
*/
|
||||
public List<IEvaluation> trainingEval(SDVariable param){
|
||||
return trainingEval(param.getVarName());
|
||||
return trainingEval(param.name());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -208,7 +208,7 @@ public class History {
|
|||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<IEvaluation> trainingEval(SDVariable param, int index){
|
||||
return trainingEval(param.getVarName(), index);
|
||||
return trainingEval(param.name(), index);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -230,7 +230,7 @@ public class History {
|
|||
* Only works if there is only one evaluation with the given metric for param
|
||||
*/
|
||||
public List<Double> validationEval(SDVariable param, IMetric metric){
|
||||
return validationEval(param.getVarName(), metric);
|
||||
return validationEval(param.name(), metric);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -254,7 +254,7 @@ public class History {
|
|||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<Double> validationEval(SDVariable param, int index, IMetric metric){
|
||||
return validationEval(param.getVarName(), index, metric);
|
||||
return validationEval(param.name(), index, metric);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -289,7 +289,7 @@ public class History {
|
|||
* Only works if there is only one evaluation for param.
|
||||
*/
|
||||
public List<IEvaluation> validationEval(SDVariable param){
|
||||
return validationEval(param.getVarName());
|
||||
return validationEval(param.name());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -313,7 +313,7 @@ public class History {
|
|||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<IEvaluation> validationEval(SDVariable param, int index){
|
||||
return validationEval(param.getVarName(), index);
|
||||
return validationEval(param.name(), index);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -116,7 +116,7 @@ public class LossCurve {
|
|||
* Return all mean loss values for a given variable
|
||||
*/
|
||||
public float[] meanLoss(@NonNull SDVariable loss){
|
||||
return meanLoss(loss.getVarName());
|
||||
return meanLoss(loss.name());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -143,7 +143,7 @@ public class LossCurve {
|
|||
* See {@link #meanLoss(int)}
|
||||
*/
|
||||
public float meanLoss(@NonNull SDVariable loss, int epoch){
|
||||
return meanLoss(loss.getVarName(), epoch);
|
||||
return meanLoss(loss.name(), epoch);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -162,7 +162,7 @@ public class LossCurve {
|
|||
* Return the mean loss value for a given variable on the last epoch.
|
||||
*/
|
||||
public float lastMeanLoss(@NonNull SDVariable loss){
|
||||
return lastMeanLoss(loss.getVarName());
|
||||
return lastMeanLoss(loss.name());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -189,7 +189,7 @@ public class LossCurve {
|
|||
* A positive delta means the loss is increasing, and a negative delta means it is decreasing.
|
||||
*/
|
||||
public double lastMeanDelta(SDVariable loss){
|
||||
return lastMeanDelta(loss.getVarName());
|
||||
return lastMeanDelta(loss.name());
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -59,10 +59,6 @@ public class SDVariable implements Serializable {
|
|||
@Setter
|
||||
protected VariableType variableType;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
protected WeightInitScheme weightInitScheme;
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
protected long[] shape;
|
||||
|
||||
|
@ -75,9 +71,7 @@ public class SDVariable implements Serializable {
|
|||
// autogen_tag::sdvars::start
|
||||
|
||||
|
||||
public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType, WeightInitScheme weightInitScheme){
|
||||
Preconditions.checkState(weightInitScheme == null || varType == VariableType.VARIABLE, "Weight initalization schemes can only be applied to VARIABLE type" +
|
||||
" SDVariables - variable \"%s\" is of type %s but was provided a weight initialization scheme %s", varName, varType, weightInitScheme);
|
||||
public SDVariable(@NonNull String varName, @NonNull VariableType varType, @NonNull SameDiff sameDiff, long[] shape, DataType dataType){
|
||||
Preconditions.checkState(dataType != DataType.UNKNOWN, "Unknown datatype is not allowed for SDVariables (variable name: %s)", varName);
|
||||
|
||||
varName = sameDiff.generateNewVarName(varName, 0, true);
|
||||
|
@ -86,10 +80,25 @@ public class SDVariable implements Serializable {
|
|||
this.varName = varName;
|
||||
this.variableType = varType;
|
||||
this.dataType = dataType;
|
||||
this.weightInitScheme = weightInitScheme;
|
||||
this.shape = shape;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the name of the SDVariable
|
||||
* @return Name of the variable
|
||||
*/
|
||||
public String name(){
|
||||
return varName;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #name()}
|
||||
*/
|
||||
@Deprecated
|
||||
public String getVarName(){
|
||||
return name();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if this variable is a place holder
|
||||
* @return
|
||||
|
@ -102,30 +111,6 @@ public class SDVariable implements Serializable {
|
|||
return variableType == VariableType.CONSTANT;
|
||||
}
|
||||
|
||||
/**
|
||||
* Allocate and return a new array
|
||||
* based on the vertex id and weight initialization.
|
||||
* @return the allocated array
|
||||
*/
|
||||
public INDArray storeAndAllocateNewArray() {
|
||||
Preconditions.checkState(variableType == VariableType.VARIABLE, "Unable to allocate and store array for variable of type %s: only" +
|
||||
" VARIABLE type variables can be initialized using this method", variableType);
|
||||
|
||||
if(!sameDiff.arrayAlreadyExistsForVarName(varName)){
|
||||
long[] shape = getShape();
|
||||
INDArray arr = getWeightInitScheme().create(dataType(), shape);
|
||||
sameDiff.associateArrayWithVariable(arr, this);
|
||||
if(log.isTraceEnabled()){
|
||||
log.trace("Generated and stored new array for variable \"{}\": shape {}", getVarName(), Arrays.toString(arr.shape()));
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
//Variable type SDVariables: shape should never change (i.e., these are params in the net!)
|
||||
INDArray ret = getArr();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* A getter for the allocated ndarray with this {@link SDVariable}.
|
||||
*
|
||||
|
@ -155,30 +140,14 @@ public class SDVariable implements Serializable {
|
|||
public INDArray getArr(boolean enforceExistence){
|
||||
if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
|
||||
return sameDiff.getArrForVarName(getVarName());
|
||||
|
||||
if(variableType == VariableType.ARRAY){
|
||||
throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead");
|
||||
}
|
||||
|
||||
//initialize value if it's actually a scalar constant (zero or 1 typically...)
|
||||
if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){
|
||||
INDArray arr = weightInitScheme.create(dataType, shape);
|
||||
sameDiff.associateArrayWithVariable(arr, this);
|
||||
if(log.isTraceEnabled()){
|
||||
log.trace("getArr() for variable \"{}\" allocated new array: shape {}", getVarName(), Arrays.toString(getShape()));
|
||||
INDArray ret = sameDiff.getArrForVarName(getVarName());
|
||||
if(enforceExistence && ret == null){
|
||||
throw new IllegalStateException("No array exists for variable \"" + name() + "\"");
|
||||
}
|
||||
return arr;
|
||||
} else if(sameDiff.getShapeForVarName(getVarName()) == null) {
|
||||
if (enforceExistence) {
|
||||
throw new IllegalStateException("Cannot get array for SDVariable \"" + getVarName() + "\": no array has" +
|
||||
" been defined, and array shape cannot be calculated");
|
||||
}
|
||||
if(log.isTraceEnabled()){
|
||||
log.trace("SDVariable.getArr(): could not get array for variable {}: shape is null", getVarName());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
return sameDiff.getArrForVarName(getVarName());
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
@ -215,21 +184,13 @@ public class SDVariable implements Serializable {
|
|||
* @return Shape of the variable
|
||||
*/
|
||||
public long[] getShape() {
|
||||
if (variableType == VariableType.PLACEHOLDER && getArr() == null) {
|
||||
if (shape != null)
|
||||
if (variableType == VariableType.PLACEHOLDER ) {
|
||||
return shape;
|
||||
else
|
||||
return new long[0];
|
||||
} else if(variableType == VariableType.VARIABLE || variableType == VariableType.CONSTANT){
|
||||
return getArr().shape();
|
||||
}
|
||||
|
||||
long[] initialShape = sameDiff.getShapeForVarName(getVarName());
|
||||
if(initialShape == null && variableType != VariableType.ARRAY) {
|
||||
val arr = getArr();
|
||||
if(arr != null)
|
||||
return arr.shape();
|
||||
}
|
||||
|
||||
return initialShape;
|
||||
return null;
|
||||
}
|
||||
|
||||
public void setShape(long... shape){
|
||||
|
@ -1488,8 +1449,8 @@ public class SDVariable implements Serializable {
|
|||
* @return
|
||||
*/
|
||||
public INDArray eval() {
|
||||
sameDiff.exec(null, getVarName());
|
||||
return getArr();
|
||||
Map<String,INDArray> m = sameDiff.output((Map<String,INDArray>)null, name());
|
||||
return m.get(name());
|
||||
}
|
||||
|
||||
|
||||
|
@ -1498,8 +1459,8 @@ public class SDVariable implements Serializable {
|
|||
* @return
|
||||
*/
|
||||
public INDArray eval(Map<String, INDArray> placeholders) {
|
||||
sameDiff.exec(placeholders, getVarName());
|
||||
return getArr();
|
||||
Map<String,INDArray> m = sameDiff.output(placeholders, name());
|
||||
return m.get(name());
|
||||
}
|
||||
|
||||
|
||||
|
@ -1519,7 +1480,7 @@ public class SDVariable implements Serializable {
|
|||
*/
|
||||
public void addControlDependency(SDVariable controlDependency){
|
||||
Variable vThis = sameDiff.getVariables().get(getVarName());
|
||||
Variable vCD = sameDiff.getVariables().get(controlDependency.getVarName());
|
||||
Variable vCD = sameDiff.getVariables().get(controlDependency.name());
|
||||
|
||||
//If possible: add control dependency on ops
|
||||
if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){
|
||||
|
@ -1729,7 +1690,6 @@ public class SDVariable implements Serializable {
|
|||
SDVariable v = new SDVariable();
|
||||
v.varName = varName;
|
||||
v.variableType = variableType;
|
||||
v.weightInitScheme = weightInitScheme;
|
||||
v.shape = shape == null ? null : shape.clone();
|
||||
v.dataType = dataType;
|
||||
v.sameDiff = sd;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -440,7 +440,7 @@ public class TrainingConfig {
|
|||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
return trainEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
return trainEvaluation(variable.name(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -468,7 +468,7 @@ public class TrainingConfig {
|
|||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
return validationEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
return validationEvaluation(variable.name(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -73,7 +73,7 @@ public class BatchOutputConfig {
|
|||
public BatchOutputConfig output(@NonNull SDVariable... outputs){
|
||||
String[] outNames = new String[outputs.length];
|
||||
for(int i = 0 ; i < outputs.length ; i++){
|
||||
outNames[i] = outputs[i].getVarName();
|
||||
outNames[i] = outputs[i].name();
|
||||
}
|
||||
|
||||
return output(outNames);
|
||||
|
@ -104,7 +104,7 @@ public class BatchOutputConfig {
|
|||
* See {@link #input(String, INDArray)}
|
||||
*/
|
||||
public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){
|
||||
return input(variable.getVarName(), placeholder);
|
||||
return input(variable.name(), placeholder);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -132,19 +132,35 @@ public class BatchOutputConfig {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #output()}
|
||||
*/
|
||||
@Deprecated
|
||||
public Map<String, INDArray> exec() {
|
||||
return output();
|
||||
}
|
||||
|
||||
/**
|
||||
* Do inference and return the results
|
||||
*/
|
||||
public Map<String, INDArray> exec(){
|
||||
public Map<String,INDArray> output(){
|
||||
return sd.output(placeholders, listeners, outputs.toArray(new String[0]));
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #outputSingle()}
|
||||
*/
|
||||
@Deprecated
|
||||
public INDArray execSingle() {
|
||||
return outputSingle();
|
||||
}
|
||||
|
||||
/**
|
||||
* Do inference and return the results for the single output
|
||||
*
|
||||
* Only works if exactly one output is specified
|
||||
*/
|
||||
public INDArray execSingle(){
|
||||
public INDArray outputSingle(){
|
||||
Preconditions.checkState(outputs.size() == 1,
|
||||
"Can only use execSingle() when exactly one output is specified, there were %s", outputs.size());
|
||||
return exec().get(outputs.get(0));
|
||||
|
|
|
@ -81,7 +81,7 @@ public class EvaluationConfig {
|
|||
* See {@link #evaluate(String, int, IEvaluation[])}
|
||||
*/
|
||||
public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
return evaluate(variable.getVarName(), labelIndex, evaluations);
|
||||
return evaluate(variable.name(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -106,7 +106,7 @@ public class EvaluationConfig {
|
|||
* See {@link #evaluate(String, IEvaluation[])}
|
||||
*/
|
||||
public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){
|
||||
return evaluate(variable.getVarName(), evaluations);
|
||||
return evaluate(variable.name(), evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -129,7 +129,7 @@ public class EvaluationConfig {
|
|||
* See {@link #labelIndex(String, int)}
|
||||
*/
|
||||
public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){
|
||||
return labelIndex(variable.getVarName(), labelIndex);
|
||||
return labelIndex(variable.name(), labelIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -75,7 +75,7 @@ public class OutputConfig {
|
|||
public OutputConfig output(@NonNull SDVariable... outputs) {
|
||||
String[] outNames = new String[outputs.length];
|
||||
for (int i = 0; i < outputs.length; i++) {
|
||||
outNames[i] = outputs[i].getVarName();
|
||||
outNames[i] = outputs[i].name();
|
||||
}
|
||||
|
||||
return output(outNames);
|
||||
|
|
|
@ -204,10 +204,10 @@ public abstract class AbstractSession<T, O> {
|
|||
VariableType vt = v.getVariableType();
|
||||
if (vt == VariableType.VARIABLE || vt == VariableType.CONSTANT) {
|
||||
ExecType et = vt == VariableType.VARIABLE ? ExecType.VARIABLE : ExecType.CONSTANT;
|
||||
ExecStep es = new ExecStep(et, v.getVarName(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
ExecStep es = new ExecStep(et, v.name(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
dt.addDependency(es, start);
|
||||
|
||||
Variable var = sameDiff.getVariables().get(v.getVarName());
|
||||
Variable var = sameDiff.getVariables().get(v.name());
|
||||
if (var.getControlDeps() != null) {
|
||||
addVarControlDeps(es, var); //Before this variable can be considered available for use, we need specified op to be executed
|
||||
}
|
||||
|
@ -668,11 +668,11 @@ public abstract class AbstractSession<T, O> {
|
|||
Variable v = sameDiff.getVariables().get(varName);
|
||||
VariableType vt = v.getVariable().getVariableType();
|
||||
if (vt == VariableType.VARIABLE) {
|
||||
return new ExecStep(ExecType.VARIABLE, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
return new ExecStep(ExecType.VARIABLE, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
} else if (vt == VariableType.PLACEHOLDER) {
|
||||
return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
return new ExecStep(ExecType.PLACEHOLDER, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
} else if (vt == VariableType.CONSTANT) {
|
||||
return new ExecStep(ExecType.CONSTANT, v.getVariable().getVarName(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
return new ExecStep(ExecType.CONSTANT, v.getVariable().name(), new FrameIter(OUTER_FRAME, 0, null));
|
||||
} else {
|
||||
//Array type. Must be output of an op
|
||||
String outOfOp = v.getOutputOfOp();
|
||||
|
|
|
@ -98,9 +98,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
//TODO we shouldn't be clearing this on every single iteration, in 99.5% of cases variables will be same as last iteration...
|
||||
for (SDVariable v : sameDiff.variables()) {
|
||||
if (v.getVariableType() == VariableType.CONSTANT) {
|
||||
arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.getVarName()));
|
||||
arrayUseTracker.addDependency(v.getArr(), new ConstantDep(v.name()));
|
||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||
arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.getVarName()));
|
||||
arrayUseTracker.addDependency(v.getArr(), new VariableDep(v.name()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -484,7 +484,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
|
||||
if (op instanceof TensorArray) {
|
||||
//Create a TensorArray
|
||||
VarId vid = outputFrameIter.toVarId(op.outputVariable().getVarName());
|
||||
VarId vid = outputFrameIter.toVarId(op.outputVariable().name());
|
||||
Preconditions.checkState(!tensorArrays.containsKey(vid), "TensorArray already exists for %s when executing TensorArrayV3", vid);
|
||||
tensorArrays.put(vid, new ArrayList<INDArray>());
|
||||
|
||||
|
@ -504,18 +504,18 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
|
||||
//Work out the frame/iteration:
|
||||
VarId v = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
VarId v = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
|
||||
if (v == null && allIterInputs != null) {
|
||||
v = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
v = lookup(inTensorArray.name(), allIterInputs, false);
|
||||
}
|
||||
|
||||
Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.getVarName());
|
||||
Preconditions.checkState(v != null, "Could not find input %s", inTensorArray.name());
|
||||
|
||||
while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) {
|
||||
while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
|
||||
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayRead
|
||||
//TODO also TensorArrayWrite, scatter, etc??
|
||||
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
|
||||
v = v.getParentFrame().toVarId(inTensorArray.getVarName());
|
||||
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
|
||||
v = v.getParentFrame().toVarId(inTensorArray.name());
|
||||
}
|
||||
|
||||
List<INDArray> list = getTensorArrays().get(v);
|
||||
|
@ -528,31 +528,31 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
//TensorArrayWrite - also has a scalar 0.0 that it returns...
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
//Work out the varid (frame/iteration) of the tensor array:
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
|
||||
if (tArr == null && allIterInputs != null) {
|
||||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
tArr = lookup(inTensorArray.name(), allIterInputs, false);
|
||||
}
|
||||
|
||||
Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.getVarName());
|
||||
Preconditions.checkState(tArr != null, "Could not find input %s", inTensorArray.name());
|
||||
|
||||
while (sameDiff.getVariableOutputOp(inTensorArray.getVarName()) instanceof Enter) {
|
||||
while (sameDiff.getVariableOutputOp(inTensorArray.name()) instanceof Enter) {
|
||||
//Handle the Enter case: this is like TensorArray -> Enter -> TensorArrayWrite
|
||||
//TODO also TensorArrayScatter, etc??
|
||||
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.getVarName()).arg();
|
||||
tArr = tArr.getParentFrame().toVarId(inTensorArray.getVarName());
|
||||
inTensorArray = sameDiff.getVariableOutputOp(inTensorArray.name()).arg();
|
||||
tArr = tArr.getParentFrame().toVarId(inTensorArray.name());
|
||||
}
|
||||
|
||||
//Input 0 is the TensorArray (or dummy variable that represents it) - but sometimes Enter, in TensorArray -> Enter -> TensorARrayRead
|
||||
//Input 1 is the index
|
||||
//Input 2 is the value to write
|
||||
|
||||
String idxName = op.arg(1).getVarName();
|
||||
String idxName = op.arg(1).name();
|
||||
SDVariable idxSDV = sameDiff.getVariable(idxName);
|
||||
INDArray idxArr = getArray(idxSDV, opInputs, allIterInputs);
|
||||
Preconditions.checkState(idxArr.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", idxArr);
|
||||
int idx = idxArr.getInt(0);
|
||||
|
||||
String inName = op.arg(2).getVarName();
|
||||
String inName = op.arg(2).name();
|
||||
SDVariable inSDV = sameDiff.getVariable(inName);
|
||||
INDArray arr = getArray(inSDV, opInputs, allIterInputs);
|
||||
Preconditions.checkState(arr != null, "Could not find array for %s", inName);
|
||||
|
@ -577,9 +577,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
//Index 0 is the TensorArray (or dummy variable that represents it)
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
//Work out the varid (frame/iteration) of the tensor array:
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
|
||||
if (tArr == null && allIterInputs != null) {
|
||||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
tArr = lookup(inTensorArray.name(), allIterInputs, false);
|
||||
}
|
||||
List<INDArray> l = tensorArrays.get(tArr);
|
||||
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
|
||||
|
@ -588,9 +588,9 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
return new INDArray[]{scalar};
|
||||
} else if (op instanceof TensorArrayConcat) {
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
|
||||
if (tArr == null && allIterInputs != null) {
|
||||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
tArr = lookup(inTensorArray.name(), allIterInputs, false);
|
||||
}
|
||||
List<INDArray> l = tensorArrays.get(tArr);
|
||||
|
||||
|
@ -605,14 +605,14 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
//Input 1: the indices (1d integer vector)
|
||||
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
|
||||
if (tArr == null && allIterInputs != null) {
|
||||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
tArr = lookup(inTensorArray.name(), allIterInputs, false);
|
||||
}
|
||||
List<INDArray> l = tensorArrays.get(tArr);
|
||||
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
|
||||
|
||||
String indicesName = op.arg(1).getVarName();
|
||||
String indicesName = op.arg(1).name();
|
||||
SDVariable indicesSDV = sameDiff.getVariable(indicesName);
|
||||
INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs);
|
||||
Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", idxArr, indicesName);
|
||||
|
@ -644,22 +644,22 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
//Input 2: The values to scatter
|
||||
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.getVarName());
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
TensorArray ta = (TensorArray) sameDiff.getVariableOutputOp(inTensorArray.name());
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
|
||||
if (tArr == null && allIterInputs != null) {
|
||||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
tArr = lookup(inTensorArray.name(), allIterInputs, false);
|
||||
}
|
||||
List<INDArray> l = tensorArrays.get(tArr);
|
||||
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
|
||||
|
||||
String indicesName = op.arg(1).getVarName();
|
||||
String indicesName = op.arg(1).name();
|
||||
SDVariable indicesSDV = sameDiff.getVariable(indicesName);
|
||||
INDArray idxArr = getArray(indicesSDV, opInputs, allIterInputs);
|
||||
Preconditions.checkState(idxArr.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", idxArr, indicesName);
|
||||
Preconditions.checkState(idxArr.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", idxArr.dataType(), indicesName);
|
||||
int[] idxs = idxArr.toIntVector();
|
||||
|
||||
String valuesName = op.arg(2).getVarName();
|
||||
String valuesName = op.arg(2).name();
|
||||
SDVariable valuesSDV = sameDiff.getVariable(valuesName);
|
||||
INDArray valuesArr = getArray(valuesSDV, opInputs, allIterInputs);
|
||||
|
||||
|
@ -697,18 +697,18 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
//Input 2: the size of each split (1d integer vector)
|
||||
|
||||
SDVariable inTensorArray = op.arg(0); //Dummy variable representing the tensor array
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.getVarName(), opInputs, false));
|
||||
VarId tArr = (opInputs == null ? null : lookup(inTensorArray.name(), opInputs, false));
|
||||
if (tArr == null && allIterInputs != null) {
|
||||
tArr = lookup(inTensorArray.getVarName(), allIterInputs, false);
|
||||
tArr = lookup(inTensorArray.name(), allIterInputs, false);
|
||||
}
|
||||
List<INDArray> l = tensorArrays.get(tArr);
|
||||
Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr);
|
||||
|
||||
String splitName = op.arg(1).getVarName();
|
||||
String splitName = op.arg(1).name();
|
||||
INDArray splitArr = getArray(sameDiff.getVariable(splitName), opInputs, allIterInputs);
|
||||
|
||||
|
||||
String sizeName = op.arg(2).getVarName();
|
||||
String sizeName = op.arg(2).name();
|
||||
SDVariable sizeSDV = sameDiff.getVariable(sizeName);
|
||||
INDArray sizeArr = getArray(sizeSDV, opInputs, allIterInputs);
|
||||
Preconditions.checkState(sizeArr.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", sizeArr, sizeName);
|
||||
|
@ -803,7 +803,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
VarId vid = lookup(s, opInputs, allIterInputs, true);
|
||||
args[i] = nodeOutputs.get(vid);
|
||||
}
|
||||
Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.getVarName());
|
||||
Preconditions.checkNotNull(args[i], "Could not parameterize op %s: array %s (variable %s) is null", opName, i, v.name());
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
@ -825,7 +825,6 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
return sdo;
|
||||
}
|
||||
|
||||
df.resolvePropertiesFromSameDiffBeforeExecution(); //TODO This is to be removed
|
||||
List<LongShapeDescriptor> outShape = customOp.calculateOutputShape();
|
||||
Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName());
|
||||
String[] outNames = df.outputVariablesNames();
|
||||
|
@ -918,7 +917,6 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
op.setZ(z);
|
||||
}
|
||||
}
|
||||
df.resolvePropertiesFromSameDiffBeforeExecution();
|
||||
}
|
||||
|
||||
return sdo;
|
||||
|
@ -926,12 +924,12 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
|||
|
||||
|
||||
protected INDArray getArray(SDVariable sdv, Collection<VarId> opInputs, Collection<VarId> allIterInputs) {
|
||||
String n = sdv.getVarName();
|
||||
String n = sdv.name();
|
||||
if (sdv.getVariableType() == VariableType.CONSTANT || sdv.getVariableType() == VariableType.VARIABLE) {
|
||||
return getConstantOrVariable(n);
|
||||
} else {
|
||||
VarId inVarId = lookup(n, opInputs, allIterInputs, false);
|
||||
Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.getVarName());
|
||||
Preconditions.checkState(inVarId != null, "Could not find array for variable %s", sdv.name());
|
||||
return nodeOutputs.get(inVarId);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,9 +88,9 @@ public class TrainingSession extends InferenceSession {
|
|||
continue;
|
||||
}
|
||||
|
||||
requiredActivations.add(grad.getVarName());
|
||||
requiredActivations.add(grad.name());
|
||||
|
||||
gradVarToVarMap.put(grad.getVarName(), s);
|
||||
gradVarToVarMap.put(grad.name(), s);
|
||||
}
|
||||
|
||||
//Set up losses
|
||||
|
|
|
@ -3266,7 +3266,7 @@ public abstract class SDBaseOps {
|
|||
|
||||
|
||||
if (cond_result.dataType() != DataType.BOOL)
|
||||
throw new IllegalStateException("Can not use " + cond_result.getVarName() + " as the condition of an While loop, the condition must be a boolean.");
|
||||
throw new IllegalStateException("Can not use " + cond_result.name() + " as the condition of an While loop, the condition must be a boolean.");
|
||||
|
||||
|
||||
final Set<String> alreadyEntered = Sets.newHashSet();
|
||||
|
@ -3275,7 +3275,7 @@ public abstract class SDBaseOps {
|
|||
for(int i = 0 ; i < loopVars.length ; i++){
|
||||
SDVariable[] s = f().switchOp(merged[i], cond_result);
|
||||
trueSwitches[i] = s[1];
|
||||
alreadyEntered.add(s[1].getVarName());
|
||||
alreadyEntered.add(s[1].name());
|
||||
exits[i] = f().exit(s[0]);
|
||||
}
|
||||
|
||||
|
@ -3290,17 +3290,17 @@ public abstract class SDBaseOps {
|
|||
@Override
|
||||
public SDVariable intercept(SDVariable argument) {
|
||||
|
||||
if(!declared.contains(argument.getVarName()))
|
||||
if(!declared.contains(argument.name()))
|
||||
return argument;
|
||||
|
||||
if(alreadyEntered.contains(argument.getVarName()))
|
||||
if(alreadyEntered.contains(argument.name()))
|
||||
return argument;
|
||||
|
||||
if(done.containsKey(argument.getVarName()))
|
||||
return done.get(argument.getVarName());
|
||||
if(done.containsKey(argument.name()))
|
||||
return done.get(argument.name());
|
||||
|
||||
SDVariable e = f().enter(argument, frameName, true);
|
||||
done.put(argument.getVarName(), e);
|
||||
done.put(argument.name(), e);
|
||||
return e;
|
||||
}
|
||||
});
|
||||
|
@ -3371,7 +3371,7 @@ public abstract class SDBaseOps {
|
|||
//cleanup partially added block
|
||||
|
||||
for(SDVariable v : sd().getVariablesInScope(ifScope))
|
||||
sd().getVariables().remove(v.getVarName());
|
||||
sd().getVariables().remove(v.name());
|
||||
|
||||
for(SameDiffOp op : sd().getOpsInScope(ifScope)) {
|
||||
for(String in : op.getInputsToOp()){
|
||||
|
@ -3381,7 +3381,7 @@ public abstract class SDBaseOps {
|
|||
}
|
||||
|
||||
|
||||
throw new IllegalStateException("Can not use " + pred.getVarName()
|
||||
throw new IllegalStateException("Can not use " + pred.name()
|
||||
+ " as the condition of an If statement, the condition must be a boolean.");
|
||||
}
|
||||
|
||||
|
@ -3394,15 +3394,15 @@ public abstract class SDBaseOps {
|
|||
public SDVariable intercept(SDVariable argument) {
|
||||
|
||||
// if its declared in the if, we don't care acout it
|
||||
if(!declared.contains(argument.getVarName()))
|
||||
if(!declared.contains(argument.name()))
|
||||
return argument;
|
||||
|
||||
// if we've already added a switch, move on
|
||||
if(switches.containsKey(argument.getVarName()))
|
||||
return switches.get(argument.getVarName())[1];
|
||||
if(switches.containsKey(argument.name()))
|
||||
return switches.get(argument.name())[1];
|
||||
|
||||
SDVariable[] s = f().switchOp(argument, pred);
|
||||
switches.put(argument.getVarName(), s);
|
||||
switches.put(argument.name(), s);
|
||||
return s[1];
|
||||
}
|
||||
});
|
||||
|
@ -3410,9 +3410,9 @@ public abstract class SDBaseOps {
|
|||
SDVariable trueOut = trueBody.define(sd());
|
||||
sd().removeArgumentInterceptor();
|
||||
|
||||
if(declared.contains(trueOut.getVarName())) {
|
||||
if(declared.contains(trueOut.name())) {
|
||||
SDVariable[] s = f().switchOp(trueOut, pred);
|
||||
switches.put(trueOut.getVarName(), s);
|
||||
switches.put(trueOut.name(), s);
|
||||
trueOut = s[1];
|
||||
}
|
||||
|
||||
|
@ -3424,15 +3424,15 @@ public abstract class SDBaseOps {
|
|||
public SDVariable intercept(SDVariable argument) {
|
||||
|
||||
// if its declared in the if, we don't care acout it
|
||||
if(!declared2.contains(argument.getVarName()))
|
||||
if(!declared2.contains(argument.name()))
|
||||
return argument;
|
||||
|
||||
// if we've already added a switch, move on
|
||||
if(switches.containsKey(argument.getVarName()))
|
||||
return switches.get(argument.getVarName())[0];
|
||||
if(switches.containsKey(argument.name()))
|
||||
return switches.get(argument.name())[0];
|
||||
|
||||
SDVariable[] s = f().switchOp(argument, pred);
|
||||
switches.put(argument.getVarName(), s);
|
||||
switches.put(argument.name(), s);
|
||||
return s[0];
|
||||
}
|
||||
});
|
||||
|
@ -3440,9 +3440,9 @@ public abstract class SDBaseOps {
|
|||
SDVariable falseOut = falseBody.define(sd());
|
||||
sd().removeArgumentInterceptor();
|
||||
|
||||
if(declared2.contains(falseOut.getVarName())) {
|
||||
if(declared2.contains(falseOut.name())) {
|
||||
SDVariable[] s = f().switchOp(falseOut, pred);
|
||||
switches.put(falseOut.getVarName(), s);
|
||||
switches.put(falseOut.name(), s);
|
||||
falseOut = s[0];
|
||||
}
|
||||
falseScope.close();
|
||||
|
|
|
@ -37,7 +37,7 @@ public class SDValidation {
|
|||
if (v == null)
|
||||
return;
|
||||
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-numerical data type " + v.dataType());
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-numerical data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -52,7 +52,7 @@ public class SDValidation {
|
|||
return;
|
||||
if (v.dataType() == DataType.BOOL || v.dataType() == DataType.UTF8)
|
||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an numerical type type; got variable \"" +
|
||||
v.getVarName() + "\" with non-integer data type " + v.dataType());
|
||||
v.name() + "\" with non-integer data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -65,8 +65,8 @@ public class SDValidation {
|
|||
*/
|
||||
protected static void validateNumerical(String opName, SDVariable v1, SDVariable v2) {
|
||||
if (v1.dataType() == DataType.BOOL || v1.dataType() == DataType.UTF8 || v2.dataType() == DataType.BOOL || v2.dataType() == DataType.UTF8)
|
||||
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" +
|
||||
v2.getVarName() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType());
|
||||
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" +
|
||||
v2.name() + "\" if one or both variables are non-numerical: " + v1.dataType() + " and " + v2.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -79,7 +79,7 @@ public class SDValidation {
|
|||
if (v == null)
|
||||
return;
|
||||
if (!v.dataType().isIntType())
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-integer data type " + v.dataType());
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-integer data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -94,7 +94,7 @@ public class SDValidation {
|
|||
return;
|
||||
if (!v.dataType().isIntType())
|
||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an integer type; got variable \"" +
|
||||
v.getVarName() + "\" with non-integer data type " + v.dataType());
|
||||
v.name() + "\" with non-integer data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -107,7 +107,7 @@ public class SDValidation {
|
|||
if (v == null)
|
||||
return;
|
||||
if (!v.dataType().isFPType())
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-floating point data type " + v.dataType());
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-floating point data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -122,7 +122,7 @@ public class SDValidation {
|
|||
return;
|
||||
if (!v.dataType().isFPType())
|
||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an floating point type; got variable \"" +
|
||||
v.getVarName() + "\" with non-floating point data type " + v.dataType());
|
||||
v.name() + "\" with non-floating point data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -135,7 +135,7 @@ public class SDValidation {
|
|||
if (v == null)
|
||||
return;
|
||||
if (v.dataType() != DataType.BOOL)
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.getVarName() + "\" with non-boolean point data type " + v.dataType());
|
||||
throw new IllegalStateException("Cannot apply operation \"" + opName + "\" to variable \"" + v.name() + "\" with non-boolean point data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -150,7 +150,7 @@ public class SDValidation {
|
|||
return;
|
||||
if (v.dataType() != DataType.BOOL)
|
||||
throw new IllegalStateException("Input \"" + inputName + "\" for operation \"" + opName + "\" must be an boolean variable; got variable \"" +
|
||||
v.getVarName() + "\" with non-boolean data type " + v.dataType());
|
||||
v.name() + "\" with non-boolean data type " + v.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -162,8 +162,8 @@ public class SDValidation {
|
|||
*/
|
||||
protected static void validateBool(String opName, SDVariable v1, SDVariable v2) {
|
||||
if (v1.dataType() != DataType.BOOL || v2.dataType() != DataType.BOOL)
|
||||
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.getVarName() + "\" and \"" +
|
||||
v2.getVarName() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType());
|
||||
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" on variables \"" + v1.name() + "\" and \"" +
|
||||
v2.name() + "\" if one or both variables are non-boolean: " + v1.dataType() + " and " + v2.dataType());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -190,7 +190,7 @@ public class SDValidation {
|
|||
String[] names = new String[vars.length];
|
||||
DataType[] dtypes = new DataType[vars.length];
|
||||
for (int j = 0; j < vars.length; j++) {
|
||||
names[j] = vars[j].getVarName();
|
||||
names[j] = vars[j].name();
|
||||
dtypes[j] = vars[j].dataType();
|
||||
}
|
||||
throw new IllegalStateException("Cannot perform operation \"" + opName + "\" to variables with different datatypes:" +
|
||||
|
|
|
@ -763,7 +763,7 @@ public class FlatBuffersMapper {
|
|||
|
||||
SDVariable[] inputs = node.args();
|
||||
for (SDVariable input : inputs) {
|
||||
String varName = input.getVarName();
|
||||
String varName = input.name();
|
||||
int outIdx;
|
||||
if (sameDiff.getVariables().get(varName).getOutputOfOp() != null) {
|
||||
DifferentialFunction df = sameDiff.getOps().get(sameDiff.getVariables().get(varName).getOutputOfOp()).getOp();
|
||||
|
|
|
@ -69,8 +69,8 @@ public class GraphTransformUtil {
|
|||
// we want to end up with (x -> A -> z)
|
||||
List<DifferentialFunction> allSubGraphFns = sg.allFunctionsInSubgraph();
|
||||
for (int i = 0; i < oldOutputs.size(); i++) {
|
||||
String oldOutVarName = oldOutputs.get(i).getVarName();
|
||||
String newOutVarName = newOutputs.get(i).getVarName();
|
||||
String oldOutVarName = oldOutputs.get(i).name();
|
||||
String newOutVarName = newOutputs.get(i).name();
|
||||
Preconditions.checkState(!oldOutVarName.equals(newOutVarName), "Reusing old variables not yet implemented");
|
||||
|
||||
//Update inputs for ops: if X->opA, and now Y->opA, then X.inputsForOps contains "opA"; Y.inputsForOps should be updated
|
||||
|
@ -133,7 +133,7 @@ public class GraphTransformUtil {
|
|||
//Step 2: Update input variables: if X -> (subgraph) exists, then X.inputsForOp needs to be updated
|
||||
List<SDVariable> inputs = sg.inputs();
|
||||
for (SDVariable v : inputs) {
|
||||
Variable var = sd.getVariables().get(v.getVarName());
|
||||
Variable var = sd.getVariables().get(v.name());
|
||||
if (var.getInputsForOp() != null) {
|
||||
List<String> newInputsForOp = new ArrayList<>(var.getInputsForOp());
|
||||
for (String opName : var.getInputsForOp()) {
|
||||
|
@ -160,7 +160,7 @@ public class GraphTransformUtil {
|
|||
SDVariable[] outputs = df.outputVariables();
|
||||
if (outputs != null) {
|
||||
for (SDVariable v : outputs) {
|
||||
vars.remove(v.getVarName());
|
||||
vars.remove(v.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -62,7 +62,7 @@ public class SubGraph {
|
|||
//But suppose same subgraph, but connection y -> a exists; then Y must be an output, because it's used somewhere else
|
||||
List<SDVariable> filteredOutputs = new ArrayList<>(allOutputs.size());
|
||||
for(SDVariable v : allOutputs){
|
||||
Variable var = sameDiff.getVariables().get(v.getVarName());
|
||||
Variable var = sameDiff.getVariables().get(v.name());
|
||||
List<String> inputsFor = var.getInputsForOp();
|
||||
boolean allInSubgraph = true;
|
||||
if(inputsFor != null){
|
||||
|
|
|
@ -77,7 +77,7 @@ public class SubGraphPredicate extends OpPredicate {
|
|||
}
|
||||
|
||||
SDVariable in = inputs[inNum];
|
||||
DifferentialFunction df = sameDiff.getVariableOutputOp(in.getVarName());
|
||||
DifferentialFunction df = sameDiff.getVariableOutputOp(in.name());
|
||||
if (df == null || !e.getValue().matches(sameDiff, df)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -103,7 +103,7 @@ public class SubGraphPredicate extends OpPredicate {
|
|||
for(Map.Entry<Integer,OpPredicate> entry : opInputSubgraphPredicates.entrySet()){
|
||||
OpPredicate p2 = entry.getValue();
|
||||
SDVariable arg = rootFn.arg(entry.getKey());
|
||||
DifferentialFunction df = sd.getVariableOutputOp(arg.getVarName());
|
||||
DifferentialFunction df = sd.getVariableOutputOp(arg.name());
|
||||
if(df != null){
|
||||
childNodes.add(df);
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ public class GradCheckUtil {
|
|||
Set<String> fnOutputs = new HashSet<>();
|
||||
for(DifferentialFunction f : sd.ops()){
|
||||
for(SDVariable s : f.outputVariables()){
|
||||
fnOutputs.add(s.getVarName());
|
||||
fnOutputs.add(s.name());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -171,7 +171,7 @@ public class GradCheckUtil {
|
|||
|
||||
Map<String,INDArray> grad = new HashMap<>();
|
||||
for(SDVariable v : sd.variables()){
|
||||
if (fnOutputs.contains(v.getVarName())) {
|
||||
if (fnOutputs.contains(v.name())) {
|
||||
//This is not an input to the graph
|
||||
continue;
|
||||
}
|
||||
|
@ -179,20 +179,20 @@ public class GradCheckUtil {
|
|||
//Skip non-fp variables, or variables that don't impact loss function value
|
||||
continue;
|
||||
}
|
||||
SDVariable g = sd.grad(v.getVarName());
|
||||
SDVariable g = sd.grad(v.name());
|
||||
if(g == null){
|
||||
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
|
||||
throw new IllegalStateException("Null gradient variable for \"" + v.name() + "\"");
|
||||
}
|
||||
INDArray ga = gm.get(v.getVarName());
|
||||
INDArray ga = gm.get(v.name());
|
||||
if(ga == null){
|
||||
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
|
||||
throw new IllegalStateException("Null gradient array encountered for variable: " + v.name());
|
||||
}
|
||||
if(!Arrays.equals(v.getArr().shape(), ga.shape())){
|
||||
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" +
|
||||
v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " +
|
||||
v.name() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " +
|
||||
Arrays.toString(ga.shape()));
|
||||
}
|
||||
grad.put(v.getVarName(), ga.dup());
|
||||
grad.put(v.name(), ga.dup());
|
||||
}
|
||||
|
||||
//Validate gradients for each variable:
|
||||
|
@ -201,25 +201,25 @@ public class GradCheckUtil {
|
|||
double maxError = 0.0;
|
||||
Random r = new Random(12345);
|
||||
for(SDVariable s : sd.variables()){
|
||||
if (fnOutputs.contains(s.getVarName()) || !s.dataType().isFPType()) {
|
||||
if (fnOutputs.contains(s.name()) || !s.dataType().isFPType()) {
|
||||
//This is not an input to the graph, or is not a floating point input (so can't be gradient checked)
|
||||
continue;
|
||||
}
|
||||
|
||||
if(skipVariables != null && skipVariables.contains(s.getVarName())){
|
||||
log.info("Grad check: skipping variable \"{}\"", s.getVarName());
|
||||
if(skipVariables != null && skipVariables.contains(s.name())){
|
||||
log.info("Grad check: skipping variable \"{}\"", s.name());
|
||||
continue;
|
||||
}
|
||||
|
||||
if(s.dataType() != DataType.DOUBLE){
|
||||
log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.getVarName(), s.dataType());
|
||||
log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", s.name(), s.dataType());
|
||||
}
|
||||
|
||||
String name = s.getVarName();
|
||||
String name = s.name();
|
||||
INDArray a = s.getArr();
|
||||
long n = a.length();
|
||||
if(print){
|
||||
log.info("Starting test for variable \"{}\" with {} values", s.getVarName(), n);
|
||||
log.info("Starting test for variable \"{}\" with {} values", s.name(), n);
|
||||
}
|
||||
|
||||
Iterator<long[]> iter;
|
||||
|
@ -256,11 +256,11 @@ public class GradCheckUtil {
|
|||
iter = new NdIndexIterator('c',a.shape());
|
||||
}
|
||||
|
||||
INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.getVarName()));
|
||||
INDArray varMask = (gradCheckMask == null ? null : gradCheckMask.get(s.name()));
|
||||
|
||||
if(varMask != null){
|
||||
Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.getVarName(), a.shape(), varMask.shape());
|
||||
Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.getVarName(), varMask.dataType());
|
||||
Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", s.name(), a.shape(), varMask.shape());
|
||||
Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", s.name(), varMask.dataType());
|
||||
}
|
||||
|
||||
int i=0;
|
||||
|
@ -281,12 +281,12 @@ public class GradCheckUtil {
|
|||
double orig = a.getDouble(idx);
|
||||
a.putScalar(idx, orig+eps);
|
||||
double scorePlus = 0.0;
|
||||
Map<String,INDArray> m = sd.exec(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue();
|
||||
Map<String,INDArray> m = sd.output(placeholderValues, lossFnVariables);//.get(outName).sumNumber().doubleValue();
|
||||
for(INDArray arr : m.values()){
|
||||
scorePlus += arr.sumNumber().doubleValue();
|
||||
}
|
||||
a.putScalar(idx, orig-eps);
|
||||
m = sd.exec(placeholderValues, lossFnVariables);
|
||||
m = sd.output(placeholderValues, lossFnVariables);
|
||||
double scoreMinus = 0.0;
|
||||
for(INDArray arr : m.values()){
|
||||
scoreMinus += arr.sumNumber().doubleValue();
|
||||
|
@ -294,9 +294,9 @@ public class GradCheckUtil {
|
|||
a.putScalar(idx, orig);
|
||||
|
||||
double numericalGrad = (scorePlus - scoreMinus) / (2 * eps);
|
||||
INDArray aGrad = grad.get(s.getVarName());
|
||||
INDArray aGrad = grad.get(s.name());
|
||||
if(aGrad == null){
|
||||
log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.getVarName());
|
||||
log.warn("No gradient array for variable \"{}\" was found, skipping variable...", s.name());
|
||||
continue;
|
||||
}
|
||||
double analyticGrad = aGrad.getDouble(idx);
|
||||
|
@ -497,12 +497,12 @@ public class GradCheckUtil {
|
|||
listener.setIdx(idx);
|
||||
listener.setEps(config.getEps());
|
||||
double scorePlus = 0.0;
|
||||
Map<String,INDArray> m = sd.exec(config.getPlaceholderValues(), lossFnVariables);
|
||||
Map<String,INDArray> m = sd.output(config.getPlaceholderValues(), lossFnVariables);
|
||||
for(INDArray arr : m.values()){
|
||||
scorePlus += arr.sumNumber().doubleValue();
|
||||
}
|
||||
listener.setEps(-config.getEps());
|
||||
m = sd.exec(config.getPlaceholderValues(), lossFnVariables);
|
||||
m = sd.output(config.getPlaceholderValues(), lossFnVariables);
|
||||
double scoreMinus = 0.0;
|
||||
for(INDArray arr : m.values()){
|
||||
scoreMinus += arr.sumNumber().doubleValue();
|
||||
|
@ -597,10 +597,10 @@ public class GradCheckUtil {
|
|||
|
||||
Set<String> varSetStr = new HashSet<>();
|
||||
for(SDVariable v : vars){
|
||||
if(varSetStr.contains(v.getVarName())){
|
||||
throw new IllegalStateException("Variable with name " + v.getVarName() + " already encountered");
|
||||
if(varSetStr.contains(v.name())){
|
||||
throw new IllegalStateException("Variable with name " + v.name() + " already encountered");
|
||||
}
|
||||
varSetStr.add(v.getVarName());
|
||||
varSetStr.add(v.name());
|
||||
}
|
||||
Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list");
|
||||
|
||||
|
@ -645,7 +645,7 @@ public class GradCheckUtil {
|
|||
Map<String, Variable> variableMap = sd.getVariables();
|
||||
Preconditions.checkState(vars.size() == variableMap.size(), "Variable map size check failed");
|
||||
for(Map.Entry<String, Variable> e : variableMap.entrySet()){
|
||||
Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().getVarName()), "Name not equal");
|
||||
Preconditions.checkState(e.getKey().equals(e.getValue().getVariable().name()), "Name not equal");
|
||||
}
|
||||
|
||||
if(generateAndCheckGradFn) {
|
||||
|
|
|
@ -208,7 +208,7 @@ public class OpValidation {
|
|||
e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
|
||||
}
|
||||
|
||||
INDArray actual = out.get(v.getVarName());
|
||||
INDArray actual = out.get(v.name());
|
||||
if (actual == null) {
|
||||
throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\"");
|
||||
}
|
||||
|
@ -271,8 +271,8 @@ public class OpValidation {
|
|||
for( int i=0; i<vars.size(); i++ ){
|
||||
SDVariable vO = vars.get(i);
|
||||
SDVariable vD = varsDe.get(i);
|
||||
Preconditions.checkState(vO.getVarName().equals(vD.getVarName()), "Names should be equal for variable %s: expected %s vs %s",
|
||||
i, vO.getVarName(), vD.getVarName());
|
||||
Preconditions.checkState(vO.name().equals(vD.name()), "Names should be equal for variable %s: expected %s vs %s",
|
||||
i, vO.name(), vD.name());
|
||||
}
|
||||
|
||||
//Check ops:
|
||||
|
|
|
@ -121,7 +121,7 @@ public class TestCase {
|
|||
* @param output Expected INDArray
|
||||
*/
|
||||
public TestCase expected(@NonNull SDVariable var, @NonNull INDArray output) {
|
||||
return expected(var.getVarName(), output);
|
||||
return expected(var.name(), output);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -135,7 +135,7 @@ public class TestCase {
|
|||
}
|
||||
|
||||
public TestCase expected(SDVariable var, Function<INDArray,String> validationFn){
|
||||
return expected(var.getVarName(), validationFn);
|
||||
return expected(var.name(), validationFn);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -487,11 +487,15 @@ public class LogFileWriter {
|
|||
|
||||
//Create outputs list:
|
||||
List<String> outputs = sd.outputs();
|
||||
int outputsOffset = 0;
|
||||
if(outputs != null && !outputs.isEmpty()) {
|
||||
int[] outputListStrOffsets = new int[outputs.size()];
|
||||
for (int i = 0; i < outputListStrOffsets.length; i++) {
|
||||
outputListStrOffsets[i] = fbb.createString(outputs.get(i));
|
||||
}
|
||||
int outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets);
|
||||
outputsOffset = UIGraphStructure.createInputsVector(fbb, outputListStrOffsets);
|
||||
}
|
||||
|
||||
|
||||
//Create variables list
|
||||
Map<String,Variable> varMap = sd.getVariables();
|
||||
|
|
|
@ -46,6 +46,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.BarnesEdgeForces.class,
|
||||
org.nd4j.linalg.api.ops.custom.BarnesHutGains.class,
|
||||
org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize.class,
|
||||
org.nd4j.linalg.api.ops.custom.KnnMinDistance.class,
|
||||
org.nd4j.linalg.api.ops.custom.SpTreeCell.class,
|
||||
org.nd4j.linalg.api.ops.custom.Flatten.class,
|
||||
org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd.class,
|
||||
|
@ -322,7 +323,6 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.BinCount.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.CheckNumerics.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.Cholesky.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.Constant.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.Histogram.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.IdentityN.class,
|
||||
|
|
|
@ -810,8 +810,6 @@ public class TFGraphMapper {
|
|||
on.setValueFor(currentField, tensor.getFloat(0));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
on.getSameDiff().addPropertyToResolve(on, entry.getKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,19 +63,11 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp
|
|||
this.sameDiff = sameDiff;
|
||||
this.inPlace = inPlace;
|
||||
this.dimension = dimension;
|
||||
if(Shape.isPlaceholderShape(i_v1.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
|
||||
}
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v2.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
|
||||
}
|
||||
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
|
||||
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variables.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public BaseBroadcastBoolOp(SameDiff sameDiff) {
|
||||
|
|
|
@ -64,19 +64,10 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp {
|
|||
this.sameDiff = sameDiff;
|
||||
this.inPlace = inPlace;
|
||||
this.dimension = dimension;
|
||||
if(Shape.isPlaceholderShape(i_v1.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
|
||||
}
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v2.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
|
||||
}
|
||||
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
|
||||
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variables.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public BaseBroadcastOp(SameDiff sameDiff) {
|
||||
|
|
|
@ -53,11 +53,8 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
|
|||
this.dimensions = dimensions;
|
||||
f().validateDifferentialFunctionsameDiff(i_v);
|
||||
sameDiff.addArgsFor(new SDVariable[]{i_v},this);
|
||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
||||
}
|
||||
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variable.");
|
||||
}
|
||||
|
@ -75,17 +72,9 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum
|
|||
this.dimensions = dimensions;
|
||||
f().validateDifferentialFunctionsameDiff(i_v);
|
||||
f().validateDifferentialFunctionsameDiff(i_v2);
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.yVertexId = i_v2.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
this.yVertexId = i_v2.name();
|
||||
sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this);
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
||||
}
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v2.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variable.");
|
||||
}
|
||||
|
|
|
@ -247,7 +247,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
val outputNames = sameDiff.getOutputsForOp(this);
|
||||
//no need to dynamically create if already exists
|
||||
if(outputNames != null) {
|
||||
zVertexId = sameDiff.getVariable(outputNames[0]).getVarName();
|
||||
zVertexId = sameDiff.getVariable(outputNames[0]).name();
|
||||
|
||||
|
||||
return new SDVariable[]{sameDiff.getVariable(outputNames[0])};
|
||||
|
@ -261,7 +261,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
|
|||
return newVars;
|
||||
}
|
||||
|
||||
sameDiff.setArrayForVariable(newVars[0].getVarName(),inputArr);
|
||||
sameDiff.setArrayForVariable(newVars[0].name(),inputArr);
|
||||
z = inputArr;
|
||||
if(sameDiff.getOutputsForOp(this) == null)
|
||||
sameDiff.addOutgoingFor(newVars,this);
|
||||
|
|
|
@ -61,7 +61,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
|||
this.dimensions = dimensions;
|
||||
f().validateDifferentialFunctionsameDiff(i_v);
|
||||
this.keepDims = keepDims;
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
sameDiff.addArgsFor(new String[]{xVertexId},this);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variable.");
|
||||
|
@ -81,8 +81,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
|
|||
|
||||
this.dimensions = dimensions;
|
||||
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.yVertexId = i_v2.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
this.yVertexId = i_v2.name();
|
||||
f().validateDifferentialFunctionsameDiff(i_v);
|
||||
f().validateDifferentialFunctionsameDiff(i_v2);
|
||||
this.keepDims = keepDims;
|
||||
|
|
|
@ -74,11 +74,8 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp {
|
|||
super(sameDiff,inPlace,extraArgs);
|
||||
this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar);
|
||||
if (i_v != null) {
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
sameDiff.addArgsFor(new String[]{xVertexId},this);
|
||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
||||
}
|
||||
f().validateDifferentialFunctionsameDiff(i_v);
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variable.");
|
||||
|
|
|
@ -93,11 +93,8 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp {
|
|||
Object[] extraArgs) {
|
||||
super(sameDiff,inPlace,extraArgs);
|
||||
this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar);
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
sameDiff.addArgsFor(new String[]{xVertexId},this);
|
||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
||||
}
|
||||
f().validateDifferentialFunctionsameDiff(i_v);
|
||||
}
|
||||
|
||||
|
|
|
@ -56,16 +56,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp {
|
|||
f().validateDifferentialFunctionsameDiff(i_v2);
|
||||
this.sameDiff = sameDiff;
|
||||
this.inPlace = inPlace;
|
||||
this.xVertexId = i_v1.getVarName();
|
||||
this.yVertexId = i_v2.getVarName();
|
||||
this.xVertexId = i_v1.name();
|
||||
this.yVertexId = i_v2.name();
|
||||
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
|
||||
if(Shape.isPlaceholderShape(i_v1.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
|
||||
}
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v2.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
|
||||
}
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variables.");
|
||||
}
|
||||
|
@ -87,18 +80,9 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp {
|
|||
f().validateDifferentialFunctionsameDiff(i_v1);
|
||||
f().validateDifferentialFunctionsameDiff(i_v2);
|
||||
this.sameDiff = sameDiff;
|
||||
this.xVertexId = i_v1.getVarName();
|
||||
this.yVertexId = i_v2.getVarName();
|
||||
this.xVertexId = i_v1.name();
|
||||
this.yVertexId = i_v2.name();
|
||||
sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this);
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v1.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v1.getVarName());
|
||||
}
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v2.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v2.getVarName());
|
||||
}
|
||||
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input not null variables.");
|
||||
}
|
||||
|
@ -130,14 +114,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp {
|
|||
|
||||
if (i_v != null) {
|
||||
f().validateDifferentialFunctionsameDiff(i_v);
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
sameDiff.addArgsFor(new SDVariable[]{i_v},this);
|
||||
|
||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
||||
}
|
||||
|
||||
|
||||
} else {
|
||||
throw new IllegalArgumentException("Input must not null variable.");
|
||||
}
|
||||
|
|
|
@ -223,7 +223,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
|
|||
if (args().length >= 1) {
|
||||
val arr = args()[0].getArr();
|
||||
if (arr != null) {
|
||||
sameDiff.setArrayForVariable(newVars[0].getVarName(), arr);
|
||||
sameDiff.setArrayForVariable(newVars[0].name(), arr);
|
||||
addOutputArgument(arr);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,7 +57,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
|||
public ExternalErrorsFunction(){ }
|
||||
|
||||
public String getGradPlaceholderName(){
|
||||
return arg().getVarName() + "-grad";
|
||||
return arg().name() + "-grad";
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -70,7 +70,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
|||
out = sameDiff.getVariable(name);
|
||||
} else {
|
||||
out = sameDiff.zero(name, Nd4j.dataType(), 1);
|
||||
sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.getVarName()));
|
||||
sameDiff.getOps().get(getOwnName()).setOutputsOfOp(Collections.singletonList(out.name()));
|
||||
sameDiff.getVariables().get(name).setOutputOfOp(getOwnName());
|
||||
}
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
|||
if (gradVariables == null) {
|
||||
gradVariables = new HashMap<>();
|
||||
for(SDVariable arg : args()){
|
||||
INDArray gradArr = gradients.get(arg.getVarName());
|
||||
INDArray gradArr = gradients.get(arg.name());
|
||||
SDVariable grad;
|
||||
DataType dt = arg.dataType();
|
||||
String n = getGradPlaceholderName();
|
||||
|
@ -94,7 +94,7 @@ public class ExternalErrorsFunction extends DynamicCustomOp {
|
|||
} else {
|
||||
grad = sameDiff.var(n, VariableType.PLACEHOLDER, null, dt);
|
||||
}
|
||||
gradVariables.put(arg.getVarName(), grad);
|
||||
gradVariables.put(arg.name(), grad);
|
||||
out.add(grad);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -196,12 +196,12 @@ public class DeConv2D extends DynamicCustomOp {
|
|||
val paddingMode = aPadding.getS().toStringUtf8();
|
||||
|
||||
val args = args();
|
||||
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
||||
INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
|
||||
if (arr == null) {
|
||||
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
||||
// TODO: arguable. it might be easier to permute weights once
|
||||
//arr = (arr.permute(3, 2, 0, 1).dup('c'));
|
||||
val varForOp = initWith.getVariable(args[1].getVarName());
|
||||
val varForOp = initWith.getVariable(args[1].name());
|
||||
if (arr != null)
|
||||
initWith.associateArrayWithVariable(arr, varForOp);
|
||||
|
||||
|
|
|
@ -158,10 +158,10 @@ public class DeConv3D extends DynamicCustomOp {
|
|||
val paddingMode = aPadding.getS().toStringUtf8();
|
||||
|
||||
val args = args();
|
||||
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
|
||||
INDArray arr = sameDiff.getVariable(args[1].name()).getArr();
|
||||
if (arr == null) {
|
||||
arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
|
||||
val varForOp = initWith.getVariable(args[1].getVarName());
|
||||
val varForOp = initWith.getVariable(args[1].name());
|
||||
if (arr != null)
|
||||
initWith.associateArrayWithVariable(arr, varForOp);
|
||||
}
|
||||
|
|
|
@ -193,12 +193,6 @@ public class Mmul extends DynamicCustomOp {
|
|||
.transposeA(isTransposeA).transposeB(isTransposeB)
|
||||
.build();
|
||||
this.mt = mMulTranspose;
|
||||
val args = args();
|
||||
for(val arg : args) {
|
||||
if(sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) {
|
||||
sameDiff.addPropertyToResolve(this,arg.getVarName());
|
||||
}
|
||||
}
|
||||
iArguments.clear();
|
||||
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()));
|
||||
}
|
||||
|
|
|
@ -130,10 +130,7 @@ public class Concat extends DynamicCustomOp {
|
|||
|
||||
val variable = initWith.getVariable(input);
|
||||
// concat dimension is only possible
|
||||
if (variable != null && variable.getArr() == null) {
|
||||
sameDiff.addPropertyToResolve(this, input);
|
||||
|
||||
} else if (variable != null) {
|
||||
if (variable != null) {
|
||||
val arr = variable.getArr();
|
||||
if (arr.length() == 1) {
|
||||
concatDimension = arr.getInt(0);
|
||||
|
|
|
@ -124,13 +124,7 @@ public class Transpose extends DynamicCustomOp {
|
|||
return;
|
||||
}
|
||||
|
||||
INDArray arr = sameDiff.getArrForVarName(arg().getVarName());
|
||||
if (arr == null) {
|
||||
val arrVar = sameDiff.getVariable(arg().getVarName());
|
||||
|
||||
arr = arrVar.getWeightInitScheme().create(arrVar.dataType(), arrVar.getShape());
|
||||
sameDiff.setArrayForVariable(arg().getVarName(), arr);
|
||||
}
|
||||
INDArray arr = sameDiff.getArrForVarName(arg().name());
|
||||
|
||||
if(permuteArrayOp != null){
|
||||
addInputArgument(arr, permuteArrayOp);
|
||||
|
@ -138,16 +132,12 @@ public class Transpose extends DynamicCustomOp {
|
|||
addInputArgument(arr);
|
||||
}
|
||||
|
||||
|
||||
|
||||
if (arr != null && permuteDims == null) {
|
||||
this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
|
||||
}
|
||||
|
||||
if (permuteDims != null && permuteDims.length < arg().getShape().length)
|
||||
throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -72,7 +72,7 @@ public class TensorArrayConcat extends BaseTensorOp {
|
|||
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
|
||||
//Same output type as the TensorArray - which is defined by input 0
|
||||
SDVariable tArr = arg(0);
|
||||
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName());
|
||||
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name());
|
||||
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
|
||||
return Collections.singletonList(dt);
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ public class TensorArrayGather extends BaseTensorOp {
|
|||
public List<DataType> calculateOutputDataTypes(java.util.List<org.nd4j.linalg.api.buffer.DataType> inputDataType){
|
||||
//Same output type as the TensorArray - which is defined by input 0
|
||||
SDVariable tArr = arg(0);
|
||||
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.getVarName());
|
||||
TensorArray t3 = (TensorArray) sameDiff.getVariableOutputOp(tArr.name());
|
||||
org.nd4j.linalg.api.buffer.DataType dt = t3.getTensorArrayDataType();
|
||||
return Collections.singletonList(dt);
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ public class TensorArrayRead extends BaseTensorOp {
|
|||
dt = importDataType;
|
||||
} else {
|
||||
SDVariable tArr = arg(0);
|
||||
DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.getVarName());
|
||||
DifferentialFunction op = sameDiff.getVariableOutputOp(tArr.name());
|
||||
TensorArray t3 = (TensorArray) op;
|
||||
dt = t3.getTensorArrayDataType();
|
||||
}
|
||||
|
|
|
@ -71,9 +71,9 @@ public class CheckNumerics extends DynamicCustomOp {
|
|||
SDVariable msg = initWith.constant(name + "/message", Nd4j.scalar(str));
|
||||
List<String> newInputs = new ArrayList<>(2);
|
||||
newInputs.addAll(initWith.getOps().get(name).getInputsToOp());
|
||||
newInputs.add(msg.getVarName());
|
||||
newInputs.add(msg.name());
|
||||
initWith.getOps().get(name).setInputsToOp(newInputs);
|
||||
initWith.getVariables().get(msg.getVarName()).setInputsForOp(Collections.singletonList(getOwnName())); }
|
||||
initWith.getVariables().get(msg.name()).setInputsForOp(Collections.singletonList(getOwnName())); }
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
|
|
|
@ -1,91 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
||||
import org.nd4j.linalg.api.ops.BaseTransformSameOp;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@Data
|
||||
public class Constant extends BaseTransformSameOp {
|
||||
|
||||
|
||||
public Constant() {
|
||||
}
|
||||
|
||||
|
||||
protected Constant(SameDiff sameDiff,
|
||||
SDVariable i_v,
|
||||
long[] shape,
|
||||
boolean inPlace) {
|
||||
super();
|
||||
sameDiff.putOrUpdateShapeForVarName(i_v.getVarName(), shape, false);
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.inPlace = inPlace;
|
||||
this.sameDiff = sameDiff;
|
||||
}
|
||||
|
||||
public Constant(SameDiff sameDiff, SDVariable i_v, long[] shape) {
|
||||
this(sameDiff, i_v, shape, false);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public DifferentialFunction dup() {
|
||||
Constant ret = new Constant(sameDiff, sameDiff.getVariable(outputVariables()[0].getVarName())
|
||||
, sameDiff.getShapeForVarName(outputVariables()[0].getVarName()));
|
||||
Constant differentialFunction = ret;
|
||||
return differentialFunction;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int opNum() {
|
||||
return 15;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "constant";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
throw new NoOpNameFoundException("No tensorflow opName found for " + opName());
|
||||
}
|
||||
|
||||
}
|
|
@ -118,7 +118,7 @@ public class MaxOut extends BaseTransformOp {
|
|||
if(arg() == null)
|
||||
throw new ND4JIllegalStateException("No arg found for op!");
|
||||
|
||||
val arr = sameDiff.getArrForVarName(arg().getVarName());
|
||||
val arr = sameDiff.getArrForVarName(arg().name());
|
||||
if(arr == null)
|
||||
return Collections.emptyList();
|
||||
|
||||
|
|
|
@ -28,13 +28,19 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
*
|
||||
* TanhDerivative: calculated dL/dIn from dL/dOut and In
|
||||
*/
|
||||
public class TanhDerivative extends DynamicCustomOp {
|
||||
public TanhDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param x Input
|
||||
* @param y Gradient at output (dL/dOut)
|
||||
* @param z Output array, gradient at input (dL/dIn - to be calculated)
|
||||
*/
|
||||
public TanhDerivative(INDArray x, INDArray y, INDArray z) {
|
||||
super(null, new INDArray[]{x, y}, new INDArray[]{z});
|
||||
}
|
||||
|
@ -42,6 +48,10 @@ public class TanhDerivative extends DynamicCustomOp {
|
|||
public TanhDerivative() {
|
||||
}
|
||||
|
||||
/**
|
||||
* @param x Input
|
||||
* @param y Gradient at output (dL/dOut)
|
||||
*/
|
||||
public TanhDerivative(INDArray x, INDArray y) {
|
||||
this(x, y, null);
|
||||
}
|
||||
|
|
|
@ -43,11 +43,8 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
|
|||
public BaseRandomOp(SameDiff sameDiff, SDVariable i_v) {
|
||||
Preconditions.checkNotNull(i_v, "Input variable can't be null with this constructor");
|
||||
this.sameDiff = sameDiff;
|
||||
this.xVertexId = i_v.getVarName();
|
||||
this.xVertexId = i_v.name();
|
||||
sameDiff.addArgsFor(new String[]{xVertexId},this);
|
||||
if(Shape.isPlaceholderShape(i_v.getShape())) {
|
||||
sameDiff.addPropertyToResolve(this,i_v.getVarName());
|
||||
}
|
||||
}
|
||||
|
||||
public BaseRandomOp(SameDiff sd, long[] shape){
|
||||
|
@ -73,11 +70,7 @@ public abstract class BaseRandomOp extends BaseOp implements RandomOp {
|
|||
if(shape != null){
|
||||
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Nd4j.defaultFloatingPointType()));
|
||||
} else {
|
||||
List<LongShapeDescriptor> ret = new ArrayList<>(1);
|
||||
val shape = sameDiff.getShapeForVarName(args()[0].getVarName());
|
||||
if (shape != null)
|
||||
ret.add(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType())));
|
||||
return ret;
|
||||
return Collections.singletonList(LongShapeDescriptor.fromShape(shape, Shape.pickPairwiseDataType(args()[0].dataType(), Nd4j.dataType())));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -5212,6 +5212,8 @@ public class Nd4j {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
backend.logBackendInit();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
@ -5625,19 +5627,38 @@ public class Nd4j {
|
|||
* @return an ndarray created from the in memory
|
||||
* numpy pointer
|
||||
*/
|
||||
|
||||
@SuppressWarnings("WeakerAccess")
|
||||
public static INDArray createFromNpyPointer(Pointer pointer) {
|
||||
return INSTANCE.createFromNpyPointer(pointer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create from a given Numpy .npy file.
|
||||
* Create an INDArray from a given Numpy .npy file.
|
||||
*
|
||||
* @param path Path to the .npy file to read
|
||||
* @return the created ndarray
|
||||
*/
|
||||
public static INDArray readNpy(@NonNull String path){
|
||||
return readNpy(new File(path));
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an INDArray from a given Numpy .npy file.
|
||||
*
|
||||
* @param file the file to create the ndarray from
|
||||
* @return the created ndarray
|
||||
*/
|
||||
public static INDArray createFromNpyFile(File file) {
|
||||
public static INDArray readNpy(@NonNull File file){
|
||||
return createFromNpyFile(file);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an INDArray from a given Numpy .npy file.
|
||||
*
|
||||
* @param file the file to create the ndarray from
|
||||
* @return the created ndarray
|
||||
*/
|
||||
public static INDArray createFromNpyFile(@NonNull File file) {
|
||||
if (!file.exists())
|
||||
throw new IllegalArgumentException("File [" + file.getAbsolutePath() + "] doesn't exist");
|
||||
|
||||
|
@ -5654,7 +5675,7 @@ public class Nd4j {
|
|||
* @return the loaded ndarray
|
||||
*/
|
||||
@SuppressWarnings("unused")
|
||||
public static INDArray createNpyFromInputStream(InputStream is) throws IOException {
|
||||
public static INDArray createNpyFromInputStream(@NonNull InputStream is) throws IOException {
|
||||
byte[] content = IOUtils.toByteArray(is);
|
||||
return createNpyFromByteArray(content);
|
||||
}
|
||||
|
@ -5668,7 +5689,7 @@ public class Nd4j {
|
|||
* @param input the input byte array with the npy format
|
||||
* @return the equivalent {@link INDArray}
|
||||
*/
|
||||
public static INDArray createNpyFromByteArray(byte[] input) {
|
||||
public static INDArray createNpyFromByteArray(@NonNull byte[] input) {
|
||||
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(input.length);
|
||||
byteBuffer.put(input);
|
||||
byteBuffer.rewind();
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.util.Properties;
|
|||
import lombok.Getter;
|
||||
import org.bytedeco.javacpp.Loader;
|
||||
import org.nd4j.config.ND4JEnvironmentVars;
|
||||
import org.nd4j.config.ND4JSystemProperties;
|
||||
import org.nd4j.context.Nd4jContext;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.slf4j.Logger;
|
||||
|
@ -101,7 +102,12 @@ public class NativeOpsHolder {
|
|||
}
|
||||
//deviceNativeOps.setOmpNumThreads(4);
|
||||
|
||||
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
|
||||
boolean logInit = Boolean.parseBoolean(logInitProperty);
|
||||
|
||||
if(logInit) {
|
||||
log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
|
||||
}
|
||||
} catch (Exception | Error e) {
|
||||
throw new RuntimeException(
|
||||
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",
|
||||
|
|
|
@ -47,7 +47,7 @@ public class MemoryTracker {
|
|||
|
||||
val f = new AtomicLong(NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceFreeMemory(i));
|
||||
|
||||
log.debug("Free memory on device_{}: {}", i, f);
|
||||
//log.debug("Free memory on device_{}: {}", i, f);
|
||||
freePerDevice.add(i, f);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,14 +16,24 @@
|
|||
|
||||
package org.nd4j.linalg.jcublas;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.javacpp.Loader;
|
||||
import org.nd4j.config.ND4JSystemProperties;
|
||||
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.io.ClassPathResource;
|
||||
import org.nd4j.linalg.io.Resource;
|
||||
import org.nd4j.nativeblas.Nd4jCuda;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@Slf4j
|
||||
public class JCublasBackend extends Nd4jBackend {
|
||||
|
||||
|
||||
|
@ -76,4 +86,34 @@ public class JCublasBackend extends Nd4jBackend {
|
|||
return JCublasNDArray.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void logBackendInit() {
|
||||
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
|
||||
boolean logInit = Boolean.parseBoolean(logInitProperty);
|
||||
|
||||
if(logInit) {
|
||||
try {
|
||||
Nd4jCuda.Environment e = Nd4jCuda.Environment.getInstance();
|
||||
int blasMajor = e.blasMajorVersion();
|
||||
int blasMinor = e.blasMinorVersion();
|
||||
int blasPatch = e.blasPatchVersion();
|
||||
log.info("ND4J CUDA build version: {}.{}.{}", blasMajor, blasMinor, blasPatch);
|
||||
int nGPUs = Nd4jEnvironment.getEnvironment().getNumGpus();
|
||||
|
||||
Properties props = Nd4j.getExecutioner().getEnvironmentInformation();
|
||||
List<Map<String, Object>> devicesList = (List<Map<String, Object>>) props.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY);
|
||||
|
||||
for (int i = 0; i < nGPUs; i++) {
|
||||
Map<String, Object> dev = devicesList.get(i);
|
||||
String name = (String) dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY);
|
||||
int major = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY)).intValue();
|
||||
int minor = ((Number) dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY)).intValue();
|
||||
long totalMem = ((Number) dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY)).longValue();
|
||||
log.info("CUDA device {}: [{}]; cc: [{}.{}]; Total memory: [{}]", i, name, major, minor, totalMem);
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
log.debug("Error logging CUDA backend versions and devices", t);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1890,14 +1890,6 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void printEnvironmentInformation() {
|
||||
super.printEnvironmentInformation();
|
||||
|
||||
Properties env = getEnvironmentInformation();
|
||||
|
||||
List<Map<String, Object>> devicesList = (List<Map<String, Object>>) env.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY);
|
||||
for (Map<String, Object> dev : devicesList) {
|
||||
log.info("Device Name: [{}]; CC: [{}.{}]; Total/free memory: [{}]", dev.get(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY),
|
||||
dev.get(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY), dev.get(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -60,4 +60,9 @@ public class CpuBackend extends Nd4jBackend {
|
|||
public Class getNDArrayClass() {
|
||||
return NDArray.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void logBackendInit() {
|
||||
//No additional logging for CPU backend
|
||||
}
|
||||
}
|
||||
|
|
|
@ -146,8 +146,8 @@ public class TestSessions extends BaseNd4jTest {
|
|||
|
||||
System.out.println("----------------------------------");
|
||||
InferenceSession is = new InferenceSession(sd);
|
||||
// String outName = merge.getVarName();
|
||||
String outName = outVar.getVarName();
|
||||
// String outName = merge.name();
|
||||
String outName = outVar.name();
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList(outName), m, null,
|
||||
Collections.<String>emptyList(), null, At.defaultAt(Operation.TRAINING));
|
||||
|
||||
|
@ -181,7 +181,7 @@ public class TestSessions extends BaseNd4jTest {
|
|||
m.put("b", bArr);
|
||||
|
||||
InferenceSession is = new InferenceSession(sd);
|
||||
String n = merge.getVarName();
|
||||
String n = merge.name();
|
||||
|
||||
System.out.println("----------------------------------");
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList(n), m, null, Collections.<String>emptyList(),
|
||||
|
|
|
@ -118,7 +118,7 @@ public class GraphExecutionerTest extends BaseNd4jTest {
|
|||
SDVariable result = sdVariable.add(scalarOne);
|
||||
SDVariable total = sameDiff.sum(result,Integer.MAX_VALUE);
|
||||
|
||||
log.info("TOTAL: {}; Id: {}", total.getVarName(), total);
|
||||
log.info("TOTAL: {}; Id: {}", total.name(), total);
|
||||
|
||||
INDArray[] resB = executionerB.executeGraph(sameDiff, configVarSpace);
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(res.getVarName(), exp);
|
||||
.expectedOutput(res.name(), exp);
|
||||
|
||||
System.out.println(sameDiff.summary());
|
||||
System.out.println("============================");
|
||||
|
@ -112,7 +112,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(res.getVarName(), exp);
|
||||
.expectedOutput(res.name(), exp);
|
||||
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
|
@ -137,7 +137,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
|
||||
TestCase tc = new TestCase(sameDiff)
|
||||
.gradientCheck(true)
|
||||
.expectedOutput(res.getVarName(), exp);
|
||||
.expectedOutput(res.name(), exp);
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
assertNull(err);
|
||||
|
@ -591,7 +591,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable out = sd.cnn().sconv2d(vars, c);
|
||||
out = sd.nn().tanh("out", out);
|
||||
|
||||
INDArray outArr = sd.execAndEndResult();
|
||||
INDArray outArr = out.eval();
|
||||
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
||||
val outShape = outArr.shape();
|
||||
assertArrayEquals(new long[]{mb, depthWise * nIn, 27, 27}, outShape);
|
||||
|
@ -637,7 +637,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable out = sd.cnn().sconv2d(vars, c);
|
||||
out = sd.nn().tanh("out", out);
|
||||
|
||||
INDArray outArr = sd.execAndEndResult();
|
||||
INDArray outArr = out.eval();
|
||||
//Expected output size: out = (in - k + 2*p)/s + 1 = (8-2+0)/1+1 = 7
|
||||
val outShape = outArr.shape();
|
||||
assertArrayEquals(new long[]{mb, nOut, 7, 7}, outShape);
|
||||
|
@ -688,7 +688,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable out = sd.cnn().deconv2d(vars, deconv);
|
||||
out = sd.nn().tanh("out", out);
|
||||
|
||||
INDArray outArr = sd.execAndEndResult();
|
||||
INDArray outArr = out.eval();
|
||||
//Expected output size: out = (in + k + 2*p)/ s - 1 = (8 + 2+0)/1 - 1 = 9
|
||||
val outShape = outArr.shape();
|
||||
assertArrayEquals(new long[]{mb, nOut, 9, 9}, outShape);
|
||||
|
@ -736,7 +736,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable out = sd.cnn().conv2d("conv", vars, c);
|
||||
out = sd.nn().tanh("out", out);
|
||||
|
||||
INDArray outArr = sd.execAndEndResult();
|
||||
INDArray outArr = out.eval();
|
||||
//Expected output size: out = (in - k + 2*p)/s + 1 = (28-2+0)/1+1 = 27
|
||||
val outShape = outArr.shape();
|
||||
assertArrayEquals(new long[]{mb, nOut, 27, 27}, outShape);
|
||||
|
@ -770,7 +770,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable outPool = sd.cnn().maxPooling2d(in, pooling2DConfig);
|
||||
SDVariable out = sd.nn().tanh("out", outPool);
|
||||
|
||||
INDArray outArr = sd.execAndEndResult();
|
||||
INDArray outArr = out.eval();
|
||||
val outShape = outArr.shape();
|
||||
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
||||
assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape);
|
||||
|
@ -828,7 +828,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
SDVariable outPool = sd.cnn().avgPooling2d(in, pooling2DConfig);
|
||||
SDVariable out = sd.nn().tanh("out", outPool);
|
||||
|
||||
INDArray outArr = sd.execAndEndResult();
|
||||
INDArray outArr = out.eval();
|
||||
val outShape = outArr.shape();
|
||||
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
|
||||
assertArrayEquals(new long[]{mb, nIn, 7, 7}, outShape);
|
||||
|
@ -996,7 +996,7 @@ public class LayerOpValidation extends BaseOpValidation {
|
|||
}
|
||||
);
|
||||
|
||||
TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.getVarName(), expected);
|
||||
TestCase tc = new TestCase(sd).gradientCheck(false).expectedOutput(res.name(), expected);
|
||||
String err = OpValidation.validate(tc);
|
||||
|
||||
assertNull(err);
|
||||
|
|
|
@ -423,11 +423,11 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
}
|
||||
|
||||
SDVariable loss = sd.sum(scatter); //.standardDeviation(scatter, true); //.sum(scatter); //TODO stdev might be better here as gradients are non-symmetrical...
|
||||
sd.execAndEndResult();
|
||||
|
||||
|
||||
TestCase tc = new TestCase(sd)
|
||||
.expected(scatter, exp)
|
||||
.gradCheckSkipVariables(indices.getVarName());
|
||||
.gradCheckSkipVariables(indices.name());
|
||||
|
||||
String error = OpValidation.validate(tc);
|
||||
if(error != null){
|
||||
|
@ -493,7 +493,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
TestCase tc = new TestCase(sd)
|
||||
.testName(msg)
|
||||
.gradCheckSkipVariables(indices.getVarName());
|
||||
.gradCheckSkipVariables(indices.name());
|
||||
|
||||
if (gatherExp != null) {
|
||||
tc.expected(gather, gatherExp);
|
||||
|
@ -589,16 +589,16 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
Map<String,INDArray> m = sameDiff.outputAll(null);
|
||||
Map<String,INDArray> gm = sameDiff.calculateGradients(null, m.keySet());
|
||||
|
||||
SDVariable finalResult = sameDiff.grad(sum.getVarName());
|
||||
SDVariable finalResult = sameDiff.grad(sum.name());
|
||||
|
||||
SDVariable cGrad = sameDiff.grad(varMulPre.getVarName());
|
||||
SDVariable cGrad = sameDiff.grad(varMulPre.name());
|
||||
|
||||
SDVariable mulGradResult = sameDiff.grad(varMul.getVarName());
|
||||
SDVariable aGrad = sameDiff.grad(sdVariable.getVarName());
|
||||
SDVariable wGrad = sameDiff.grad(sdVariable1.getVarName());
|
||||
SDVariable dGrad = sameDiff.grad(varMul.getVarName());
|
||||
SDVariable mulGradResult = sameDiff.grad(varMul.name());
|
||||
SDVariable aGrad = sameDiff.grad(sdVariable.name());
|
||||
SDVariable wGrad = sameDiff.grad(sdVariable1.name());
|
||||
SDVariable dGrad = sameDiff.grad(varMul.name());
|
||||
|
||||
INDArray scalarGradTest = gm.get(sum.getVarName());
|
||||
INDArray scalarGradTest = gm.get(sum.name());
|
||||
assertEquals(scalar, scalarGradTest);
|
||||
|
||||
|
||||
|
@ -738,11 +738,10 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
SDVariable B2 = sd.var("B2", B);
|
||||
|
||||
SDVariable[] batchMul = sd.batchMmul(new SDVariable[] {A1, A2}, new SDVariable[] {B1, B2});
|
||||
sd.exec(Collections.emptyMap(), sd.outputs());
|
||||
|
||||
INDArray resultingMatrix = batchMul[0].getArr();
|
||||
System.out.print(resultingMatrix);
|
||||
Map<String,INDArray> m = sd.output(Collections.emptyMap(), sd.outputs());
|
||||
|
||||
INDArray resultingMatrix = m.get(batchMul[0].name());
|
||||
//System.out.print(resultingMatrix);
|
||||
}
|
||||
|
||||
|
||||
|
@ -770,14 +769,14 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
SDVariable mmul = sd.f().mmul(f, s, mt);
|
||||
sd.updateVariableNameAndReference(mmul, "mmul");
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = mmul.eval();
|
||||
|
||||
INDArray exp = first.transpose().mmul(second);
|
||||
assertEquals(exp, out);
|
||||
|
||||
SDVariable loss = sd.standardDeviation(mmul, true);
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expected(mmul.getVarName(), exp));
|
||||
.expected(mmul.name(), exp));
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
|
@ -1287,7 +1286,7 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
SDVariable var = sd.var("in", i);
|
||||
SDVariable diag = sd.math().diagPart(var);
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = diag.eval();
|
||||
assertEquals(1, out.rank());
|
||||
}
|
||||
|
||||
|
@ -1644,10 +1643,10 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
SDVariable v = new StopGradient(sd, w).outputVariable();
|
||||
SDVariable loss = v.std(true);
|
||||
|
||||
Map<String,INDArray> gm = sd.calculateGradients(null, v.getVarName(), w.getVarName());
|
||||
Map<String,INDArray> gm = sd.calculateGradients(null, v.name(), w.name());
|
||||
|
||||
INDArray vArr = gm.get(v.getVarName());
|
||||
INDArray wArr = gm.get(w.getVarName());
|
||||
INDArray vArr = gm.get(v.name());
|
||||
INDArray wArr = gm.get(w.name());
|
||||
|
||||
System.out.println(vArr);
|
||||
System.out.println(wArr);
|
||||
|
@ -1669,18 +1668,18 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
INDArray expLoss = in.std(true);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput(checkNumerics.getVarName(), in)
|
||||
.expectedOutput(checkNumerics.name(), in)
|
||||
.placeholderValue("in", in)
|
||||
.expectedOutput("loss", expLoss));
|
||||
Preconditions.checkState(err == null, err);
|
||||
|
||||
|
||||
//Also check that it actually does what it's supposed to:
|
||||
sd.execAll(Collections.singletonMap("in", in));
|
||||
sd.outputAll(Collections.singletonMap("in", in));
|
||||
|
||||
in.putScalar(0, Double.NaN);
|
||||
try {
|
||||
sd.execAll(Collections.singletonMap("in", in));
|
||||
sd.outputAll(Collections.singletonMap("in", in));
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
//OK
|
||||
|
@ -1688,14 +1687,14 @@ public class MiscOpValidation extends BaseOpValidation {
|
|||
|
||||
in.putScalar(0, Double.POSITIVE_INFINITY);
|
||||
try {
|
||||
sd.execAll(Collections.singletonMap("in", in));
|
||||
sd.outputAll(Collections.singletonMap("in", in));
|
||||
fail("Expected exception");
|
||||
} catch (Throwable t){
|
||||
//OK
|
||||
}
|
||||
|
||||
in.putScalar(0, 0.0);
|
||||
sd.execAll(Collections.singletonMap("in", in));
|
||||
sd.outputAll(Collections.singletonMap("in", in));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -117,8 +117,8 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
SDVariable loss = nonZero.add(zero).castTo(DataType.DOUBLE).std(true);
|
||||
|
||||
String error = OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput(nonZero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0))
|
||||
.expectedOutput(zero.getVarName(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0))
|
||||
.expectedOutput(nonZero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 4.0))
|
||||
.expectedOutput(zero.name(), Nd4j.scalar(DataType.LONG, i == 0 ? 2.0 : 0.0))
|
||||
.gradientCheck(false)
|
||||
);
|
||||
if (error != null)
|
||||
|
@ -148,7 +148,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
SDVariable zeroFraction = sd.math().zeroFraction(input);
|
||||
|
||||
String error = OpValidation.validate(new TestCase(sd)
|
||||
.expectedOutput(zeroFraction.getVarName(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f))
|
||||
.expectedOutput(zeroFraction.name(), Nd4j.scalar(i == 0 ? 0.5f : 0.0f))
|
||||
.gradientCheck(i != 0)
|
||||
);
|
||||
if (error != null)
|
||||
|
@ -429,7 +429,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
tc.gradientCheck(gradientCheckable);
|
||||
if(exp != null){
|
||||
tc.expectedOutput(loss.getVarName(), exp);
|
||||
tc.expectedOutput(loss.name(), exp);
|
||||
}
|
||||
|
||||
String error = OpValidation.validate(tc);
|
||||
|
@ -996,7 +996,7 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
|
||||
String msg = name + " - dims=" + Arrays.toString(reduceDims);
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = reduced.eval();
|
||||
|
||||
log.info(msg + " - expected shape: " + Arrays.toString(expShape) + ", out=" + Arrays.toString(out.shape())
|
||||
+ ", outExp=" + Arrays.toString(expOut.shape()));
|
||||
|
@ -1069,10 +1069,10 @@ public class ReductionOpValidation extends BaseOpValidation {
|
|||
sd.associateArrayWithVariable(inputArr, input);
|
||||
sd.associateArrayWithVariable(labelArr, label);
|
||||
|
||||
INDArray result = sd.execAndEndResult();
|
||||
INDArray result = loss.eval();
|
||||
assertEquals(1, result.length());
|
||||
|
||||
sd.execBackwards(Collections.emptyMap());
|
||||
sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -76,11 +76,11 @@ public class RnnOpValidation extends BaseOpValidation {
|
|||
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
|
||||
List<String> toExec = new ArrayList<>();
|
||||
for(SDVariable sdv : v.getAllOutputs()){
|
||||
toExec.add(sdv.getVarName());
|
||||
toExec.add(sdv.name());
|
||||
}
|
||||
|
||||
//Test forward pass:
|
||||
Map<String,INDArray> m = sd.exec(null, toExec);
|
||||
Map<String,INDArray> m = sd.output(null, toExec);
|
||||
|
||||
//Weights and bias order: [i, f, z, o]
|
||||
|
||||
|
@ -179,11 +179,11 @@ public class RnnOpValidation extends BaseOpValidation {
|
|||
LSTMCellOutputs v = sd.rnn().lstmCell(x, cLast, yLast, weights, conf); //Output order: i, c, f, o, z, h, y
|
||||
List<String> toExec = new ArrayList<>();
|
||||
for(SDVariable sdv : v.getAllOutputs()){
|
||||
toExec.add(sdv.getVarName());
|
||||
toExec.add(sdv.name());
|
||||
}
|
||||
|
||||
//Test forward pass:
|
||||
Map<String,INDArray> m = sd.exec(null, toExec);
|
||||
Map<String,INDArray> m = sd.output(null, toExec);
|
||||
|
||||
INDArray out0 = Nd4j.create(new float[]{0.27817473f, 0.53092605f}, new int[]{1,2}); //Input mod gate
|
||||
INDArray out1 = Nd4j.create(new float[]{-0.18100877f, 0.19417824f}, new int[]{1,2}); //CS (pre tanh)
|
||||
|
@ -233,11 +233,11 @@ public class RnnOpValidation extends BaseOpValidation {
|
|||
List<SDVariable> v = sd.rnn().gru("gru", x, hLast, weights).getAllOutputs();
|
||||
List<String> toExec = new ArrayList<>();
|
||||
for(SDVariable sdv : v){
|
||||
toExec.add(sdv.getVarName());
|
||||
toExec.add(sdv.name());
|
||||
}
|
||||
|
||||
//Test forward pass:
|
||||
Map<String,INDArray> m = sd.exec(null, toExec);
|
||||
Map<String,INDArray> m = sd.output(null, toExec);
|
||||
|
||||
//Weights and bias order: [r, u], [c]
|
||||
|
||||
|
|
|
@ -128,7 +128,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
//Using stdev here: mean/sum would backprop the same gradient for each input...
|
||||
SDVariable stdev = sd.standardDeviation("out", reshape, true);
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = stdev.eval();
|
||||
INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
|
||||
|
||||
String msg = "toShape=" + Arrays.toString(toShape) + ", order=" + order;
|
||||
|
@ -247,7 +247,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
Map<String,INDArray> m = sd.outputAll(null);
|
||||
INDArray expOut = in.getArr().std(true);
|
||||
|
||||
assertArrayEquals(expExpandShape, m.get(expand.getVarName()).shape());
|
||||
assertArrayEquals(expExpandShape, m.get(expand.name()).shape());
|
||||
INDArray expExpand = inArr.dup('c').reshape(expExpandShape);
|
||||
|
||||
String msg = "expandDim=" + i + ", source=" + p.getSecond();
|
||||
|
@ -256,7 +256,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
TestCase tc = new TestCase(sd);
|
||||
tc.testName(msg)
|
||||
.expectedOutput("out", expOut)
|
||||
.expectedOutput(expand.getVarName(), expExpand);
|
||||
.expectedOutput(expand.name(), expExpand);
|
||||
|
||||
String error = OpValidation.validate(tc);
|
||||
if(error != null){
|
||||
|
@ -306,17 +306,17 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
Map<String,INDArray> m = sd.outputAll(null);
|
||||
|
||||
INDArray squeezed = m.get(squeeze.getVarName());
|
||||
INDArray squeezed = m.get(squeeze.name());
|
||||
// assertArrayEquals(expShapePostSqueeze, squeezed.shape());
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = m.get(stdev.name());
|
||||
INDArray expOut = in.getArr().std(true, Integer.MAX_VALUE);
|
||||
assertEquals(expOut, out);
|
||||
|
||||
String msg = "squeezeDim=" + i + ", source=" + p.getSecond();
|
||||
TestCase tc = new TestCase(sd)
|
||||
.testName(msg)
|
||||
.expected(squeeze.getVarName(), exp)
|
||||
.expected(squeeze.name(), exp)
|
||||
.expectedOutput("out", expOut);
|
||||
|
||||
|
||||
|
@ -618,7 +618,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
SDVariable stack = sd.stack(axis, in);
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = stack.eval();
|
||||
assertArrayEquals(expOutShape, out.shape());
|
||||
|
||||
if (ArrayUtil.prodLong(shape) == 1) {
|
||||
|
@ -714,7 +714,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
Map<String,INDArray> m = sd.outputAll(null);
|
||||
for (SDVariable v : unstacked) {
|
||||
assertArrayEquals(msg, shape, m.get(v.getVarName()).shape());
|
||||
assertArrayEquals(msg, shape, m.get(v.name()).shape());
|
||||
}
|
||||
|
||||
TestCase tc = new TestCase(sd).testName(msg);
|
||||
|
@ -884,7 +884,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
INDArray exp = arr.dup('c').reshape('c', 4,3);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sameDiff)
|
||||
.expectedOutput(result1.getVarName(), exp));
|
||||
.expectedOutput(result1.name(), exp));
|
||||
|
||||
assertNull(err);
|
||||
}
|
||||
|
@ -920,7 +920,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable result = sameDiff.transpose(x);
|
||||
SDVariable loss = sameDiff.standardDeviation(result, true);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.getVarName(), arr.transpose()));
|
||||
String err = OpValidation.validate(new TestCase(sameDiff).expectedOutput(result.name(), arr.transpose()));
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
|
@ -1022,17 +1022,16 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SameDiff sd = SameDiff.create();
|
||||
INDArray ia = Nd4j.create(new double[]{1,2,3});
|
||||
SDVariable in = sd.var(ia);
|
||||
SDVariable constant = sd.constant(in, 3);
|
||||
SDVariable loss = constant.std(true);
|
||||
SDVariable loss = in.std(true);
|
||||
|
||||
assertNull(OpValidation.validate(new TestCase(sd).expected(constant, ia)));
|
||||
assertNull(OpValidation.validate(new TestCase(sd).expected(in, ia)));
|
||||
|
||||
//Case 1: shape is provided + scalar
|
||||
|
||||
sd = SameDiff.create();
|
||||
ia = Nd4j.scalar(3.0);
|
||||
in = sd.var(ia);
|
||||
constant = sd.constant(in, 3,4,5);
|
||||
SDVariable constant = sd.constant(Nd4j.create(DataType.FLOAT, 3,4,5));
|
||||
INDArray exp = Nd4j.valueArrayOf(new long[]{3,4,5}, 3.0);
|
||||
loss = constant.std(true);
|
||||
|
||||
|
@ -1149,7 +1148,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
SDVariable loss = sameDiff.standardDeviation(result, true);
|
||||
String err = OpValidation.validate(new TestCase(sameDiff)
|
||||
.expected(result.getVarName(), expected)
|
||||
.expected(result.name(), expected)
|
||||
.gradientCheck(false));
|
||||
assertNull(err);
|
||||
}
|
||||
|
@ -1172,7 +1171,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
INDArray outExp = Nd4j.scalar(d);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expected(md.getVarName(), outExp));
|
||||
.expected(md.name(), outExp));
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
|
@ -1196,7 +1195,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
INDArray outExp = Nd4j.scalar(d);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expected(md.getVarName(), outExp));
|
||||
.expected(md.name(), outExp));
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
|
@ -1227,7 +1226,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
INDArray outExp = Nd4j.scalar(d);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expected(md.getVarName(), outExp));
|
||||
.expected(md.name(), outExp));
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
|
@ -1247,7 +1246,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
//System.out.println(d);
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.expected(md.getVarName(), Nd4j.scalar(d)));
|
||||
.expected(md.name(), Nd4j.scalar(d)));
|
||||
assertNull(err);
|
||||
}
|
||||
|
||||
|
@ -1332,7 +1331,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
.testName(op)
|
||||
.expected(sm, exp)
|
||||
.gradientCheck(true)
|
||||
.gradCheckSkipVariables(segments.getVarName());
|
||||
.gradCheckSkipVariables(segments.name());
|
||||
|
||||
String err = OpValidation.validate(tc);
|
||||
if(err != null)
|
||||
|
@ -1383,7 +1382,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
String err = OpValidation.validate(new TestCase(sameDiff)
|
||||
.expected(result1, expected)
|
||||
.gradCheckSkipVariables(lengths.getVarName()));
|
||||
.gradCheckSkipVariables(lengths.name()));
|
||||
assertNull(err);
|
||||
|
||||
// Test with dynamic maxlen
|
||||
|
@ -1591,8 +1590,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
INDArray arr = Transforms.sigmoid(Nd4j.linspace(1, 6, 6).reshape(2, 3));
|
||||
SDVariable x = sameDiff.var("x", arr);
|
||||
SDVariable result = sameDiff.permute(x, 1, 0);
|
||||
sameDiff.execAll(null);
|
||||
assertArrayEquals(new long[]{3, 2}, result.getShape());
|
||||
Map<String,INDArray> m = sameDiff.outputAll(null);
|
||||
assertArrayEquals(new long[]{3, 2}, m.get(result.name()).shape());
|
||||
|
||||
}
|
||||
|
||||
|
@ -1629,10 +1628,10 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable slice_full = sd.slice(in, new int[]{0, 0}, new int[]{3, 4});
|
||||
SDVariable subPart = sd.slice(in, new int[]{1, 2}, new int[]{2, 2});
|
||||
|
||||
sd.exec(Collections.emptyMap(), sd.outputs());
|
||||
Map<String,INDArray> m = sd.outputAll(Collections.emptyMap());
|
||||
|
||||
assertEquals(inArr, slice_full.getArr());
|
||||
assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr());
|
||||
assertEquals(inArr, m.get(slice_full.name()));
|
||||
assertEquals(inArr.get(interval(1, 3), interval(2, 4)), m.get(subPart.name()));
|
||||
}
|
||||
|
||||
|
||||
|
@ -1645,10 +1644,10 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable slice_full = sd.slice(in, new int[]{0, 0, 0}, new int[]{3, 4, 5});
|
||||
SDVariable subPart = sd.slice(in, new int[]{1, 2, 3}, new int[]{2, 2, 1});
|
||||
|
||||
sd.exec(Collections.emptyMap(), sd.outputs());
|
||||
Map<String,INDArray> m = sd.outputAll(null);
|
||||
|
||||
assertEquals(inArr, slice_full.getArr());
|
||||
assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), subPart.getArr());
|
||||
assertEquals(inArr, m.get(slice_full.name()));
|
||||
assertEquals(inArr.get(interval(1, 3), interval(2, 4), interval(3, 4)), m.get(subPart.name()));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1661,7 +1660,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable subPart = sd.stridedSlice(in, new int[]{1, 2}, new int[]{3, 4}, new int[]{1, 1});
|
||||
// SDVariable subPart2 = sd.stridedSlice(in, new int[]{0, 0}, new int[]{4, 5}, new int[]{2, 2});
|
||||
|
||||
sd.execAll(null);
|
||||
sd.outputAll(null);
|
||||
|
||||
assertEquals(inArr, slice_full.getArr());
|
||||
assertEquals(inArr.get(interval(1, 3), interval(2, 4)), subPart.getArr());
|
||||
|
@ -1678,7 +1677,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable slice1 = sd.stridedSlice(in, new int[]{-999, 0}, new int[]{2, 4}, new int[]{1, 1}, 1 << 1, 0, 0, 0, 0);
|
||||
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 0}, new int[]{-999, 4}, new int[]{1, 1}, 0, 1, 0, 0, 0);
|
||||
|
||||
sd.execAll(null);
|
||||
sd.outputAll(null);
|
||||
|
||||
assertEquals(inArr.get(NDArrayIndex.interval(0, 2), NDArrayIndex.all()), slice1.getArr());
|
||||
assertEquals(inArr.get(NDArrayIndex.interval(1, 3), NDArrayIndex.all()), slice2.getArr());
|
||||
|
@ -1695,7 +1694,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
//[1:3,...,1:4] -> [1:3,:,1:4]
|
||||
SDVariable slice2 = sd.stridedSlice(in, new int[]{1, 1}, new int[]{3, 4}, new int[]{1, 1}, 0, 0, 1 << 1, 0, 0);
|
||||
|
||||
sd.execAll(Collections.emptyMap());
|
||||
sd.outputAll(Collections.emptyMap());
|
||||
|
||||
assertEquals(inArr.get(interval(1, 3), all(), all()), slice.getArr());
|
||||
assertEquals(inArr.get(interval(1, 3), all(), all()), slice2.getArr());
|
||||
|
@ -1708,7 +1707,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable in = sd.var("in", inArr);
|
||||
SDVariable slice = sd.stridedSlice(in, new int[]{-999, 0, 0, 0}, new int[]{-999, 3, 4, 5}, new int[]{-999, 1, 1, 1}, 0, 0, 0, 1, 0);
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = slice.eval();
|
||||
|
||||
assertArrayEquals(new long[]{1, 3, 4, 5}, out.shape());
|
||||
assertEquals(inArr, out.get(point(0), all(), all(), all()));
|
||||
|
@ -1720,7 +1719,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SameDiff sd = SameDiff.create();
|
||||
SDVariable in = sd.var("in", inArr);
|
||||
SDVariable slice = sd.stridedSlice(in, new int[]{1, 1, -999, 1}, new int[]{3, 3, -999, 4}, new int[]{1, 1, -999, 1}, 0, 0, 0, 1 << 2, 0);
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = slice.eval();
|
||||
|
||||
assertArrayEquals(new long[]{2, 2, 1, 3}, slice.getArr().shape());
|
||||
}
|
||||
|
@ -1735,7 +1734,7 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable slice2 = sd.stridedSlice(in, new int[]{2, 0, 0}, new int[]{-999, 4, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1);
|
||||
SDVariable slice3 = sd.stridedSlice(in, new int[]{1, 2, 1}, new int[]{-999, -999, 5}, new int[]{1, 1, 1}, 0, 0, 0, 0, 1 | 1 << 1);
|
||||
|
||||
sd.execAll(null);
|
||||
sd.outputAll(null);
|
||||
|
||||
assertEquals(inArr.get(point(0), all(), all()), slice.getArr());
|
||||
assertEquals(inArr.get(point(2), all(), all()), slice2.getArr());
|
||||
|
@ -1880,8 +1879,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
|
||||
|
||||
// log.info(sd.summary());
|
||||
sd.exec(Collections.emptyMap(), Lists.newArrayList(s));
|
||||
sd.execBackwards(Collections.emptyMap());
|
||||
sd.output(Collections.emptyMap(), Lists.newArrayList(s));
|
||||
sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2405,8 +2404,8 @@ public class ShapeOpValidation extends BaseOpValidation {
|
|||
SDVariable gathered = sd.gather(input, indices, 1);
|
||||
SDVariable loss = gathered.std(true);
|
||||
|
||||
sd.exec(null, gathered.getVarName());
|
||||
sd.setLossVariables(gathered.getVarName());
|
||||
sd.output((Map<String,INDArray>)null, gathered.name());
|
||||
sd.setLossVariables(gathered.name());
|
||||
|
||||
String err = OpValidation.validate(new TestCase(sd)
|
||||
.gradCheckEpsilon(1e-3)
|
||||
|
|
|
@ -115,37 +115,37 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
switch (i){
|
||||
case 0:
|
||||
out = in.mul(2);
|
||||
tc.expectedOutput(out.getVarName(), inArr.mul(2));
|
||||
tc.expectedOutput(out.name(), inArr.mul(2));
|
||||
msg = "mul - " + inOrder;
|
||||
break;
|
||||
case 1:
|
||||
out = in.div(2);
|
||||
tc.expectedOutput(out.getVarName(), inArr.div(2));
|
||||
tc.expectedOutput(out.name(), inArr.div(2));
|
||||
msg = "div - " + inOrder;
|
||||
break;
|
||||
case 2:
|
||||
out = in.add(2);
|
||||
tc.expectedOutput(out.getVarName(), inArr.add(2));
|
||||
tc.expectedOutput(out.name(), inArr.add(2));
|
||||
msg = "add - " + inOrder;
|
||||
break;
|
||||
case 3:
|
||||
out = in.sub(2);
|
||||
tc.expectedOutput(out.getVarName(), inArr.sub(2));
|
||||
tc.expectedOutput(out.name(), inArr.sub(2));
|
||||
msg = "sub - " + inOrder;
|
||||
break;
|
||||
case 4:
|
||||
out = in.rdiv(2);
|
||||
tc.expectedOutput(out.getVarName(), inArr.rdiv(2));
|
||||
tc.expectedOutput(out.name(), inArr.rdiv(2));
|
||||
msg = "rdiv - " + inOrder;
|
||||
break;
|
||||
case 5:
|
||||
out = in.rsub(2);
|
||||
tc.expectedOutput(out.getVarName(), inArr.rsub(2));
|
||||
tc.expectedOutput(out.name(), inArr.rsub(2));
|
||||
msg = "rsub - " + inOrder;
|
||||
break;
|
||||
case 6:
|
||||
out = sd.math().pow(in,2);
|
||||
tc.expectedOutput(out.getVarName(), Transforms.pow(inArr, 2));
|
||||
tc.expectedOutput(out.name(), Transforms.pow(inArr, 2));
|
||||
msg = "pow - " + inOrder;
|
||||
break;
|
||||
case 7:
|
||||
|
@ -584,219 +584,219 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
switch (i) {
|
||||
case 0:
|
||||
t = in.add(5.0);
|
||||
tc.expectedOutput(t.getVarName(), ia.add(5.0));
|
||||
tc.expectedOutput(t.name(), ia.add(5.0));
|
||||
break;
|
||||
case 1:
|
||||
t = in.sub(5.0);
|
||||
tc.expectedOutput(t.getVarName(), ia.sub(5.0));
|
||||
tc.expectedOutput(t.name(), ia.sub(5.0));
|
||||
break;
|
||||
case 2:
|
||||
t = in.mul(2.5);
|
||||
tc.expectedOutput(t.getVarName(), ia.mul(2.5));
|
||||
tc.expectedOutput(t.name(), ia.mul(2.5));
|
||||
break;
|
||||
case 3:
|
||||
t = in.div(4.0);
|
||||
tc.expectedOutput(t.getVarName(), ia.div(4.0));
|
||||
tc.expectedOutput(t.name(), ia.div(4.0));
|
||||
break;
|
||||
case 4:
|
||||
t = in.rsub(5.0);
|
||||
tc.expectedOutput(t.getVarName(), ia.rsub(5.0));
|
||||
tc.expectedOutput(t.name(), ia.rsub(5.0));
|
||||
break;
|
||||
case 5:
|
||||
t = in.rdiv(1.0);
|
||||
tc.expectedOutput(t.getVarName(), ia.rdiv(1.0));
|
||||
tc.expectedOutput(t.name(), ia.rdiv(1.0));
|
||||
break;
|
||||
case 6:
|
||||
t = sd.math().pow(in, 2.5);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.5, true));
|
||||
tc.expectedOutput(t.name(), Transforms.pow(ia, 2.5, true));
|
||||
break;
|
||||
case 7:
|
||||
t = sd.nn().sigmoid(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.sigmoid(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.sigmoid(ia, true));
|
||||
break;
|
||||
case 8:
|
||||
t = sd.math().tanh(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(2).subi(1.0);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.tanh(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.tanh(ia, true));
|
||||
break;
|
||||
case 9:
|
||||
ia.assign(Nd4j.rand(DataType.DOUBLE, ia.shape()));
|
||||
t = sd.math().tan(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.tan(ia));
|
||||
tc.expectedOutput(t.name(), Transforms.tan(ia));
|
||||
break;
|
||||
case 10:
|
||||
t = sd.math().cos(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.cos(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.cos(ia, true));
|
||||
break;
|
||||
case 11:
|
||||
t = sd.math().sin(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.sin(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.sin(ia, true));
|
||||
break;
|
||||
case 12:
|
||||
t = sd.nn().softplus(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.softPlus(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.softPlus(ia, true));
|
||||
break;
|
||||
case 13:
|
||||
t = sd.math().log(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.log(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.log(ia, true));
|
||||
break;
|
||||
case 14:
|
||||
t = sd.math().neg(in);
|
||||
INDArray exp14 = ia.neg();
|
||||
tc.expectedOutput(t.getVarName(), exp14);
|
||||
tc.expectedOutput(t.name(), exp14);
|
||||
break;
|
||||
case 15:
|
||||
t = sd.math().acos(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.acos(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.acos(ia, true));
|
||||
break;
|
||||
case 16:
|
||||
t = sd.math().acosh(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).addi(1.01); //Only defined for x >= 1
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ACosh(ia.dup())));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ACosh(ia.dup())));
|
||||
break;
|
||||
case 17:
|
||||
t = sd.math().asin(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.asin(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.asin(ia, true));
|
||||
break;
|
||||
case 18:
|
||||
t = sd.math().atan(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(4).subi(2);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.atan(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.atan(ia, true));
|
||||
break;
|
||||
case 19:
|
||||
t = sd.math().atanh(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut).muli(1.8).subi(0.9);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.atanh(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.atanh(ia, true));
|
||||
break;
|
||||
case 20:
|
||||
t = sd.math().cosh(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.cosh(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.cosh(ia, true));
|
||||
break;
|
||||
case 21:
|
||||
t = sd.math().cube(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 3.0, true));
|
||||
tc.expectedOutput(t.name(), Transforms.pow(ia, 3.0, true));
|
||||
break;
|
||||
case 22:
|
||||
t = sd.nn().elu(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.elu(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.elu(ia, true));
|
||||
break;
|
||||
case 23:
|
||||
//TODO SHOULDN'T THIS HAVE A DIMENSION ARG???
|
||||
t = sd.nn().softmax(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]);
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]);
|
||||
break;
|
||||
case 24:
|
||||
t = sd.math().sqrt(in);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.sqrt(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.sqrt(ia, true));
|
||||
break;
|
||||
case 25:
|
||||
t = sd.math().square(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.pow(ia, 2.0, true));
|
||||
tc.expectedOutput(t.name(), Transforms.pow(ia, 2.0, true));
|
||||
break;
|
||||
case 26:
|
||||
t = sd.transpose(in);
|
||||
tc.expectedOutput(t.getVarName(), ia.transpose().dup());
|
||||
tc.expectedOutput(t.name(), ia.transpose().dup());
|
||||
break;
|
||||
case 27:
|
||||
t = sd.math().abs(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.abs(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.abs(ia, true));
|
||||
break;
|
||||
case 28:
|
||||
t = sd.math().sinh(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.sinh(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.sinh(ia, true));
|
||||
break;
|
||||
case 29:
|
||||
t = sd.math().asinh(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ASinh(ia.dup())));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ASinh(ia.dup())));
|
||||
break;
|
||||
case 30:
|
||||
t = sd.math().exp(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.exp(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.exp(ia, true));
|
||||
break;
|
||||
case 31:
|
||||
t = sd.math().floor(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.floor(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.floor(ia, true));
|
||||
break;
|
||||
case 32:
|
||||
t = sd.nn().relu(in, 0.0);
|
||||
ia = Nd4j.rand(minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.relu(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.relu(ia, true));
|
||||
break;
|
||||
case 33:
|
||||
t = sd.nn().hardTanh(in);
|
||||
ia = Nd4j.rand(minibatch, nOut).muli(2).subi(1.0);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.hardTanh(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.hardTanh(ia, true));
|
||||
break;
|
||||
case 34:
|
||||
t = sd.nn().logSigmoid(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup())));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new LogSigmoid(ia.dup())));
|
||||
break;
|
||||
case 35:
|
||||
t = sd.nn().swish(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Swish(ia.dup())));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Swish(ia.dup())));
|
||||
break;
|
||||
case 36:
|
||||
t = sd.math().sign(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.sign(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.sign(ia, true));
|
||||
break;
|
||||
case 37:
|
||||
t = sd.nn().softsign(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.softsign(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.softsign(ia, true));
|
||||
break;
|
||||
case 38:
|
||||
t = sd.nn().leakyRelu(in, 0.0);
|
||||
ia = Nd4j.rand(minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.leakyRelu(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.leakyRelu(ia, true));
|
||||
break;
|
||||
case 39:
|
||||
if(OpValidationSuite.IGNORE_FAILING)
|
||||
continue;
|
||||
t = sd.nn().logSoftmax(in);
|
||||
ia = Nd4j.rand(minibatch, nOut).muli(10).subi(5);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.log(Transforms.softmax(ia, true)));
|
||||
tc.expectedOutput(t.name(), Transforms.log(Transforms.softmax(ia, true)));
|
||||
stdevLoss = true;
|
||||
break;
|
||||
case 40:
|
||||
t = sd.nn().selu(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new SELU(ia.dup())));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SELU(ia.dup())));
|
||||
break;
|
||||
case 41:
|
||||
t = sd.gt(in, 1.0).castTo(DataType.DOUBLE);
|
||||
tc.expectedOutput(t.getVarName(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.gt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
break;
|
||||
case 42:
|
||||
t = sd.gte(in, 1.0).castTo(DataType.DOUBLE);
|
||||
tc.expectedOutput(t.getVarName(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.gte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
break;
|
||||
case 43:
|
||||
t = sd.lt(in, 1.0).castTo(DataType.DOUBLE);
|
||||
tc.expectedOutput(t.getVarName(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.lt(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
break;
|
||||
case 44:
|
||||
t = sd.lte(in, 1.0).castTo(DataType.DOUBLE);
|
||||
tc.expectedOutput(t.getVarName(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.lte(1.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
break;
|
||||
case 45:
|
||||
t = sd.eq(in, 2.0).castTo(DataType.DOUBLE);
|
||||
ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.eq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
break;
|
||||
case 46:
|
||||
t = sd.neq(in, 2.0).castTo(DataType.DOUBLE);
|
||||
ia = Nd4j.linspace(1, minibatch * nOut, minibatch * nOut, DataType.DOUBLE).reshape('c', minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.neq(2.0).castTo(DataType.DOUBLE)).gradientCheck(false);
|
||||
break;
|
||||
case 47:
|
||||
t = sd.math().ceil(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.ceil(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.ceil(ia, true));
|
||||
break;
|
||||
case 48:
|
||||
ia = Nd4j.randn(DataType.DOUBLE, ia.shape()).muli(2);
|
||||
|
@ -804,7 +804,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
INDArray expOut48 = ia.dup();
|
||||
BooleanIndexing.replaceWhere(expOut48, -3, Conditions.lessThan(-3));
|
||||
BooleanIndexing.replaceWhere(expOut48, 2, Conditions.greaterThan(2));
|
||||
tc.expectedOutput(t.getVarName(), expOut48);
|
||||
tc.expectedOutput(t.name(), expOut48);
|
||||
break;
|
||||
case 49:
|
||||
//Clip by norm, dimension 0, some below threshold, some above
|
||||
|
@ -825,7 +825,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
expOut49.putColumn(j, origCol.mul(clip / origCol.norm2Number().doubleValue()));
|
||||
}
|
||||
}
|
||||
tc.expectedOutput(t.getVarName(), expOut49);
|
||||
tc.expectedOutput(t.name(), expOut49);
|
||||
//System.out.println(expOut.norm2(0));
|
||||
break;
|
||||
//TODO clip by norm along other dimensions
|
||||
|
@ -837,7 +837,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
.addIntegerArguments(dim)
|
||||
.addInputs(ia).addOutputs(expOut50).build();
|
||||
Nd4j.getExecutioner().exec(reverse);
|
||||
tc.expectedOutput(t.getVarName(), expOut50);
|
||||
tc.expectedOutput(t.name(), expOut50);
|
||||
break;
|
||||
case 51:
|
||||
dim = 0;
|
||||
|
@ -850,7 +850,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
.addIntegerArguments((exclusive) ? 1 : 0, (reverseBool) ? 1 : 0, dim)
|
||||
.addInputs(ia).addOutputs(expOut51).build();
|
||||
Nd4j.getExecutioner().exec(cumsum);
|
||||
tc.expectedOutput(t.getVarName(), expOut51);
|
||||
tc.expectedOutput(t.name(), expOut51);
|
||||
break;
|
||||
case 52:
|
||||
if(OpValidationSuite.IGNORE_FAILING){
|
||||
|
@ -869,7 +869,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
expOut52.putScalar(s0, s1, prod);
|
||||
}
|
||||
}
|
||||
tc.expectedOutput(t.getVarName(), expOut52);
|
||||
tc.expectedOutput(t.name(), expOut52);
|
||||
break;
|
||||
case 53:
|
||||
if(OpValidationSuite.IGNORE_FAILING){
|
||||
|
@ -881,90 +881,90 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
INDArray expOut53 = Nd4j.create(DataType.DOUBLE, 2, 2);
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut53).build();
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
tc.expectedOutput(t.getVarName(), expOut53);
|
||||
tc.expectedOutput(t.name(), expOut53);
|
||||
break;
|
||||
case 54:
|
||||
t = sd.math().erf(in);
|
||||
INDArray expOut54 = Nd4j.createUninitialized(DataType.DOUBLE, ia.shape(), ia.ordering());
|
||||
Nd4j.getExecutioner().exec(new Erf(ia, expOut54));
|
||||
tc.expectedOutput(t.getVarName(), expOut54);
|
||||
tc.expectedOutput(t.name(), expOut54);
|
||||
break;
|
||||
case 55:
|
||||
t = sd.math().erfc(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering()))));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Erfc(ia, Nd4j.createUninitialized(ia.shape(), ia.ordering()))));
|
||||
break;
|
||||
case 56:
|
||||
t = sd.math().expm1(in);
|
||||
tc.expectedOutput(t.getVarName(),Transforms.expm1(ia, true));
|
||||
tc.expectedOutput(t.name(),Transforms.expm1(ia, true));
|
||||
break;
|
||||
case 57:
|
||||
t = sd.math().log1p(in);
|
||||
ia = Nd4j.rand(minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.log1p(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.log1p(ia, true));
|
||||
break;
|
||||
case 58:
|
||||
t = sd.math().round(in);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.round(ia, true));
|
||||
tc.expectedOutput(t.name(), Transforms.round(ia, true));
|
||||
break;
|
||||
case 59:
|
||||
ia = Nd4j.create(new float[]{4, 2}).castTo(DataType.DOUBLE);
|
||||
// in = sd.var("in", new int[]{1, 2});
|
||||
t = sd.math().rsqrt(in);
|
||||
tc.expectedOutput(t.getVarName(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering()))));
|
||||
tc.expectedOutput(t.name(),Nd4j.getExecutioner().exec(new RSqrt(ia, Nd4j.create(ia.shape(), ia.ordering()))));
|
||||
break;
|
||||
case 60:
|
||||
t = sd.nn().relu6(in, 0);
|
||||
ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut);
|
||||
tc.expectedOutput(t.getVarName(),Transforms.relu6(ia, true));
|
||||
tc.expectedOutput(t.name(),Transforms.relu6(ia, true));
|
||||
break;
|
||||
case 61:
|
||||
ia = Nd4j.create(new float[] {2, 2}).castTo(DataType.DOUBLE);
|
||||
sd.associateArrayWithVariable(ia, in);
|
||||
double value = 42;
|
||||
t = sd.fill(in.castTo(DataType.INT), DataType.DOUBLE, value);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), Nd4j.valueArrayOf(new int[]{2,2}, 42)).gradientCheck(false);
|
||||
opName = "fill";
|
||||
break;
|
||||
case 62:
|
||||
t = sd.nn().hardSigmoid(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup())));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new HardSigmoid(ia, ia.dup())));
|
||||
break;
|
||||
case 63:
|
||||
t = sd.scalarMax(in, 0.5);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.max(ia, 0.5, true));
|
||||
tc.expectedOutput(t.name(), Transforms.max(ia, 0.5, true));
|
||||
break;
|
||||
case 64:
|
||||
t = sd.scalarMin(in, 0.5);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.min(ia, 0.5, true));
|
||||
tc.expectedOutput(t.name(), Transforms.min(ia, 0.5, true));
|
||||
break;
|
||||
case 65:
|
||||
t = sd.assign(in, 0.5);
|
||||
tc.expectedOutput(t.getVarName(), ia.dup().assign(0.5));
|
||||
tc.expectedOutput(t.name(), ia.dup().assign(0.5));
|
||||
break;
|
||||
case 66:
|
||||
t = sd.scalarFloorMod(in, 0.5);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5)));
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new ScalarFMod(ia.dup(), 0.5)));
|
||||
break;
|
||||
case 67:
|
||||
t = sd.math().reciprocal(in);
|
||||
tc.expectedOutput(t.getVarName(), ia.rdiv(1.0));
|
||||
tc.expectedOutput(t.name(), ia.rdiv(1.0));
|
||||
break;
|
||||
case 68:
|
||||
t = sd.shape(in).castTo(DataType.DOUBLE);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), Nd4j.create(ArrayUtil.toDouble(ia.shape()))).gradientCheck(false);
|
||||
break;
|
||||
case 69:
|
||||
t = sd.rank(in).castTo(DataType.DOUBLE);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.scalar((double)ia.rank())).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), Nd4j.scalar((double)ia.rank())).gradientCheck(false);
|
||||
break;
|
||||
case 70:
|
||||
t = sd.onesLike(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.ones(ia.shape()));
|
||||
tc.expectedOutput(t.name(), Nd4j.ones(ia.shape()));
|
||||
break;
|
||||
case 71:
|
||||
ia = Nd4j.randn(DataType.DOUBLE, nOut, nOut);
|
||||
t = sd.math().diagPart(in);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE));
|
||||
tc.expectedOutput(t.name(), Nd4j.create(new double[]{ia.getDouble(0,0), ia.getDouble(1,1), ia.getDouble(2,2), ia.getDouble(3,3)}).castTo(DataType.DOUBLE));
|
||||
break;
|
||||
case 72:
|
||||
t = sd.identity(in);
|
||||
|
@ -1087,109 +1087,109 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
switch (i) {
|
||||
case 0:
|
||||
t = in1.add(in2);
|
||||
tc.expectedOutput(t.getVarName(), ia.add(ib));
|
||||
tc.expectedOutput(t.name(), ia.add(ib));
|
||||
break;
|
||||
case 1:
|
||||
t = in1.sub(in2);
|
||||
tc.expectedOutput(t.getVarName(),ia.sub(ib));
|
||||
tc.expectedOutput(t.name(),ia.sub(ib));
|
||||
break;
|
||||
case 2:
|
||||
t = in1.mul(in2);
|
||||
tc.expectedOutput(t.getVarName(), ia.mul(ib));
|
||||
tc.expectedOutput(t.name(), ia.mul(ib));
|
||||
break;
|
||||
case 3:
|
||||
t = in1.div(in2);
|
||||
tc.expectedOutput(t.getVarName(), ia.div(ib));
|
||||
tc.expectedOutput(t.name(), ia.div(ib));
|
||||
break;
|
||||
case 4:
|
||||
t = in1.rsub(in2);
|
||||
tc.expectedOutput(t.getVarName(), ia.rsub(ib));
|
||||
tc.expectedOutput(t.name(), ia.rsub(ib));
|
||||
break;
|
||||
case 5:
|
||||
ia.assign(Nd4j.rand(ia.shape())).addi(0.5);
|
||||
ib.assign(Nd4j.rand(ib.shape())).addi(0.5);
|
||||
t = in1.rdiv(in2);
|
||||
tc.expectedOutput(t.getVarName(), ia.rdiv(ib));
|
||||
tc.expectedOutput(t.name(), ia.rdiv(ib));
|
||||
break;
|
||||
case 6:
|
||||
t = sd.eq(in1, in2);
|
||||
opName = "eq";
|
||||
tc.expectedOutput(t.getVarName(), ia.eq(ib)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.eq(ib)).gradientCheck(false);
|
||||
break;
|
||||
case 7:
|
||||
t = sd.neq(in1, in2);
|
||||
opName = "neq";
|
||||
tc.expectedOutput(t.getVarName(), ia.neq(ib)).gradientCheck(false);;
|
||||
tc.expectedOutput(t.name(), ia.neq(ib)).gradientCheck(false);;
|
||||
break;
|
||||
case 8:
|
||||
t = sd.gt(in1, in2);
|
||||
opName = "gt";
|
||||
tc.expectedOutput(t.getVarName(), ia.gt(ib)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.gt(ib)).gradientCheck(false);
|
||||
break;
|
||||
case 9:
|
||||
t = sd.lt(in1, in2);
|
||||
opName = "lt";
|
||||
tc.expectedOutput(t.getVarName(), ia.lt(ib)).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), ia.lt(ib)).gradientCheck(false);
|
||||
break;
|
||||
case 10:
|
||||
t = sd.gte(in1, in2);
|
||||
opName = "gte";
|
||||
INDArray expOut10 = Nd4j.create(DataType.BOOL, ia.shape());
|
||||
Nd4j.getExecutioner().exec(new GreaterThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut10}));
|
||||
tc.expectedOutput(t.getVarName(), expOut10).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), expOut10).gradientCheck(false);
|
||||
break;
|
||||
case 11:
|
||||
t = sd.lte(in1, in2);
|
||||
opName = "lte";
|
||||
INDArray expOut11 = Nd4j.create(DataType.BOOL, ia.shape());
|
||||
Nd4j.getExecutioner().exec(new LessThanOrEqual(new INDArray[]{ia, ib}, new INDArray[]{expOut11}));
|
||||
tc.expectedOutput(t.getVarName(), expOut11).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), expOut11).gradientCheck(false);
|
||||
break;
|
||||
case 12:
|
||||
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
|
||||
ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
|
||||
t = sd.math().or(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL));
|
||||
opName = "or";
|
||||
tc.expectedOutput(t.getVarName(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), Transforms.or(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
||||
break;
|
||||
case 13:
|
||||
ib = Nd4j.randn(DataType.DOUBLE, nOut, nOut);
|
||||
t = sd.mmul(in1, in2);
|
||||
tc.expectedOutput(t.getVarName(), ia.mmul(ib));
|
||||
tc.expectedOutput(t.name(), ia.mmul(ib));
|
||||
break;
|
||||
case 14:
|
||||
t = sd.max(in1, in2);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]);
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Max(ia, ib, ia.dup()))[0]);
|
||||
break;
|
||||
case 15:
|
||||
t = sd.min(in1, in2);
|
||||
tc.expectedOutput(t.getVarName(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]);
|
||||
tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new Min(ia, ib, ia.dup()))[0]);
|
||||
break;
|
||||
case 16:
|
||||
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
|
||||
ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
|
||||
t = sd.math().and(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL));
|
||||
opName = "and";
|
||||
tc.expectedOutput(t.getVarName(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), Transforms.and(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
||||
break;
|
||||
case 17:
|
||||
ia = Nd4j.getExecutioner().exec(new BernoulliDistribution(ia, 0.5));
|
||||
ib = Nd4j.getExecutioner().exec(new BernoulliDistribution(ib, 0.5));
|
||||
t = sd.math().xor(in1.castTo(DataType.BOOL), in2.castTo(DataType.BOOL));
|
||||
opName = "xor";
|
||||
tc.expectedOutput(t.getVarName(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
||||
tc.expectedOutput(t.name(), Transforms.xor(ia.castTo(DataType.BOOL), ib.castTo(DataType.BOOL))).gradientCheck(false);
|
||||
break;
|
||||
case 18:
|
||||
t = sd.assign(in1, in2);
|
||||
tc.expectedOutput(t.getVarName(), ib);
|
||||
tc.expectedOutput(t.name(), ib);
|
||||
break;
|
||||
case 19:
|
||||
t = sd.math().atan2(in1, in2);
|
||||
tc.expectedOutput(t.getVarName(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms
|
||||
tc.expectedOutput(t.name(), Transforms.atan2(ib, ia)); //Note: y,x order for samediff; x,y order for transforms
|
||||
break;
|
||||
case 20:
|
||||
t = sd.math().mergeAdd(in1, in2, in2);
|
||||
tc.expectedOutput(t.getVarName(), ia.add(ib).add(ib));
|
||||
tc.expectedOutput(t.name(), ia.add(ib).add(ib));
|
||||
break;
|
||||
case 21:
|
||||
t = in1.squaredDifference(in2);
|
||||
|
@ -1199,7 +1199,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
.addOutputs(expOut21)
|
||||
.build();
|
||||
Nd4j.getExecutioner().exec(squareDiff);
|
||||
tc.expectedOutput(t.getVarName(), expOut21);
|
||||
tc.expectedOutput(t.name(), expOut21);
|
||||
break;
|
||||
case 22:
|
||||
//set diag
|
||||
|
@ -1210,7 +1210,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
expOut22.putScalar(j,j, ib.getDouble(j));
|
||||
}
|
||||
t = sd.math().setDiag(in1, in2);
|
||||
tc.expectedOutput(t.getVarName(), expOut22);
|
||||
tc.expectedOutput(t.name(), expOut22);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException();
|
||||
|
@ -1341,7 +1341,6 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
}
|
||||
}
|
||||
|
||||
//TODO UPDATE TO OP VALIDATION OR DELETE
|
||||
@Test
|
||||
public void testLogGrad() {
|
||||
SameDiff sameDiff = SameDiff.create();
|
||||
|
@ -1349,7 +1348,7 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
SDVariable log = sameDiff.math().log(input);
|
||||
SDVariable sum = sameDiff.sum(log, Integer.MAX_VALUE);
|
||||
INDArray result = null;
|
||||
sameDiff.execBackwards(Collections.emptyMap());
|
||||
sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet());
|
||||
}
|
||||
|
||||
|
||||
|
@ -1362,8 +1361,8 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
SDVariable input = sameDiff.var("x", inputs.get("x"));
|
||||
SDVariable sigmoid = sameDiff.nn().sigmoid(input);
|
||||
SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE);
|
||||
sameDiff.execBackwards(Collections.emptyMap());
|
||||
INDArray arr = input.gradient().getArr();
|
||||
Map<String,INDArray> m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet());
|
||||
INDArray arr = m.get(input.name());
|
||||
assertTrue(Nd4j.create(new double[][]{
|
||||
{0.1966, 0.1050},
|
||||
{0.0452, 0.0177}
|
||||
|
@ -1384,12 +1383,12 @@ public class TransformOpValidation extends BaseOpValidation {
|
|||
public void testRank0EdgeCase(){
|
||||
SameDiff sd = SameDiff.create();
|
||||
SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4})));
|
||||
double d0 = sd.execAndEndResult().getDouble(0);
|
||||
double d0 = v1.eval().getDouble(0);
|
||||
assertEquals(8, d0, 0);
|
||||
|
||||
SDVariable v2 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))).div(2.0);
|
||||
sd.exec(Collections.emptyMap(), sd.outputs());
|
||||
double d1 = v2.getArr().getDouble(0);
|
||||
Map<String,INDArray> m = sd.outputAll(Collections.emptyMap());
|
||||
double d1 = m.get(v2.name()).getDouble(0);
|
||||
assertEquals(4, d1, 0);
|
||||
}
|
||||
|
||||
|
|
|
@ -87,12 +87,12 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
SDVariable tanh = sd.math().tanh(in);
|
||||
INDArray exp = Transforms.tanh(in.getArr(), true);
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = tanh.eval();
|
||||
assertEquals(exp, out);
|
||||
|
||||
//Now, replace with minibatch 5:
|
||||
in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4));
|
||||
INDArray out2 = sd.execAndEndResult();
|
||||
INDArray out2 = tanh.eval();
|
||||
assertArrayEquals(new long[]{5,4}, out2.shape());
|
||||
|
||||
exp = Transforms.tanh(in.getArr(), true);
|
||||
|
@ -124,12 +124,12 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
SDVariable mmul = sd.mmul(in,w).add(b);
|
||||
INDArray exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
|
||||
|
||||
INDArray out = sd.execAndEndResult();
|
||||
INDArray out = mmul.eval();
|
||||
assertEquals(exp, out);
|
||||
|
||||
//Now, replace with minibatch 5:
|
||||
in.setArray(Nd4j.linspace(1,20,20, DataType.DOUBLE).reshape(5,4));
|
||||
INDArray out2 = sd.execAndEndResult();
|
||||
INDArray out2 = mmul.eval();
|
||||
assertArrayEquals(new long[]{5,5}, out2.shape());
|
||||
|
||||
exp = in.getArr().mmul(w.getArr()).addiRowVector(b.getArr());
|
||||
|
@ -137,11 +137,10 @@ public class FailingSameDiffTests extends BaseNd4jTest {
|
|||
|
||||
//Generate gradient function, and exec
|
||||
SDVariable loss = mmul.std(true);
|
||||
sd.execBackwards(Collections.emptyMap());
|
||||
sd.calculateGradients(Collections.emptyMap(), sd.getVariables().keySet());
|
||||
|
||||
in.setArray(Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape(3,4));
|
||||
sd.execAndEndResult();
|
||||
out2 = mmul.getArr();
|
||||
out2 = mmul.eval();
|
||||
assertArrayEquals(new long[]{3,5}, out2.shape());
|
||||
}
|
||||
|
||||
|
|
|
@ -173,7 +173,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
}
|
||||
|
||||
if(execFirst){
|
||||
sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName()));
|
||||
sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name()));
|
||||
}
|
||||
|
||||
File f = testDir.newFile();
|
||||
|
@ -186,7 +186,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
List<SDVariable> varsRestored = restored.variables();
|
||||
assertEquals(varsOrig.size(), varsRestored.size());
|
||||
for (int j = 0; j < varsOrig.size(); j++) {
|
||||
assertEquals(varsOrig.get(j).getVarName(), varsRestored.get(j).getVarName());
|
||||
assertEquals(varsOrig.get(j).name(), varsRestored.get(j).name());
|
||||
}
|
||||
|
||||
DifferentialFunction[] fOrig = sd.ops();
|
||||
|
@ -200,10 +200,10 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
assertEquals(sd.getLossVariables(), restored.getLossVariables());
|
||||
|
||||
|
||||
Map<String,INDArray> m = sd.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName()));
|
||||
INDArray outOrig = m.get(x.getVarName());
|
||||
Map<String,INDArray> m2 = restored.exec(Collections.singletonMap("in", arr), Collections.singletonList(x.getVarName()));
|
||||
INDArray outRestored = m2.get(x.getVarName());
|
||||
Map<String,INDArray> m = sd.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name()));
|
||||
INDArray outOrig = m.get(x.name());
|
||||
Map<String,INDArray> m2 = restored.output(Collections.singletonMap("in", arr), Collections.singletonList(x.name()));
|
||||
INDArray outRestored = m2.get(x.name());
|
||||
|
||||
assertEquals(String.valueOf(i), outOrig, outRestored);
|
||||
|
||||
|
@ -320,7 +320,7 @@ public class FlatBufferSerdeTest extends BaseNd4jTest {
|
|||
if(v.isPlaceHolder() || v.getVariableType() == VariableType.ARRAY)
|
||||
continue;
|
||||
|
||||
SDVariable v2 = sd2.getVariable(v.getVarName());
|
||||
SDVariable v2 = sd2.getVariable(v.name());
|
||||
|
||||
INDArray a1 = v.getArr();
|
||||
INDArray a2 = v2.getArr();
|
||||
|
|
|
@ -57,17 +57,17 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
|
|||
|
||||
SDVariable sub = add.sub(add2);
|
||||
|
||||
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.getVarName())));
|
||||
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.getVarName())));
|
||||
assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.getVarName())));
|
||||
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add.name())));
|
||||
assertTrue(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(add2.name())));
|
||||
assertFalse(OpPredicate.classEquals(AddOp.class).matches(sd, sd.getVariableOutputOp(sub.name())));
|
||||
|
||||
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.getVarName())));
|
||||
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.getVarName())));
|
||||
assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.getVarName())));
|
||||
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add.name())));
|
||||
assertTrue(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(add2.name())));
|
||||
assertFalse(OpPredicate.opNameEquals(AddOp.OP_NAME).matches(sd, sd.getVariableOutputOp(sub.name())));
|
||||
|
||||
assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.getVarName())));
|
||||
assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.getVarName())));
|
||||
assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.getVarName())));
|
||||
assertTrue(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(add.name())));
|
||||
assertTrue(OpPredicate.opNameMatches("ad.*").matches(sd, sd.getVariableOutputOp(add2.name())));
|
||||
assertFalse(OpPredicate.opNameMatches(".*dd").matches(sd, sd.getVariableOutputOp(sub.name())));
|
||||
|
||||
|
||||
SubGraphPredicate p = SubGraphPredicate.withRoot(OpPredicate.classEquals(AddOp.class));
|
||||
|
@ -76,11 +76,11 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
|
|||
assertEquals(2, l.size());
|
||||
|
||||
SubGraph sg1 = l.get(0);
|
||||
assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.getVarName()));
|
||||
assertTrue(sg1.getRootNode() == sd.getVariableOutputOp(add.name()));
|
||||
assertEquals(0, sg1.getChildNodes().size());
|
||||
|
||||
SubGraph sg2 = l.get(1);
|
||||
assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.getVarName()));
|
||||
assertTrue(sg2.getRootNode() == sd.getVariableOutputOp(add2.name()));
|
||||
assertEquals(0, sg2.getChildNodes().size());
|
||||
}
|
||||
|
||||
|
@ -118,7 +118,7 @@ public class GraphTransformUtilTests extends BaseNd4jTest {
|
|||
});
|
||||
|
||||
INDArray exp2 = p1.div(p2).mul(p1.sub(p2));
|
||||
INDArray out2 = sd2.getVariable(mul.getVarName()).eval();
|
||||
INDArray out2 = sd2.getVariable(mul.name()).eval();
|
||||
assertEquals(exp2, out2);
|
||||
|
||||
|
||||
|
|
|
@ -33,18 +33,18 @@ public class NameScopeTests extends BaseNd4jTest {
|
|||
SDVariable v = sd.var("x");
|
||||
try(NameScope ns = sd.withNameScope("nameScope")){
|
||||
SDVariable v2 = sd.var("x2");
|
||||
assertEquals("nameScope/x2", v2.getVarName());
|
||||
assertEquals("nameScope/x2", v2.name());
|
||||
assertTrue(sd.getVariables().containsKey("nameScope/x2"));
|
||||
assertEquals("nameScope", sd.currentNameScope());
|
||||
|
||||
SDVariable v3 = sd.var("x");
|
||||
assertEquals("nameScope/x", v3.getVarName());
|
||||
assertEquals("nameScope/x", v3.name());
|
||||
assertTrue(sd.getVariables().containsKey("nameScope/x"));
|
||||
|
||||
try(NameScope ns2 = sd.withNameScope("scope2")){
|
||||
assertEquals("nameScope/scope2", sd.currentNameScope());
|
||||
SDVariable v4 = sd.var("x");
|
||||
assertEquals("nameScope/scope2/x", v4.getVarName());
|
||||
assertEquals("nameScope/scope2/x", v4.name());
|
||||
assertTrue(sd.getVariables().containsKey("nameScope/scope2/x"));
|
||||
}
|
||||
|
||||
|
@ -76,19 +76,19 @@ public class NameScopeTests extends BaseNd4jTest {
|
|||
}
|
||||
SDVariable a = sd.var("a", DataType.FLOAT, 1);
|
||||
|
||||
assertEquals("x", x.getVarName());
|
||||
assertEquals("s1/y", y.getVarName());
|
||||
assertEquals("s1/s2/z", z.getVarName());
|
||||
assertEquals("a", a.getVarName());
|
||||
assertEquals("x", x.name());
|
||||
assertEquals("s1/y", y.name());
|
||||
assertEquals("s1/s2/z", z.name());
|
||||
assertEquals("a", a.name());
|
||||
|
||||
assertTrue(add.getVarName(), add.getVarName().startsWith("s1/"));
|
||||
assertEquals("s1/addxy", addWithName.getVarName());
|
||||
assertTrue(add.name(), add.name().startsWith("s1/"));
|
||||
assertEquals("s1/addxy", addWithName.name());
|
||||
|
||||
assertTrue(merge.getVarName(), merge.getVarName().startsWith("s1/s2/"));
|
||||
assertEquals("s1/s2/mmax", mergeWithName.getVarName());
|
||||
assertTrue(merge.name(), merge.name().startsWith("s1/s2/"));
|
||||
assertEquals("s1/s2/mmax", mergeWithName.name());
|
||||
|
||||
Set<String> allowedVarNames = new HashSet<>(Arrays.asList("x", "s1/y", "s1/s2/z", "a",
|
||||
add.getVarName(), addWithName.getVarName(), merge.getVarName(), mergeWithName.getVarName()));
|
||||
add.name(), addWithName.name(), merge.name(), mergeWithName.name()));
|
||||
Set<String> allowedOpNames = new HashSet<>();
|
||||
|
||||
//Check op names:
|
||||
|
@ -102,8 +102,8 @@ public class NameScopeTests extends BaseNd4jTest {
|
|||
|
||||
//Check fields - Variable, SDOp, etc
|
||||
for(Variable v : sd.getVariables().values()){
|
||||
assertTrue(v.getVariable().getVarName(), allowedVarNames.contains(v.getVariable().getVarName()));
|
||||
assertEquals(v.getName(), v.getVariable().getVarName());
|
||||
assertTrue(v.getVariable().name(), allowedVarNames.contains(v.getVariable().name()));
|
||||
assertEquals(v.getName(), v.getVariable().name());
|
||||
if(v.getInputsForOp() != null){
|
||||
for(String s : v.getInputsForOp()){
|
||||
assertTrue(s, allowedOpNames.contains(s));
|
||||
|
|
|
@ -108,14 +108,14 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
|
|||
sd.fit(ds);
|
||||
}
|
||||
|
||||
for(String s : new String[]{"w", "b", badd.getVarName(), add.getVarName(), "l1", "l2"}){
|
||||
for(String s : new String[]{"w", "b", badd.name(), add.name(), "l1", "l2"}){
|
||||
SDVariable gradVar = sd.getVariable(s).gradient();
|
||||
assertNotNull(s, gradVar);
|
||||
}
|
||||
//Unused:
|
||||
assertFalse(shape.hasGradient());
|
||||
try{ assertNull(shape.gradient()); } catch (IllegalStateException e){ assertTrue(e.getMessage().contains("only floating point variables")); }
|
||||
for(String s : new String[]{unused1.getVarName(), unused2.getVarName(), unused3.getVarName()}){
|
||||
for(String s : new String[]{unused1.name(), unused2.name(), unused3.name()}){
|
||||
assertNull(sd.getVariable(s).gradient());
|
||||
}
|
||||
}
|
||||
|
@ -151,20 +151,20 @@ public class SameDiffSpecifiedLossVarsTests extends BaseNd4jTest {
|
|||
sd.setLossVariables("loss1");
|
||||
sd.createGradFunction();
|
||||
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
|
||||
assertNotNull(v.getVarName(), v.gradient());
|
||||
assertNotNull(v.name(), v.gradient());
|
||||
}
|
||||
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
|
||||
assertNull(v.getVarName(), v.gradient());
|
||||
assertNull(v.name(), v.gradient());
|
||||
}
|
||||
|
||||
//Now, set to other loss function
|
||||
sd.setLossVariables("loss2");
|
||||
sd.createGradFunction();
|
||||
for(SDVariable v : new SDVariable[]{ph1, w1, b1, mmul1, badd1, loss1}){
|
||||
assertNull(v.getVarName(), v.gradient());
|
||||
assertNull(v.name(), v.gradient());
|
||||
}
|
||||
for(SDVariable v : new SDVariable[]{ph2, w2, b2, mmul2, badd2, loss2}){
|
||||
assertNotNull(v.getVarName(), v.gradient());
|
||||
assertNotNull(v.name(), v.gradient());
|
||||
}
|
||||
|
||||
//Train the first side of the graph. The other side should remain unmodified!
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -109,6 +109,8 @@ public class FileReadWriteTests extends BaseNd4jTest {
|
|||
for (int i = 0; i < s.outputsLength(); i++) {
|
||||
outputs.add(s.outputs(i));
|
||||
}
|
||||
if(outputs.isEmpty())
|
||||
outputs = null;
|
||||
assertEquals(sd.outputs(), outputs);
|
||||
|
||||
//Check variables
|
||||
|
|
|
@ -63,7 +63,7 @@ public class UIListenerTest {
|
|||
Map<String, INDArray> m = new HashMap<>();
|
||||
iter.reset();
|
||||
m.put("in", iter.next().getFeatures());
|
||||
INDArray out = sd.execSingle(m, "softmax");
|
||||
INDArray out = sd.outputSingle(m, "softmax");
|
||||
assertNotNull(out);
|
||||
assertArrayEquals(new long[]{150, 3}, out.shape());
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ public class ExecutionTests extends BaseNd4jTest {
|
|||
val tg = TFGraphMapper.importGraphTxt(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream(), null, null);
|
||||
System.out.println(tg.summary());
|
||||
|
||||
Map<String,INDArray> result_0 = tg.exec(Collections.emptyMap(), tg.outputs());
|
||||
Map<String,INDArray> result_0 = tg.outputAll(null);
|
||||
val exp_0 = Nd4j.create(DataType.FLOAT, 3).assign(3.0);
|
||||
|
||||
assertEquals(exp_0, result_0.get("Sum"));
|
||||
|
|
|
@ -174,7 +174,7 @@ public class BERTGraphTest extends BaseNd4jTest {
|
|||
//Find pre-dropout input variable:
|
||||
SDVariable newOut = null;
|
||||
for(SDVariable v : inputs){
|
||||
if(v.getVarName().endsWith("/BiasAdd") || v.getVarName().endsWith("/Softmax") || v.getVarName().endsWith("/add_1") || v.getVarName().endsWith("/Tanh")){
|
||||
if(v.name().endsWith("/BiasAdd") || v.name().endsWith("/Softmax") || v.name().endsWith("/add_1") || v.name().endsWith("/Tanh")){
|
||||
newOut = v;
|
||||
break;
|
||||
}
|
||||
|
@ -249,7 +249,7 @@ public class BERTGraphTest extends BaseNd4jTest {
|
|||
placeholderValues.put("IteratorGetNext:1", mask);
|
||||
placeholderValues.put("IteratorGetNext:4", segmentIdxs);
|
||||
|
||||
Map<String, INDArray> out = sd.exec(placeholderValues, "loss/Softmax");
|
||||
Map<String, INDArray> out = sd.output(placeholderValues, "loss/Softmax");
|
||||
INDArray softmax = out.get("loss/Softmax");
|
||||
// System.out.println("OUTPUT - Softmax");
|
||||
// System.out.println(softmax);
|
||||
|
@ -335,8 +335,8 @@ public class BERTGraphTest extends BaseNd4jTest {
|
|||
|
||||
//For training, convert weights and biases from constants to variables:
|
||||
for(SDVariable v : sd.variables()){
|
||||
if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.getVarName())){ //Skip scalars - trainable params
|
||||
log.info("Converting to variable: {} - dtype: {} - shape: {}", v.getVarName(), v.dataType(), Arrays.toString(v.getArr().shape()));
|
||||
if(v.isConstant() && v.dataType().isFPType() && !v.getArr().isScalar() && !floatConstants.contains(v.name())){ //Skip scalars - trainable params
|
||||
log.info("Converting to variable: {} - dtype: {} - shape: {}", v.name(), v.dataType(), Arrays.toString(v.getArr().shape()));
|
||||
v.convertToVariable();
|
||||
}
|
||||
}
|
||||
|
@ -393,14 +393,14 @@ public class BERTGraphTest extends BaseNd4jTest {
|
|||
placeholderValues.put("IteratorGetNext:4", segmentIdxs);
|
||||
placeholderValues.put("label", labelArr);
|
||||
|
||||
INDArray lossArr = sd.exec(placeholderValues, "loss").get("loss");
|
||||
INDArray lossArr = sd.output(placeholderValues, "loss").get("loss");
|
||||
assertTrue(lossArr.isScalar());
|
||||
double scoreBefore = lossArr.getDouble(0);
|
||||
for( int i=0; i<5; i++ ){
|
||||
sd.fit(mds);
|
||||
}
|
||||
|
||||
lossArr = sd.exec(placeholderValues, "loss").get("loss");
|
||||
lossArr = sd.output(placeholderValues, "loss").get("loss");
|
||||
assertTrue(lossArr.isScalar());
|
||||
double scoreAfter = lossArr.getDouble(0);
|
||||
|
||||
|
|
|
@ -105,11 +105,9 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
|
|||
//Perform inference
|
||||
List<String> inputs = sd.inputs();
|
||||
assertEquals(1, inputs.size());
|
||||
List<String> outputs = sd.outputs();
|
||||
assertEquals(1, outputs.size());
|
||||
|
||||
String out = outputs.get(0);
|
||||
Map<String,INDArray> m = sd.exec(Collections.singletonMap(inputs.get(0), img), out);
|
||||
String out = "MobilenetV1/Predictions/Softmax";
|
||||
Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out);
|
||||
|
||||
INDArray outArr = m.get(out);
|
||||
|
||||
|
@ -167,7 +165,7 @@ public class ValidateZooModelPredictions extends BaseNd4jTest {
|
|||
assertEquals(1, inputs.size());
|
||||
|
||||
String out = "softmax_tensor";
|
||||
Map<String,INDArray> m = sd.exec(Collections.singletonMap(inputs.get(0), img), out);
|
||||
Map<String,INDArray> m = sd.output(Collections.singletonMap(inputs.get(0), img), out);
|
||||
|
||||
INDArray outArr = m.get(out);
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
//g.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/mnist.fb"), ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).build());
|
||||
|
||||
g.execAndEndResult();
|
||||
g.outputAll(null);
|
||||
}
|
||||
|
||||
|
||||
|
@ -129,7 +129,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
val graph = TFGraphMapper.importGraph(new ClassPathResource("/tf_graphs/argmax.pb.txt").getInputStream());
|
||||
|
||||
log.info(graph.asFlatPrint());
|
||||
val result = graph.execAndEndResult();
|
||||
val result = graph.outputAll(null).get(graph.outputs().get(0));
|
||||
|
||||
val exp = Nd4j.createFromArray(new long[]{2, 2, 2});
|
||||
|
||||
|
@ -222,7 +222,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
graph.var("Placeholder", p0);
|
||||
graph.var("Placeholder_1", p1);
|
||||
|
||||
val res = graph.execAndEndResult();
|
||||
val res = graph.outputAll(null).get(graph.outputs().get(0));
|
||||
|
||||
|
||||
|
||||
|
@ -341,7 +341,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
val constIn = tg.getVariable("StridedSlice/input");
|
||||
assertNotNull(constIn);
|
||||
|
||||
val arr = tg.getArrForVarName(constIn.getVarName());
|
||||
val arr = tg.getArrForVarName(constIn.name());
|
||||
assertEquals(139.5, arr.sumNumber().doubleValue(), 1e-5);
|
||||
|
||||
|
||||
|
@ -651,9 +651,8 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
INDArray input = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4);
|
||||
INDArray expectedOutput = Nd4j.linspace(1,40,40, DataType.FLOAT).reshape(10,4).addRowVector(Nd4j.linspace(1,4,4, DataType.FLOAT));
|
||||
INDArray actual = graph.execSingle(Collections.singletonMap("input",input), graph.outputs().get(0));
|
||||
INDArray actual = graph.outputSingle(Collections.singletonMap("input",input), graph.outputs().get(0));
|
||||
assertEquals(input,graph.getVariable("input").getArr());
|
||||
assertArrayEquals(input.shape(),graph.getShapeForVarName(graph.getVariable("input").getVarName()));
|
||||
assertEquals(expectedOutput,actual);
|
||||
}
|
||||
|
||||
|
@ -665,13 +664,13 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
val variables = new HashMap<String, SDVariable>();
|
||||
for (val var : tg.variables()) {
|
||||
variables.put(var.getVarName(), var);
|
||||
variables.put(var.name(), var);
|
||||
}
|
||||
|
||||
val functions = new HashMap<String, DifferentialFunction>();
|
||||
for (val func: tg.ops()) {
|
||||
val ownName = func.getOwnName();
|
||||
String outName = func.outputVariables()[0].getVarName();
|
||||
String outName = func.outputVariables()[0].name();
|
||||
|
||||
assertTrue("Missing ownName: [" + ownName +"]",variables.containsKey(ownName));
|
||||
assertEquals(ownName, outName);
|
||||
|
@ -704,7 +703,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
|
||||
|
||||
//log.info("{}", tg.asFlatPrint());
|
||||
val array = tg.execAndEndResult();
|
||||
val array = tg.outputAll(null).get(tg.outputs().get(0));
|
||||
val exp = Nd4j.create(2, 2).assign(1);
|
||||
assertNotNull(array);
|
||||
assertEquals(exp, array);
|
||||
|
@ -723,7 +722,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
//log.info("{}", tg.asFlatPrint());
|
||||
|
||||
|
||||
val array = tg.execAndEndResult();
|
||||
val array = tg.outputAll(null).get(tg.outputs().get(0));
|
||||
val exp = Nd4j.create(2, 2).assign(1);
|
||||
assertNotNull(array);
|
||||
assertEquals(exp, array);
|
||||
|
@ -741,7 +740,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
//log.info("{}", tg.asFlatPrint());
|
||||
/*
|
||||
val array = tg.execAndEndResult();
|
||||
val array = tg.outputAll(null).get(tg.outputs().get(0));
|
||||
val exp = Nd4j.create(2, 2).assign(2);
|
||||
assertNotNull(array);
|
||||
assertEquals(exp, array);*/
|
||||
|
@ -759,7 +758,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
//log.info("{}", tg.asFlatPrint());
|
||||
|
||||
val array = tg.execAndEndResult();
|
||||
val array = tg.outputAll(null).get(tg.outputs().get(0));
|
||||
val exp = Nd4j.create(2, 2).assign(4);
|
||||
assertNotNull(array);
|
||||
assertEquals(exp, array);
|
||||
|
@ -780,7 +779,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
//log.info("{}", tg.asFlatPrint());
|
||||
|
||||
val array = tg.execAndEndResult();
|
||||
INDArray array = tg.outputAll(null).get(tg.outputs().get(0));
|
||||
val exp = Nd4j.create(2, 2).assign(-1);
|
||||
assertNotNull(array);
|
||||
assertEquals(exp, array);
|
||||
|
@ -800,7 +799,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
//log.info("{}", tg.asFlatPrint());
|
||||
|
||||
val array = tg.execAndEndResult();
|
||||
val array = tg.outputAll(null).get(tg.outputs().get(0));
|
||||
val exp = Nd4j.create(2, 2).assign(-3);
|
||||
assertNotNull(array);
|
||||
assertEquals(exp, array);
|
||||
|
@ -822,7 +821,8 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
//log.info("{}", tg.asFlatPrint());
|
||||
|
||||
val array = tg.execAndEndResult();
|
||||
Map<String,INDArray> m = tg.outputAll(null);
|
||||
val array = m.get(tg.outputs().get(0));
|
||||
//val array = tg.getVariable("output").getArr();
|
||||
val exp = Nd4j.create(2, 2).assign(15.0);
|
||||
assertNotNull(array);
|
||||
|
@ -968,7 +968,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
assertNotNull(tg);
|
||||
|
||||
val input_matrix = Nd4j.ones(3, 2);
|
||||
val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
||||
val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
||||
|
||||
val exp = Nd4j.create(new float[] {1, 1, 2, 2, 3, 3}, new int[]{3, 2});
|
||||
|
||||
|
@ -982,7 +982,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
val input_matrix = Nd4j.ones(3, 2);
|
||||
|
||||
val array = tg.exec(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
||||
val array = tg.output(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0)).get(tg.outputs().get(0));
|
||||
|
||||
val exp = Nd4j.create(new float[] {2, 2}, new int[]{2});
|
||||
|
||||
|
@ -997,7 +997,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/tensor_array_unstack.pb.txt").getInputStream());
|
||||
assertNotNull(tg);
|
||||
|
||||
val array = tg.execSingle(Collections.emptyMap(), tg.outputs().get(0));
|
||||
val array = tg.outputSingle(Collections.emptyMap(), tg.outputs().get(0));
|
||||
|
||||
val exp = Nd4j.create(new float[] {5, 6, 7, 8}, new int[]{4});
|
||||
|
||||
|
@ -1011,7 +1011,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
|
||||
val input_matrix = Nd4j.linspace(1, 10, 10, DataType.FLOAT).reshape(5, 2);
|
||||
log.info("Graph: {}", tg.asFlatPrint());
|
||||
val array = tg.execSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
||||
val array = tg.outputSingle(Collections.singletonMap("input_matrix", input_matrix), tg.outputs().get(0));
|
||||
|
||||
val exp = Nd4j.create(new float[] {3,6, 9,12, 15,18, 21,24, 27,30}, new int[]{5, 2});
|
||||
|
||||
|
@ -1023,7 +1023,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
Nd4j.create(1);
|
||||
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/losses/log_loss_rank2_axis1_SUM_OVER_BATCH_SIZE/frozen_model.pb").getInputStream());
|
||||
|
||||
tg.execAndEndResult();
|
||||
tg.outputAll(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -1040,7 +1040,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
for (int e = 0; e < 1000; e++){
|
||||
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/reduce_any/rank0/frozen_model.pb").getInputStream());
|
||||
|
||||
Map<String,INDArray> result = tg.exec(Collections.emptyMap(), tg.outputs());
|
||||
Map<String,INDArray> result = tg.output(Collections.emptyMap(), tg.outputs());
|
||||
|
||||
assertNotNull(result);
|
||||
assertTrue(result.size() > 0);
|
||||
|
@ -1052,7 +1052,7 @@ public class TensorFlowImportTest extends BaseNd4jTest {
|
|||
Nd4j.create(1);
|
||||
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transforms/logicalxor_3,4_3,4/frozen_model.pb").getInputStream());
|
||||
|
||||
tg.execAndEndResult();
|
||||
tg.outputAll(null);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -118,7 +118,7 @@ public class ImportModelDebugger {
|
|||
|
||||
List<String> outputs = sd.outputs();
|
||||
|
||||
sd.exec(ph, outputs);
|
||||
sd.output(ph, outputs);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
package org.nd4j.linalg.convolution;
|
||||
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
import org.nd4j.resources.Resources;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class DeconvTests extends BaseNd4jTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
public DeconvTests(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void compareKeras() throws Exception {
|
||||
File f = testDir.newFolder();
|
||||
Resources.copyDirectory("keras/deconv", f);
|
||||
|
||||
File[] files = f.listFiles();
|
||||
|
||||
Set<String> tests = new HashSet<>();
|
||||
for(File file : files){
|
||||
String n = file.getName();
|
||||
if(!n.startsWith("mb"))
|
||||
continue;
|
||||
|
||||
int idx = n.lastIndexOf('_');
|
||||
String name = n.substring(0, idx);
|
||||
tests.add(name);
|
||||
}
|
||||
|
||||
List<String> l = new ArrayList<>(tests);
|
||||
Collections.sort(l);
|
||||
assertFalse(l.isEmpty());
|
||||
|
||||
for(String s : l){
|
||||
String s2 = s.replaceAll("[a-zA-Z]", "");
|
||||
String[] nums = s2.split("_");
|
||||
int mb = Integer.parseInt(nums[0]);
|
||||
int k = Integer.parseInt(nums[1]);
|
||||
int size = Integer.parseInt(nums[2]);
|
||||
int stride = Integer.parseInt(nums[3]);
|
||||
boolean same = s.contains("same");
|
||||
int d = Integer.parseInt(nums[5]);
|
||||
boolean nchw = s.contains("nchw");
|
||||
|
||||
INDArray w = Nd4j.readNpy(new File(f, s + "_W.npy"));
|
||||
INDArray b = Nd4j.readNpy(new File(f, s + "_b.npy"));
|
||||
INDArray in = Nd4j.readNpy(new File(f, s + "_in.npy")).castTo(DataType.FLOAT);
|
||||
INDArray expOut = Nd4j.readNpy(new File(f, s + "_out.npy"));
|
||||
|
||||
CustomOp op = DynamicCustomOp.builder("deconv2d")
|
||||
.addInputs(in, w, b)
|
||||
.addIntegerArguments(
|
||||
k, k,
|
||||
stride, stride,
|
||||
0, 0, //padding
|
||||
d, d,
|
||||
same ? 1 : 0,
|
||||
nchw ? 0 : 1)
|
||||
.callInplace(false)
|
||||
.build();
|
||||
INDArray out = Nd4j.create(op.calculateOutputShape().get(0));
|
||||
out.assign(Double.NaN);
|
||||
op.addOutputArgument(out);
|
||||
Nd4j.exec(op);
|
||||
|
||||
boolean eq = expOut.equalsWithEps(out, 1e-4);
|
||||
assertTrue(eq);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -21,8 +21,6 @@ import org.nd4j.config.ND4JEnvironmentVars;
|
|||
import org.nd4j.config.ND4JSystemProperties;
|
||||
import org.nd4j.context.Nd4jContext;
|
||||
import org.nd4j.linalg.io.Resource;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
@ -152,6 +150,9 @@ public abstract class Nd4jBackend {
|
|||
*/
|
||||
public static Nd4jBackend load() throws NoAvailableBackendException {
|
||||
|
||||
String logInitProperty = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION, "true");
|
||||
boolean logInit = Boolean.parseBoolean(logInitProperty);
|
||||
|
||||
List<Nd4jBackend> backends = new ArrayList<>(1);
|
||||
ServiceLoader<Nd4jBackend> loader = ServiceLoader.load(Nd4jBackend.class);
|
||||
try {
|
||||
|
@ -183,7 +184,9 @@ public abstract class Nd4jBackend {
|
|||
error = e.getMessage();
|
||||
}
|
||||
if (!available) {
|
||||
if(logInit) {
|
||||
log.warn("Skipped [{}] backend (unavailable): {}", backend.getClass().getSimpleName(), error);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -193,7 +196,9 @@ public abstract class Nd4jBackend {
|
|||
e.printStackTrace();
|
||||
}
|
||||
|
||||
if(logInit) {
|
||||
log.info("Loaded [{}] backend", backend.getClass().getSimpleName());
|
||||
}
|
||||
return backend;
|
||||
}
|
||||
|
||||
|
@ -273,6 +278,8 @@ public abstract class Nd4jBackend {
|
|||
return getClass().getName();
|
||||
}
|
||||
|
||||
public abstract void logBackendInit();
|
||||
|
||||
|
||||
@SuppressWarnings("serial")
|
||||
public static class NoAvailableBackendException extends Exception {
|
||||
|
|
|
@ -200,7 +200,7 @@ public class SameDiffServlet<I, O> implements ModelServingServlet<I, O> {
|
|||
map.put(n, mds.getFeatures(cnt++));
|
||||
}
|
||||
|
||||
val output = sdModel.exec(map, orderedOutputNodes);
|
||||
val output = sdModel.output(map, orderedOutputNodes);
|
||||
val arrays = new INDArray[output.size()];
|
||||
|
||||
// now we need to get ordered output arrays, as specified in server constructor
|
||||
|
|
Loading…
Reference in New Issue