Optimization / fix for DL4J SameDiff layers (#156)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-23 20:54:24 +10:00 committed by GitHub
parent fb8de5006f
commit 70ee8ba91d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 100 additions and 55 deletions

View File

@ -99,25 +99,28 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
doInit();
}
Map<String,INDArray> phMap = new HashMap<>();
config.validateInput(inputs);
for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i);
final String maskName = name + "_mask";
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name));
phMap.put(name, inputs[i]);
if(maskArrays != null && maskArrays[i] != null) {
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName);
phMap.put(maskName, maskArrays[i]);
}else{
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName);
phMap.put(maskName, createMask(dataType, inputs[i].shape()));
}
}
if(paramTable != null && paramTable.size() > 0) {
for (String s : paramTable.keySet()) {
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
}
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
INDArray result = out.get(outputKey);
INDArray result = sameDiff.outputSingle(phMap, outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
@ -136,10 +139,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
doInit();
}
List<String> inputNames = config.getVertexParams().getInputs();
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
List<String> inputs = config.getVertexParams().getInputs();
String[] inArr = inputs.toArray(new String[inputs.size()]);
String[] inArr = inputNames.toArray(new String[inputNames.size()]);
sameDiff.createGradFunction(inArr);
}
config.validateInput(inputs);
@ -149,25 +152,31 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
for(String s : inputs){
phMap.put(s, this.inputs[i++]);
}
if(maskArrays != null){
for( int j=0; j<maskArrays.length; j++ ){
String name = inputs.get(j);
final String maskName = name + "_mask";
if(maskArrays[j] != null) {
sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName);
}
for( int j=0; j<this.inputs.length; j++ ){
String name = inputs.get(j);
final String maskName = name + "_mask";
if(maskArrays != null && maskArrays[j] != null) {
phMap.put(maskName, maskArrays[j]);
}else{
phMap.put(maskName, createMask(dataType, this.inputs[j].shape()));
}
}
String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon);
for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
sameDiff.execBackwards(phMap);
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for(String s : inputNames){
required.add(sameDiff.getVariable(s).gradient().getVarName());
}
sameDiff.execBackwards(phMap, required);
for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr();
INDArray dl4jGrad = gradTable.get(s);
@ -176,9 +185,17 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
}
dLdIns = new INDArray[inputs.size()];
String fnName = fn.getGradPlaceholderName();
for(int j=0; j<inputs.size(); j++ ){
String name = inputs.get(j);
dLdIns[j] = sameDiff.grad(name).getArr();
String gradName = sameDiff.grad(inputNames.get(j)).getVarName();
if(dLdIns[j] == null && fnName.equals(gradName)){
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIns[j] = epsilon;
}
}
}
@ -218,9 +235,13 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
int i=0;
for(String s : config.getVertexParams().getInputs()){
val inputShape = inputs[i++].shape().clone();
SDVariable inputVar = sameDiff.var(s, dataType, inputShape);
INDArray maskTemp = createMask(dataType, inputShape);
inputShape[0] = -1;
SDVariable inputVar = sameDiff.placeHolder(s, dataType, inputShape);
inputVars.put(s, inputVar);
SDVariable maskVar = sameDiff.constant(s + "_mask", createMask(dataType, inputShape));
long[] maskShape = maskTemp.shape().clone();
maskShape[0] = -1;
SDVariable maskVar = sameDiff.placeHolder(s + "_mask", maskTemp.dataType(), maskShape);
maskVars.put(s, maskVar);
}

View File

@ -96,11 +96,14 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
}
for(String s : paramTable.keySet() ) {
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> out = sameDiff.exec(phMap, outputKey);
Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
@ -131,9 +134,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();

View File

@ -99,8 +99,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
doInit();
}
for(String s : paramTable.keySet() ) {
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();
@ -111,7 +114,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
INDArray out = sameDiff.execSingle(phMap, s);
INDArray out = sameDiff.outputSingle(phMap, s);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
@ -149,20 +152,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.createGradFunction(INPUT_KEY);
}
INDArray castInput = input.castTo(dataType);
if(castInput.isAttached())
castInput = castInput.dup();
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
if(layerConf().labelsRequired()) {
INDArray castLabels = labels.castTo(dataType);
if(castLabels.isAttached())
castLabels = castLabels.dup();
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
}
for(String s : paramTable.keySet() ){
//TODO this should only be necessary, in theory, once!
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
List<String> gradVarNames = new ArrayList<>();
@ -297,8 +291,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
outputVar = layerOutput;
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
for (Map.Entry<String, INDArray> e : p.entrySet()) {
sameDiff.associateArrayWithVariable(e.getValue(), sameDiff.getVariable(e.getKey()));
INDArray arr = e.getValue();
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
}
this.outputKey = layerOutput.getVarName();

View File

@ -807,9 +807,9 @@ public class SameDiff extends SDBaseOps {
SDVariable v = getVariable(varName);
if (v.isConstant()) {
constantArrays.put(varName, new DeviceLocalNDArray(arr));
constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
} else if (v.getVariableType() == VariableType.VARIABLE) {
variablesArrays.put(varName, new DeviceLocalNDArray(arr));
variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
} else if (v.isPlaceHolder()) {
long tid = Thread.currentThread().getId();
if (!placeholdersPerThread.containsKey(tid)) {
@ -1033,10 +1033,10 @@ public class SameDiff extends SDBaseOps {
switch (variable.getVariableType()) {
case VARIABLE:
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break;
case CONSTANT:
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
break;
case ARRAY:
// FIXME: remove this before release
@ -1077,6 +1077,29 @@ public class SameDiff extends SDBaseOps {
}
}
/**
* Update the constant or variable type SDVariable with the values from the specified
* array. Note that unlike {@link #associateArrayWithVariable(INDArray, String)} this method will take the
* values from the argument array and assign it to the current array.
* The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance.
* @param arr Array values to set
* @param variable Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable
*/
public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){
Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT,
"assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.getVarName(), variable.getVariableType());
//DeviceLocal doesn't work with views
if(arr.isView())
arr = arr.dup();
if(variable.getVariableType() == VariableType.VARIABLE ){
variablesArrays.get(variable.getVarName()).update(arr);
} else {
constantArrays.get(variable.getVarName()).update(arr);
}
}
/**
* Associate a {@link SameDiff} namespace as a sub function.
@ -3256,7 +3279,7 @@ public class SameDiff extends SDBaseOps {
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
name = v.getVarName();
variables.put(name, Variable.builder().name(name).variable(v).build());
constantArrays.put(name, new DeviceLocalNDArray(constant));
constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
return v;
}
@ -3630,7 +3653,7 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
constantArrays.put(n, new DeviceLocalNDArray(arr));
constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
variablesArrays.remove(n);
if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
@ -3728,7 +3751,7 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
variablesArrays.put(n, new DeviceLocalNDArray(arr));
variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
constantArrays.remove(n);
if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
@ -3807,13 +3830,13 @@ public class SameDiff extends SDBaseOps {
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
INDArray arr = dl.get();
INDArray newArr = arr.castTo(d);
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr));
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break;
case CONSTANT:
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
INDArray arr2 = dl2.get();
INDArray newArr2 = arr2.castTo(d);
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2));
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break;
case PLACEHOLDER:
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());