Optimization / fix for DL4J SameDiff layers (#156)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
fb8de5006f
commit
70ee8ba91d
|
@ -99,25 +99,28 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
config.validateInput(inputs);
|
config.validateInput(inputs);
|
||||||
for(int i=0; i<inputs.length; i++ ){
|
for(int i=0; i<inputs.length; i++ ){
|
||||||
String name = config.getVertexParams().getInputs().get(i);
|
String name = config.getVertexParams().getInputs().get(i);
|
||||||
final String maskName = name + "_mask";
|
final String maskName = name + "_mask";
|
||||||
sameDiff.associateArrayWithVariable(inputs[i].dup(), sameDiff.getVariable(name));
|
phMap.put(name, inputs[i]);
|
||||||
if(maskArrays != null && maskArrays[i] != null) {
|
if(maskArrays != null && maskArrays[i] != null) {
|
||||||
sameDiff.associateArrayWithVariable(maskArrays[i].dup(), maskName);
|
phMap.put(maskName, maskArrays[i]);
|
||||||
}else{
|
}else{
|
||||||
sameDiff.associateArrayWithVariable(createMask(dataType, inputs[i].shape()), maskName);
|
phMap.put(maskName, createMask(dataType, inputs[i].shape()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(paramTable != null && paramTable.size() > 0) {
|
if(paramTable != null && paramTable.size() > 0) {
|
||||||
for (String s : paramTable.keySet()) {
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
//TODO Find a more efficient solution for this
|
||||||
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
|
INDArray arr = e.getValue();
|
||||||
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Map<String,INDArray> out = sameDiff.exec(null, outputKey);
|
INDArray result = sameDiff.outputSingle(phMap, outputKey);
|
||||||
INDArray result = out.get(outputKey);
|
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
|
@ -136,10 +139,10 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
List<String> inputNames = config.getVertexParams().getInputs();
|
||||||
if(!sameDiff.hasGradientFunction()) {
|
if(!sameDiff.hasGradientFunction()) {
|
||||||
//Create when scoped out, to ensure any arrays are not in WS
|
//Create when scoped out, to ensure any arrays are not in WS
|
||||||
List<String> inputs = config.getVertexParams().getInputs();
|
String[] inArr = inputNames.toArray(new String[inputNames.size()]);
|
||||||
String[] inArr = inputs.toArray(new String[inputs.size()]);
|
|
||||||
sameDiff.createGradFunction(inArr);
|
sameDiff.createGradFunction(inArr);
|
||||||
}
|
}
|
||||||
config.validateInput(inputs);
|
config.validateInput(inputs);
|
||||||
|
@ -149,25 +152,31 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
for(String s : inputs){
|
for(String s : inputs){
|
||||||
phMap.put(s, this.inputs[i++]);
|
phMap.put(s, this.inputs[i++]);
|
||||||
}
|
}
|
||||||
if(maskArrays != null){
|
for( int j=0; j<this.inputs.length; j++ ){
|
||||||
for( int j=0; j<maskArrays.length; j++ ){
|
String name = inputs.get(j);
|
||||||
String name = inputs.get(j);
|
final String maskName = name + "_mask";
|
||||||
final String maskName = name + "_mask";
|
if(maskArrays != null && maskArrays[j] != null) {
|
||||||
if(maskArrays[j] != null) {
|
phMap.put(maskName, maskArrays[j]);
|
||||||
sameDiff.associateArrayWithVariable(maskArrays[j].dup(), maskName);
|
}else{
|
||||||
}
|
phMap.put(maskName, createMask(dataType, this.inputs[j].shape()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
String epsName = fn.getGradPlaceholderName();
|
String epsName = fn.getGradPlaceholderName();
|
||||||
phMap.put(epsName, epsilon);
|
phMap.put(epsName, epsilon);
|
||||||
|
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ){
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
//TODO this should only be necessary, in theory, once!
|
//TODO Find a more efficient solution for this
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
|
INDArray arr = e.getValue();
|
||||||
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
sameDiff.execBackwards(phMap);
|
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||||
|
for(String s : inputNames){
|
||||||
|
required.add(sameDiff.getVariable(s).gradient().getVarName());
|
||||||
|
}
|
||||||
|
sameDiff.execBackwards(phMap, required);
|
||||||
for(String s : paramTable.keySet() ){
|
for(String s : paramTable.keySet() ){
|
||||||
INDArray sdGrad = sameDiff.grad(s).getArr();
|
INDArray sdGrad = sameDiff.grad(s).getArr();
|
||||||
INDArray dl4jGrad = gradTable.get(s);
|
INDArray dl4jGrad = gradTable.get(s);
|
||||||
|
@ -176,9 +185,17 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
}
|
}
|
||||||
|
|
||||||
dLdIns = new INDArray[inputs.size()];
|
dLdIns = new INDArray[inputs.size()];
|
||||||
|
String fnName = fn.getGradPlaceholderName();
|
||||||
for(int j=0; j<inputs.size(); j++ ){
|
for(int j=0; j<inputs.size(); j++ ){
|
||||||
String name = inputs.get(j);
|
String name = inputs.get(j);
|
||||||
dLdIns[j] = sameDiff.grad(name).getArr();
|
dLdIns[j] = sameDiff.grad(name).getArr();
|
||||||
|
|
||||||
|
String gradName = sameDiff.grad(inputNames.get(j)).getVarName();
|
||||||
|
if(dLdIns[j] == null && fnName.equals(gradName)){
|
||||||
|
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
|
||||||
|
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
|
||||||
|
dLdIns[j] = epsilon;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,9 +235,13 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
||||||
int i=0;
|
int i=0;
|
||||||
for(String s : config.getVertexParams().getInputs()){
|
for(String s : config.getVertexParams().getInputs()){
|
||||||
val inputShape = inputs[i++].shape().clone();
|
val inputShape = inputs[i++].shape().clone();
|
||||||
SDVariable inputVar = sameDiff.var(s, dataType, inputShape);
|
INDArray maskTemp = createMask(dataType, inputShape);
|
||||||
|
inputShape[0] = -1;
|
||||||
|
SDVariable inputVar = sameDiff.placeHolder(s, dataType, inputShape);
|
||||||
inputVars.put(s, inputVar);
|
inputVars.put(s, inputVar);
|
||||||
SDVariable maskVar = sameDiff.constant(s + "_mask", createMask(dataType, inputShape));
|
long[] maskShape = maskTemp.shape().clone();
|
||||||
|
maskShape[0] = -1;
|
||||||
|
SDVariable maskVar = sameDiff.placeHolder(s + "_mask", maskTemp.dataType(), maskShape);
|
||||||
maskVars.put(s, maskVar);
|
maskVars.put(s, maskVar);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -96,11 +96,14 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
|
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
|
||||||
}
|
}
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ) {
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
//TODO Find a more efficient solution for this
|
||||||
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
|
INDArray arr = e.getValue();
|
||||||
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String,INDArray> out = sameDiff.exec(phMap, outputKey);
|
Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
|
||||||
INDArray result = out.get(outputKey);
|
INDArray result = out.get(outputKey);
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
|
@ -131,9 +134,11 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
||||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||||
bl.validateInput(input);
|
bl.validateInput(input);
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ){
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
//TODO this should only be necessary, in theory, once!
|
//TODO Find a more efficient solution for this
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
|
INDArray arr = e.getValue();
|
||||||
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String,INDArray> phMap = new HashMap<>();
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
|
|
|
@ -99,8 +99,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
doInit();
|
doInit();
|
||||||
}
|
}
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ) {
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
//TODO Find a more efficient solution for this
|
||||||
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
|
INDArray arr = e.getValue();
|
||||||
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String,INDArray> phMap = new HashMap<>();
|
Map<String,INDArray> phMap = new HashMap<>();
|
||||||
|
@ -111,7 +114,7 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
|
|
||||||
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
|
String s = activations ? layerConf().activationsVertexName() : outputVar.getVarName();
|
||||||
|
|
||||||
INDArray out = sameDiff.execSingle(phMap, s);
|
INDArray out = sameDiff.outputSingle(phMap, s);
|
||||||
|
|
||||||
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
|
||||||
sameDiff.clearPlaceholders(true);
|
sameDiff.clearPlaceholders(true);
|
||||||
|
@ -149,20 +152,11 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
sameDiff.createGradFunction(INPUT_KEY);
|
sameDiff.createGradFunction(INPUT_KEY);
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray castInput = input.castTo(dataType);
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
if(castInput.isAttached())
|
//TODO Find a more efficient solution for this
|
||||||
castInput = castInput.dup();
|
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||||
sameDiff.associateArrayWithVariable(castInput, sameDiff.getVariable(INPUT_KEY));
|
INDArray arr = e.getValue();
|
||||||
if(layerConf().labelsRequired()) {
|
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||||
INDArray castLabels = labels.castTo(dataType);
|
|
||||||
if(castLabels.isAttached())
|
|
||||||
castLabels = castLabels.dup();
|
|
||||||
sameDiff.associateArrayWithVariable(castLabels, sameDiff.getVariable(LABELS_KEY));
|
|
||||||
}
|
|
||||||
|
|
||||||
for(String s : paramTable.keySet() ){
|
|
||||||
//TODO this should only be necessary, in theory, once!
|
|
||||||
sameDiff.associateArrayWithVariable(paramTable.get(s), s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> gradVarNames = new ArrayList<>();
|
List<String> gradVarNames = new ArrayList<>();
|
||||||
|
@ -297,8 +291,10 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
||||||
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
||||||
outputVar = layerOutput;
|
outputVar = layerOutput;
|
||||||
|
|
||||||
|
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||||
for (Map.Entry<String, INDArray> e : p.entrySet()) {
|
for (Map.Entry<String, INDArray> e : p.entrySet()) {
|
||||||
sameDiff.associateArrayWithVariable(e.getValue(), sameDiff.getVariable(e.getKey()));
|
INDArray arr = e.getValue();
|
||||||
|
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
this.outputKey = layerOutput.getVarName();
|
this.outputKey = layerOutput.getVarName();
|
||||||
|
|
|
@ -807,9 +807,9 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
SDVariable v = getVariable(varName);
|
SDVariable v = getVariable(varName);
|
||||||
if (v.isConstant()) {
|
if (v.isConstant()) {
|
||||||
constantArrays.put(varName, new DeviceLocalNDArray(arr));
|
constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||||
variablesArrays.put(varName, new DeviceLocalNDArray(arr));
|
variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||||
} else if (v.isPlaceHolder()) {
|
} else if (v.isPlaceHolder()) {
|
||||||
long tid = Thread.currentThread().getId();
|
long tid = Thread.currentThread().getId();
|
||||||
if (!placeholdersPerThread.containsKey(tid)) {
|
if (!placeholdersPerThread.containsKey(tid)) {
|
||||||
|
@ -1033,10 +1033,10 @@ public class SameDiff extends SDBaseOps {
|
||||||
|
|
||||||
switch (variable.getVariableType()) {
|
switch (variable.getVariableType()) {
|
||||||
case VARIABLE:
|
case VARIABLE:
|
||||||
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
|
variablesArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||||
break;
|
break;
|
||||||
case CONSTANT:
|
case CONSTANT:
|
||||||
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr));
|
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
|
||||||
break;
|
break;
|
||||||
case ARRAY:
|
case ARRAY:
|
||||||
// FIXME: remove this before release
|
// FIXME: remove this before release
|
||||||
|
@ -1077,6 +1077,29 @@ public class SameDiff extends SDBaseOps {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the constant or variable type SDVariable with the values from the specified
|
||||||
|
* array. Note that unlike {@link #associateArrayWithVariable(INDArray, String)} this method will take the
|
||||||
|
* values from the argument array and assign it to the current array.
|
||||||
|
* The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance.
|
||||||
|
* @param arr Array values to set
|
||||||
|
* @param variable Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariable
|
||||||
|
*/
|
||||||
|
public void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable){
|
||||||
|
Preconditions.checkState(variable.getVariableType() == VariableType.VARIABLE || variable.getVariableType() == VariableType.CONSTANT,
|
||||||
|
"assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", variable.getVarName(), variable.getVariableType());
|
||||||
|
|
||||||
|
//DeviceLocal doesn't work with views
|
||||||
|
if(arr.isView())
|
||||||
|
arr = arr.dup();
|
||||||
|
|
||||||
|
if(variable.getVariableType() == VariableType.VARIABLE ){
|
||||||
|
variablesArrays.get(variable.getVarName()).update(arr);
|
||||||
|
} else {
|
||||||
|
constantArrays.get(variable.getVarName()).update(arr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Associate a {@link SameDiff} namespace as a sub function.
|
* Associate a {@link SameDiff} namespace as a sub function.
|
||||||
|
@ -3256,7 +3279,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
|
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType(), null);
|
||||||
name = v.getVarName();
|
name = v.getVarName();
|
||||||
variables.put(name, Variable.builder().name(name).variable(v).build());
|
variables.put(name, Variable.builder().name(name).variable(v).build());
|
||||||
constantArrays.put(name, new DeviceLocalNDArray(constant));
|
constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3630,7 +3653,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
INDArray arr = variable.getArr();
|
INDArray arr = variable.getArr();
|
||||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||||
|
|
||||||
constantArrays.put(n, new DeviceLocalNDArray(arr));
|
constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||||
variablesArrays.remove(n);
|
variablesArrays.remove(n);
|
||||||
if (!placeholdersPerThread.isEmpty()) {
|
if (!placeholdersPerThread.isEmpty()) {
|
||||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||||
|
@ -3728,7 +3751,7 @@ public class SameDiff extends SDBaseOps {
|
||||||
INDArray arr = variable.getArr();
|
INDArray arr = variable.getArr();
|
||||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||||
|
|
||||||
variablesArrays.put(n, new DeviceLocalNDArray(arr));
|
variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||||
constantArrays.remove(n);
|
constantArrays.remove(n);
|
||||||
if (!placeholdersPerThread.isEmpty()) {
|
if (!placeholdersPerThread.isEmpty()) {
|
||||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||||
|
@ -3807,13 +3830,13 @@ public class SameDiff extends SDBaseOps {
|
||||||
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
|
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
|
||||||
INDArray arr = dl.get();
|
INDArray arr = dl.get();
|
||||||
INDArray newArr = arr.castTo(d);
|
INDArray newArr = arr.castTo(d);
|
||||||
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr));
|
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||||
break;
|
break;
|
||||||
case CONSTANT:
|
case CONSTANT:
|
||||||
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
|
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
|
||||||
INDArray arr2 = dl2.get();
|
INDArray arr2 = dl2.get();
|
||||||
INDArray newArr2 = arr2.castTo(d);
|
INDArray newArr2 = arr2.castTo(d);
|
||||||
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2));
|
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||||
break;
|
break;
|
||||||
case PLACEHOLDER:
|
case PLACEHOLDER:
|
||||||
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());
|
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());
|
||||||
|
|
Loading…
Reference in New Issue