From 70ee8ba91d5002f1868b85001f465a511ca6944f Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 23 Aug 2019 20:54:24 +1000 Subject: [PATCH] Optimization / fix for DL4J SameDiff layers (#156) Signed-off-by: AlexDBlack --- .../layers/samediff/SameDiffGraphVertex.java | 65 ++++++++++++------- .../nn/layers/samediff/SameDiffLayer.java | 17 +++-- .../layers/samediff/SameDiffOutputLayer.java | 32 ++++----- .../org/nd4j/autodiff/samediff/SameDiff.java | 41 +++++++++--- 4 files changed, 100 insertions(+), 55 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index 40fa6aaa2..f1f4b536d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -99,25 +99,28 @@ public class SameDiffGraphVertex extends BaseGraphVertex { doInit(); } + Map phMap = new HashMap<>(); config.validateInput(inputs); for(int i=0; i 0) { - for (String s : paramTable.keySet()) { - sameDiff.associateArrayWithVariable(paramTable.get(s), s); + //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration + //TODO Find a more efficient solution for this + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } } - Map out = sameDiff.exec(null, outputKey); - INDArray result = out.get(outputKey); + INDArray result = sameDiff.outputSingle(phMap, outputKey); //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere sameDiff.clearPlaceholders(true); @@ -136,10 +139,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex { doInit(); } + List inputNames = config.getVertexParams().getInputs(); if(!sameDiff.hasGradientFunction()) { //Create when scoped out, to ensure any arrays are not in WS - List inputs = config.getVertexParams().getInputs(); - String[] inArr = inputs.toArray(new String[inputs.size()]); + String[] inArr = inputNames.toArray(new String[inputNames.size()]); sameDiff.createGradFunction(inArr); } config.validateInput(inputs); @@ -149,25 +152,31 @@ public class SameDiffGraphVertex extends BaseGraphVertex { for(String s : inputs){ phMap.put(s, this.inputs[i++]); } - if(maskArrays != null){ - for( int j=0; j e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } - sameDiff.execBackwards(phMap); + List required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated + for(String s : inputNames){ + required.add(sameDiff.getVariable(s).gradient().getVarName()); + } + sameDiff.execBackwards(phMap, required); for(String s : paramTable.keySet() ){ INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray dl4jGrad = gradTable.get(s); @@ -176,9 +185,17 @@ public class SameDiffGraphVertex extends BaseGraphVertex { } dLdIns = new INDArray[inputs.size()]; + String fnName = fn.getGradPlaceholderName(); for(int j=0; j { phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); } - for(String s : paramTable.keySet() ) { - sameDiff.associateArrayWithVariable(paramTable.get(s), s); + //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration + //TODO Find a more efficient solution for this + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } - Map out = sameDiff.exec(phMap, outputKey); + Map out = sameDiff.output(phMap, outputKey); INDArray result = out.get(outputKey); //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere @@ -131,9 +134,11 @@ public class SameDiffLayer extends AbstractLayer { org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); bl.validateInput(input); - for(String s : paramTable.keySet() ){ - //TODO this should only be necessary, in theory, once! - sameDiff.associateArrayWithVariable(paramTable.get(s), s); + //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration + //TODO Find a more efficient solution for this + for (Map.Entry e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } Map phMap = new HashMap<>(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index 29b58628f..e5ca125cd 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -99,8 +99,11 @@ public class SameDiffOutputLayer extends AbstractLayer e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } Map phMap = new HashMap<>(); @@ -111,7 +114,7 @@ public class SameDiffOutputLayer extends AbstractLayer e : paramTable.entrySet()) { + INDArray arr = e.getValue(); + sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); } List gradVarNames = new ArrayList<>(); @@ -297,8 +291,10 @@ public class SameDiffOutputLayer extends AbstractLayer e : p.entrySet()) { - sameDiff.associateArrayWithVariable(e.getValue(), sameDiff.getVariable(e.getKey())); + INDArray arr = e.getValue(); + sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey())); } this.outputKey = layerOutput.getVarName(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 8fee57dad..e6f30d12e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -807,9 +807,9 @@ public class SameDiff extends SDBaseOps { SDVariable v = getVariable(varName); if (v.isConstant()) { - constantArrays.put(varName, new DeviceLocalNDArray(arr)); + constantArrays.put(varName, new DeviceLocalNDArray(arr, true)); } else if (v.getVariableType() == VariableType.VARIABLE) { - variablesArrays.put(varName, new DeviceLocalNDArray(arr)); + variablesArrays.put(varName, new DeviceLocalNDArray(arr, true)); } else if (v.isPlaceHolder()) { long tid = Thread.currentThread().getId(); if (!placeholdersPerThread.containsKey(tid)) { @@ -1033,10 +1033,10 @@ public class SameDiff extends SDBaseOps { switch (variable.getVariableType()) { case VARIABLE: - variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr)); + variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case CONSTANT: - constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr)); + constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); break; case ARRAY: // FIXME: remove this before release @@ -1077,6 +1077,29 @@ public class SameDiff extends SDBaseOps { } } + /** + * Update the constant or variable type SDVariable with the values from the specified + * array. Note that unlike {@link #associateArrayWithVariable(INDArray, String)} this method will take the + * values from the argument array and assign it to the current array. + * The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance. + * @param arr Array values to set + * @param variable Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable + */ + public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){ + Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT, + "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.getVarName(), variable.getVariableType()); + + //DeviceLocal doesn't work with views + if(arr.isView()) + arr = arr.dup(); + + if(variable.getVariableType() == VariableType.VARIABLE ){ + variablesArrays.get(variable.getVarName()).update(arr); + } else { + constantArrays.get(variable.getVarName()).update(arr); + } + } + /** * Associate a {@link SameDiff} namespace as a sub function. @@ -3256,7 +3279,7 @@ public class SameDiff extends SDBaseOps { SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null); name = v.getVarName(); variables.put(name, Variable.builder().name(name).variable(v).build()); - constantArrays.put(name, new DeviceLocalNDArray(constant)); + constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads return v; } @@ -3630,7 +3653,7 @@ public class SameDiff extends SDBaseOps { INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); - constantArrays.put(n, new DeviceLocalNDArray(arr)); + constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads variablesArrays.remove(n); if (!placeholdersPerThread.isEmpty()) { for (Map m : placeholdersPerThread.values()) { @@ -3728,7 +3751,7 @@ public class SameDiff extends SDBaseOps { INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); - variablesArrays.put(n, new DeviceLocalNDArray(arr)); + variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads constantArrays.remove(n); if (!placeholdersPerThread.isEmpty()) { for (Map m : placeholdersPerThread.values()) { @@ -3807,13 +3830,13 @@ public class SameDiff extends SDBaseOps { DeviceLocalNDArray dl = variablesArrays.remove(e.getKey()); INDArray arr = dl.get(); INDArray newArr = arr.castTo(d); - variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr)); + variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case CONSTANT: DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey()); INDArray arr2 = dl2.get(); INDArray newArr2 = arr2.castTo(d); - constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2)); + constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case PLACEHOLDER: Map m = placeholdersPerThread.get(Thread.currentThread().getId());