Optimization / fix for DL4J SameDiff layers (#156)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
fb8de5006f
commit
70ee8ba91d
|
@ -99,25 +99,28 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
doInit();
|
||||
}
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
config.validateInput(inputs);
|
||||
for(int i=0; i<inputs.length; i++ ){
|
||||
String name = config.getVertexParams().getInputs().get(i);
|
||||
final String maskName = name + "_mask";
|
||||
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name));
|
||||
phMap.put(name, inputs[i]);
|
||||
if(maskArrays != null && maskArrays[i] != null) {
|
||||
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName);
|
||||
phMap.put(maskName, maskArrays[i]);
|
||||
}else{
|
||||
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName);
|
||||
phMap.put(maskName, createMask(dataType, inputs[i].shape()));
|
||||
}
|
||||
}
|
||||
|
||||
if(paramTable != null && paramTable.size() > 0) {
|
||||
for (String s : paramTable.keySet()) {
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
}
|
||||
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
|
||||
INDArray result = out.get(outputKey);
|
||||
INDArray result = sameDiff.outputSingle(phMap, outputKey);
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
|
@ -136,10 +139,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
doInit();
|
||||
}
|
||||
|
||||
List<String> inputNames = config.getVertexParams().getInputs();
|
||||
if(!sameDiff.hasGradientFunction()) {
|
||||
//Create when scoped out, to ensure any arrays are not in WS
|
||||
List<String> inputs = config.getVertexParams().getInputs();
|
||||
String[] inArr = inputs.toArray(new String[inputs.size()]);
|
||||
String[] inArr = inputNames.toArray(new String[inputNames.size()]);
|
||||
sameDiff.createGradFunction(inArr);
|
||||
}
|
||||
config.validateInput(inputs);
|
||||
|
@ -149,25 +152,31 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
for(String s : inputs){
|
||||
phMap.put(s, this.inputs[i++]);
|
||||
}
|
||||
if(maskArrays != null){
|
||||
for( int j=0; j<maskArrays.length; j++ ){
|
||||
String name = inputs.get(j);
|
||||
final String maskName = name + "_mask";
|
||||
if(maskArrays[j] != null) {
|
||||
sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName);
|
||||
}
|
||||
for( int j=0; j<this.inputs.length; j++ ){
|
||||
String name = inputs.get(j);
|
||||
final String maskName = name + "_mask";
|
||||
if(maskArrays != null && maskArrays[j] != null) {
|
||||
phMap.put(maskName, maskArrays[j]);
|
||||
}else{
|
||||
phMap.put(maskName, createMask(dataType, this.inputs[j].shape()));
|
||||
}
|
||||
}
|
||||
String epsName = fn.getGradPlaceholderName();
|
||||
phMap.put(epsName, epsilon);
|
||||
|
||||
|
||||
for(String s : paramTable.keySet() ){
|
||||
//TODO this should only be necessary, in theory, once!
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
sameDiff.execBackwards(phMap);
|
||||
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||
for(String s : inputNames){
|
||||
required.add(sameDiff.getVariable(s).gradient().getVarName());
|
||||
}
|
||||
sameDiff.execBackwards(phMap, required);
|
||||
for(String s : paramTable.keySet() ){
|
||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||
INDArray dl4jGrad = gradTable.get(s);
|
||||
|
@ -176,9 +185,17 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
}
|
||||
|
||||
dLdIns = new INDArray[inputs.size()];
|
||||
String fnName = fn.getGradPlaceholderName();
|
||||
for(int j=0; j<inputs.size(); j++ ){
|
||||
String name = inputs.get(j);
|
||||
dLdIns[j] = sameDiff.grad(name).getArr();
|
||||
|
||||
String gradName = sameDiff.grad(inputNames.get(j)).getVarName();
|
||||
if(dLdIns[j] == null && fnName.equals(gradName)){
|
||||
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
||||
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||
dLdIns[j] = epsilon;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -218,9 +235,13 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
int i=0;
|
||||
for(String s : config.getVertexParams().getInputs()){
|
||||
val inputShape = inputs[i++].shape().clone();
|
||||
SDVariable inputVar = sameDiff.var(s, dataType, inputShape);
|
||||
INDArray maskTemp = createMask(dataType, inputShape);
|
||||
inputShape[0] = -1;
|
||||
SDVariable inputVar = sameDiff.placeHolder(s, dataType, inputShape);
|
||||
inputVars.put(s, inputVar);
|
||||
SDVariable maskVar = sameDiff.constant(s + "_mask", createMask(dataType, inputShape));
|
||||
long[] maskShape = maskTemp.shape().clone();
|
||||
maskShape[0] = -1;
|
||||
SDVariable maskVar = sameDiff.placeHolder(s + "_mask", maskTemp.dataType(), maskShape);
|
||||
maskVars.put(s, maskVar);
|
||||
}
|
||||
|
||||
|
|
|
@ -96,11 +96,14 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
|
||||
}
|
||||
|
||||
for(String s : paramTable.keySet() ) {
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
Map<String,INDArray> out = sameDiff.exec(phMap, outputKey);
|
||||
Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
|
||||
INDArray result = out.get(outputKey);
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
|
@ -131,9 +134,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||
bl.validateInput(input);
|
||||
|
||||
for(String s : paramTable.keySet() ){
|
||||
//TODO this should only be necessary, in theory, once!
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
|
|
|
@ -99,8 +99,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
doInit();
|
||||
}
|
||||
|
||||
for(String s : paramTable.keySet() ) {
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
|
@ -111,7 +114,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
|
||||
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
|
||||
|
||||
INDArray out = sameDiff.execSingle(phMap, s);
|
||||
INDArray out = sameDiff.outputSingle(phMap, s);
|
||||
|
||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||
sameDiff.clearPlaceholders(true);
|
||||
|
@ -149,20 +152,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
sameDiff.createGradFunction(INPUT_KEY);
|
||||
}
|
||||
|
||||
INDArray castInput = input.castTo(dataType);
|
||||
if(castInput.isAttached())
|
||||
castInput = castInput.dup();
|
||||
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
|
||||
if(layerConf().labelsRequired()) {
|
||||
INDArray castLabels = labels.castTo(dataType);
|
||||
if(castLabels.isAttached())
|
||||
castLabels = castLabels.dup();
|
||||
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
|
||||
}
|
||||
|
||||
for(String s : paramTable.keySet() ){
|
||||
//TODO this should only be necessary, in theory, once!
|
||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
List<String> gradVarNames = new ArrayList<>();
|
||||
|
@ -297,8 +291,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
||||
outputVar = layerOutput;
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
for (Map.Entry<String, INDArray> e : p.entrySet()) {
|
||||
sameDiff.associateArrayWithVariable(e.getValue(), sameDiff.getVariable(e.getKey()));
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
this.outputKey = layerOutput.getVarName();
|
||||
|
|
|
@ -807,9 +807,9 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
SDVariable v = getVariable(varName);
|
||||
if (v.isConstant()) {
|
||||
constantArrays.put(varName, new DeviceLocalNDArray(arr));
|
||||
constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||
variablesArrays.put(varName, new DeviceLocalNDArray(arr));
|
||||
variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||
} else if (v.isPlaceHolder()) {
|
||||
long tid = Thread.currentThread().getId();
|
||||
if (!placeholdersPerThread.containsKey(tid)) {
|
||||
|
@ -1033,10 +1033,10 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
switch (variable.getVariableType()) {
|
||||
case VARIABLE:
|
||||
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
|
||||
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
break;
|
||||
case CONSTANT:
|
||||
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
|
||||
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
|
||||
break;
|
||||
case ARRAY:
|
||||
// FIXME: remove this before release
|
||||
|
@ -1077,6 +1077,29 @@ public class SameDiff extends SDBaseOps {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the constant or variable type SDVariable with the values from the specified
|
||||
* array. Note that unlike {@link #associateArrayWithVariable(INDArray, String)} this method will take the
|
||||
* values from the argument array and assign it to the current array.
|
||||
* The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance.
|
||||
* @param arr Array values to set
|
||||
* @param variable Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable
|
||||
*/
|
||||
public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){
|
||||
Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT,
|
||||
"assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.getVarName(), variable.getVariableType());
|
||||
|
||||
//DeviceLocal doesn't work with views
|
||||
if(arr.isView())
|
||||
arr = arr.dup();
|
||||
|
||||
if(variable.getVariableType() == VariableType.VARIABLE ){
|
||||
variablesArrays.get(variable.getVarName()).update(arr);
|
||||
} else {
|
||||
constantArrays.get(variable.getVarName()).update(arr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Associate a {@link SameDiff} namespace as a sub function.
|
||||
|
@ -3256,7 +3279,7 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
|
||||
name = v.getVarName();
|
||||
variables.put(name, Variable.builder().name(name).variable(v).build());
|
||||
constantArrays.put(name, new DeviceLocalNDArray(constant));
|
||||
constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
return v;
|
||||
}
|
||||
|
||||
|
@ -3630,7 +3653,7 @@ public class SameDiff extends SDBaseOps {
|
|||
INDArray arr = variable.getArr();
|
||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||
|
||||
constantArrays.put(n, new DeviceLocalNDArray(arr));
|
||||
constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.remove(n);
|
||||
if (!placeholdersPerThread.isEmpty()) {
|
||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||
|
@ -3728,7 +3751,7 @@ public class SameDiff extends SDBaseOps {
|
|||
INDArray arr = variable.getArr();
|
||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||
|
||||
variablesArrays.put(n, new DeviceLocalNDArray(arr));
|
||||
variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.remove(n);
|
||||
if (!placeholdersPerThread.isEmpty()) {
|
||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||
|
@ -3807,13 +3830,13 @@ public class SameDiff extends SDBaseOps {
|
|||
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
|
||||
INDArray arr = dl.get();
|
||||
INDArray newArr = arr.castTo(d);
|
||||
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr));
|
||||
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
break;
|
||||
case CONSTANT:
|
||||
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
|
||||
INDArray arr2 = dl2.get();
|
||||
INDArray newArr2 = arr2.castTo(d);
|
||||
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2));
|
||||
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
break;
|
||||
case PLACEHOLDER:
|
||||
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());
|
||||
|
|
Loading…
Reference in New Issue