Optimization / fix for DL4J SameDiff layers (#156)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-23 20:54:24 +10:00 committed by GitHub
parent fb8de5006f
commit 70ee8ba91d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 55 deletions

View File

@ -99,25 +99,28 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
doInit(); doInit();
} }
Map<String,INDArray> phMap = new HashMap<>();
config.validateInput(inputs); config.validateInput(inputs);
for(int i=0; i<inputs.length; i++ ){ for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i); String name = config.getVertexParams().getInputs().get(i);
final String maskName = name + "_mask"; final String maskName = name + "_mask";
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name)); phMap.put(name, inputs[i]);
if(maskArrays != null && maskArrays[i] != null) { if(maskArrays != null && maskArrays[i] != null) {
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName); phMap.put(maskName, maskArrays[i]);
}else{ }else{
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName); phMap.put(maskName, createMask(dataType, inputs[i].shape()));
} }
} }
if(paramTable != null && paramTable.size() > 0) { if(paramTable != null && paramTable.size() > 0) {
for (String s : paramTable.keySet()) { //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
sameDiff.associateArrayWithVariable(paramTable.get(s), s); //TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
} }
Map<String,INDArray> out = sameDiff.exec(null, outputKey); INDArray result = sameDiff.outputSingle(phMap, outputKey);
INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
@ -136,10 +139,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
doInit(); doInit();
} }
List<String> inputNames = config.getVertexParams().getInputs();
if(!sameDiff.hasGradientFunction()) { if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS //Create when scoped out, to ensure any arrays are not in WS
List<String> inputs = config.getVertexParams().getInputs(); String[] inArr = inputNames.toArray(new String[inputNames.size()]);
String[] inArr = inputs.toArray(new String[inputs.size()]);
sameDiff.createGradFunction(inArr); sameDiff.createGradFunction(inArr);
} }
config.validateInput(inputs); config.validateInput(inputs);
@ -149,25 +152,31 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
for(String s : inputs){ for(String s : inputs){
phMap.put(s, this.inputs[i++]); phMap.put(s, this.inputs[i++]);
} }
if(maskArrays != null){ for( int j=0; j<this.inputs.length; j++ ){
for( int j=0; j<maskArrays.length; j++ ){
String name = inputs.get(j); String name = inputs.get(j);
final String maskName = name + "_mask"; final String maskName = name + "_mask";
if(maskArrays[j] != null) { if(maskArrays != null && maskArrays[j] != null) {
sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName); phMap.put(maskName, maskArrays[j]);
} }else{
phMap.put(maskName, createMask(dataType, this.inputs[j].shape()));
} }
} }
String epsName = fn.getGradPlaceholderName(); String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon); phMap.put(epsName, epsilon);
for(String s : paramTable.keySet() ){ //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO this should only be necessary, in theory, once! //TODO Find a more efficient solution for this
sameDiff.associateArrayWithVariable(paramTable.get(s), s); for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
sameDiff.execBackwards(phMap); List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for(String s : inputNames){
required.add(sameDiff.getVariable(s).gradient().getVarName());
}
sameDiff.execBackwards(phMap, required);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
@ -176,9 +185,17 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
} }
dLdIns = new INDArray[inputs.size()]; dLdIns = new INDArray[inputs.size()];
String fnName = fn.getGradPlaceholderName();
for(int j=0; j<inputs.size(); j++ ){ for(int j=0; j<inputs.size(); j++ ){
String name = inputs.get(j); String name = inputs.get(j);
dLdIns[j] = sameDiff.grad(name).getArr(); dLdIns[j] = sameDiff.grad(name).getArr();
String gradName = sameDiff.grad(inputNames.get(j)).getVarName();
if(dLdIns[j] == null && fnName.equals(gradName)){
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIns[j] = epsilon;
}
} }
} }
@ -218,9 +235,13 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
int i=0; int i=0;
for(String s : config.getVertexParams().getInputs()){ for(String s : config.getVertexParams().getInputs()){
val inputShape = inputs[i++].shape().clone(); val inputShape = inputs[i++].shape().clone();
SDVariable inputVar = sameDiff.var(s, dataType, inputShape); INDArray maskTemp = createMask(dataType, inputShape);
inputShape[0] = -1;
SDVariable inputVar = sameDiff.placeHolder(s, dataType, inputShape);
inputVars.put(s, inputVar); inputVars.put(s, inputVar);
SDVariable maskVar = sameDiff.constant(s + "_mask", createMask(dataType, inputShape)); long[] maskShape = maskTemp.shape().clone();
maskShape[0] = -1;
SDVariable maskVar = sameDiff.placeHolder(s + "_mask", maskTemp.dataType(), maskShape);
maskVars.put(s, maskVar); maskVars.put(s, maskVar);
} }

View File

@ -96,11 +96,14 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
} }
for(String s : paramTable.keySet() ) { //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
sameDiff.associateArrayWithVariable(paramTable.get(s), s); //TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
Map<String,INDArray> out = sameDiff.exec(phMap, outputKey); Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
INDArray result = out.get(outputKey); INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
@ -131,9 +134,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input); bl.validateInput(input);
for(String s : paramTable.keySet() ){ //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO this should only be necessary, in theory, once! //TODO Find a more efficient solution for this
sameDiff.associateArrayWithVariable(paramTable.get(s), s); for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();

View File

@ -99,8 +99,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
doInit(); doInit();
} }
for(String s : paramTable.keySet() ) { //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
sameDiff.associateArrayWithVariable(paramTable.get(s), s); //TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
Map<String,INDArray> phMap = new HashMap<>(); Map<String,INDArray> phMap = new HashMap<>();
@ -111,7 +114,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName(); String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
INDArray out = sameDiff.execSingle(phMap, s); INDArray out = sameDiff.outputSingle(phMap, s);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
@ -149,20 +152,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.createGradFunction(INPUT_KEY); sameDiff.createGradFunction(INPUT_KEY);
} }
INDArray castInput = input.castTo(dataType); //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
if(castInput.isAttached()) //TODO Find a more efficient solution for this
castInput = castInput.dup(); for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY)); INDArray arr = e.getValue();
if(layerConf().labelsRequired()) { sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
INDArray castLabels = labels.castTo(dataType);
if(castLabels.isAttached())
castLabels = castLabels.dup();
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
}
for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
} }
List<String> gradVarNames = new ArrayList<>(); List<String> gradVarNames = new ArrayList<>();
@ -297,8 +291,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null"); Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
outputVar = layerOutput; outputVar = layerOutput;
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
for (Map.Entry<String, INDArray> e : p.entrySet()) { for (Map.Entry<String, INDArray> e : p.entrySet()) {
sameDiff.associateArrayWithVariable(e.getValue(), sameDiff.getVariable(e.getKey())); INDArray arr = e.getValue();
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
} }
this.outputKey = layerOutput.getVarName(); this.outputKey = layerOutput.getVarName();

View File

@ -807,9 +807,9 @@ public class SameDiff extends SDBaseOps {
SDVariable v = getVariable(varName); SDVariable v = getVariable(varName);
if (v.isConstant()) { if (v.isConstant()) {
constantArrays.put(varName, new DeviceLocalNDArray(arr)); constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
} else if (v.getVariableType() == VariableType.VARIABLE) { } else if (v.getVariableType() == VariableType.VARIABLE) {
variablesArrays.put(varName, new DeviceLocalNDArray(arr)); variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
} else if (v.isPlaceHolder()) { } else if (v.isPlaceHolder()) {
long tid = Thread.currentThread().getId(); long tid = Thread.currentThread().getId();
if (!placeholdersPerThread.containsKey(tid)) { if (!placeholdersPerThread.containsKey(tid)) {
@ -1033,10 +1033,10 @@ public class SameDiff extends SDBaseOps {
switch (variable.getVariableType()) { switch (variable.getVariableType()) {
case VARIABLE: case VARIABLE:
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr)); variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break; break;
case CONSTANT: case CONSTANT:
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr)); constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
break; break;
case ARRAY: case ARRAY:
// FIXME: remove this before release // FIXME: remove this before release
@ -1077,6 +1077,29 @@ public class SameDiff extends SDBaseOps {
} }
} }
/**
* Update the constant or variable type SDVariable with the values from the specified
* array. Note that unlike {@link #associateArrayWithVariable(INDArray, String)} this method will take the
* values from the argument array and assign it to the current array.
* The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance.
* @param arr Array values to set
* @param variable Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable
*/
public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){
Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT,
"assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.getVarName(), variable.getVariableType());
//DeviceLocal doesn't work with views
if(arr.isView())
arr = arr.dup();
if(variable.getVariableType() == VariableType.VARIABLE ){
variablesArrays.get(variable.getVarName()).update(arr);
} else {
constantArrays.get(variable.getVarName()).update(arr);
}
}
/** /**
* Associate a {@link SameDiff} namespace as a sub function. * Associate a {@link SameDiff} namespace as a sub function.
@ -3256,7 +3279,7 @@ public class SameDiff extends SDBaseOps {
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null); SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
name = v.getVarName(); name = v.getVarName();
variables.put(name, Variable.builder().name(name).variable(v).build()); variables.put(name, Variable.builder().name(name).variable(v).build());
constantArrays.put(name, new DeviceLocalNDArray(constant)); constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
return v; return v;
} }
@ -3630,7 +3653,7 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr(); INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
constantArrays.put(n, new DeviceLocalNDArray(arr)); constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
variablesArrays.remove(n); variablesArrays.remove(n);
if (!placeholdersPerThread.isEmpty()) { if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) { for (Map<String, INDArray> m : placeholdersPerThread.values()) {
@ -3728,7 +3751,7 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr(); INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
variablesArrays.put(n, new DeviceLocalNDArray(arr)); variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
constantArrays.remove(n); constantArrays.remove(n);
if (!placeholdersPerThread.isEmpty()) { if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) { for (Map<String, INDArray> m : placeholdersPerThread.values()) {
@ -3807,13 +3830,13 @@ public class SameDiff extends SDBaseOps {
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey()); DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
INDArray arr = dl.get(); INDArray arr = dl.get();
INDArray newArr = arr.castTo(d); INDArray newArr = arr.castTo(d);
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr)); variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break; break;
case CONSTANT: case CONSTANT:
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey()); DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
INDArray arr2 = dl2.get(); INDArray arr2 = dl2.get();
INDArray newArr2 = arr2.castTo(d); INDArray newArr2 = arr2.castTo(d);
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2)); constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break; break;
case PLACEHOLDER: case PLACEHOLDER:
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId()); Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());