SameDiff: make use of DeviceLocal configurable (#32)
* #8340 make DeviceLocal configurable Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J SameDiff layers: use SingleThreadArrayHolder to avoid assigns + DeviceLocalNDArray overhead Signed-off-by: AlexDBlack <blacka101@gmail.com> * Javadoc Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
7583ccfa15
commit
df8b4e607a
|
@ -31,6 +31,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -133,16 +134,6 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
}
|
||||
is.setMmgr(mmgr);
|
||||
|
||||
|
||||
|
||||
if(paramTable != null && paramTable.size() > 0) {
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
}
|
||||
INDArray result = sameDiff.outputSingle(phMap, outputKey);
|
||||
|
||||
//Edge case: "vertex" is just an identity activation, for example
|
||||
|
@ -212,17 +203,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
String epsName = fn.getGradPlaceholderName();
|
||||
phMap.put(epsName, epsilon);
|
||||
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
List<String> required = new ArrayList<>(config.getVertexParams().getInputs()); //Ensure that the input placeholder gradients are calculated
|
||||
required.addAll(paramTable.keySet());
|
||||
required.addAll(inputNames);
|
||||
|
||||
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
|
||||
for(String s : paramTable.keySet() ){
|
||||
|
@ -279,6 +261,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
|
|||
protected void doInit(){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
sameDiff = SameDiff.create();
|
||||
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
|
||||
|
||||
inputVars = new LinkedHashMap<>();
|
||||
LinkedHashMap<String, SDVariable> maskVars = new LinkedHashMap<>();
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
|
|||
import org.deeplearning4j.nn.layers.AbstractLayer;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -100,13 +101,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
|
||||
}
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
//Configure memory management for SameDiff instance - use DL4J workspaces
|
||||
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
|
||||
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||
|
@ -179,13 +173,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||
bl.validateInput(input);
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
phMap.put(fn.getGradPlaceholderName(), epsilon);
|
||||
|
@ -300,6 +287,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
|
|||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
|
||||
sameDiff = SameDiff.create();
|
||||
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
|
||||
Map<String, INDArray> p = paramTable();
|
||||
|
||||
long[] inputShape = input.shape().clone();
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.internal.InferenceSession;
|
||||
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -119,15 +120,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
}
|
||||
is.setMmgr(mmgr);
|
||||
|
||||
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
Map<String,INDArray> phMap = new HashMap<>();
|
||||
phMap.put(INPUT_KEY, input);
|
||||
if(!activations && layerConf().labelsRequired() && labels != null) {
|
||||
|
@ -193,13 +185,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
sameDiff.createGradFunction(INPUT_KEY);
|
||||
}
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
//TODO Find a more efficient solution for this
|
||||
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
|
||||
}
|
||||
|
||||
List<String> gradVarNames = new ArrayList<>();
|
||||
gradVarNames.addAll(paramTable.keySet());
|
||||
gradVarNames.add(INPUT_KEY);
|
||||
|
@ -317,6 +302,8 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer bl = layerConf();
|
||||
sameDiff = SameDiff.create();
|
||||
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
|
||||
Map<String, INDArray> p = paramTable();
|
||||
|
||||
long[] inputShape = input.shape().clone();
|
||||
|
@ -339,7 +326,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
|
|||
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
|
||||
outputVar = layerOutput;
|
||||
|
||||
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
|
||||
for (Map.Entry<String, INDArray> e : p.entrySet()) {
|
||||
INDArray arr = e.getValue();
|
||||
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
package org.nd4j.autodiff.samediff;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Collection;
|
||||
|
||||
/**
|
||||
* Holds a set of arrays keyed by a String name, functioning essentially like a {@code Map<String,INDArray>}.<br>
|
||||
* Implementations may have different internal ways of storing arrays, however.<br>
|
||||
* For example for single threaded applications: {@link org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder}<br>
|
||||
* And for multi-threaded: {@link org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public interface ArrayHolder {
|
||||
|
||||
/**
|
||||
* @return True if an array by that name exists
|
||||
*/
|
||||
boolean hasArray(String name);
|
||||
|
||||
/**
|
||||
* @param name Name of the array to get
|
||||
* @return The array, or null if no array with that name exists
|
||||
*/
|
||||
INDArray getArray(String name);
|
||||
|
||||
/**
|
||||
* Set the array for the specified name (new array, or replace if it already exists)
|
||||
*
|
||||
* @param name Name of the array
|
||||
* @param array Array to set
|
||||
*/
|
||||
void setArray(String name, INDArray array);
|
||||
|
||||
/**
|
||||
* Remove the array from the ArrayHolder, returning it (if it exists)
|
||||
*
|
||||
* @param name Name of the array to return
|
||||
* @return The now-removed array
|
||||
*/
|
||||
INDArray removeArray(String name);
|
||||
|
||||
/**
|
||||
* @return Number of arrays in the ArrayHolder
|
||||
*/
|
||||
int size();
|
||||
|
||||
/**
|
||||
* Initialize from the specified array holder.
|
||||
* This clears all internal arrays, and adds all arrays from the specified array holder
|
||||
*
|
||||
* @param arrayHolder Array holder to initialize this based on
|
||||
*/
|
||||
void initFrom(ArrayHolder arrayHolder);
|
||||
|
||||
/**
|
||||
* @return Names of the arrays currently in the ArrayHolder
|
||||
*/
|
||||
Collection<String> arrayNames();
|
||||
|
||||
/**
|
||||
* Rename the entry with the specified name
|
||||
*
|
||||
* @param from Original name
|
||||
* @param to New name
|
||||
*/
|
||||
void rename(String from, String to);
|
||||
}
|
|
@ -30,6 +30,8 @@ import org.nd4j.autodiff.listeners.impl.HistoryListener;
|
|||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.samediff.api.OutAndGrad;
|
||||
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
|
||||
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
|
||||
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
|
||||
import org.nd4j.autodiff.samediff.config.FitConfig;
|
||||
|
@ -122,8 +124,8 @@ public class SameDiff extends SDBaseOps {
|
|||
@Getter
|
||||
private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<>(); //Key: thread ID
|
||||
|
||||
private final Map<String, DeviceLocalNDArray> constantArrays = new ConcurrentHashMap<>();
|
||||
private final Map<String, DeviceLocalNDArray> variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training?
|
||||
private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
|
||||
private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
|
||||
private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them
|
||||
|
||||
private final List<String> lossVariables = new ArrayList<>();
|
||||
|
@ -346,6 +348,23 @@ public class SameDiff extends SDBaseOps {
|
|||
return listeners;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the array holders for variable and constant arrays<br>
|
||||
* <b>NOTE:</b> this is usually reserved for developers and internal use, and should not be needed by almost all users<br>
|
||||
* See {@link ArrayHolder} for more details
|
||||
*
|
||||
* @param variableArrayHolder Array holder for variable arrays
|
||||
* @param constantArrayHolder Array holder for constant arrays
|
||||
* @param initialize If true: transfer any arrays from the current array holders to the new/specified ones
|
||||
*/
|
||||
public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize){
|
||||
if(initialize){
|
||||
variableArrayHolder.initFrom(this.variablesArrays);
|
||||
constantArrayHolder.initFrom(this.constantArrays);
|
||||
}
|
||||
this.variablesArrays = variableArrayHolder;
|
||||
this.constantArrays = constantArrayHolder;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details.
|
||||
|
@ -674,9 +693,9 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
SDVariable v = getVariable(varName);
|
||||
if (v.isConstant()) {
|
||||
constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||
constantArrays.setArray(varName, arr);
|
||||
} else if (v.getVariableType() == VariableType.VARIABLE) {
|
||||
variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
|
||||
variablesArrays.setArray(varName, arr);
|
||||
} else if (v.isPlaceHolder()) {
|
||||
long tid = Thread.currentThread().getId();
|
||||
if (!placeholdersPerThread.containsKey(tid)) {
|
||||
|
@ -699,12 +718,12 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable var = getVariable(varName);
|
||||
switch (var.getVariableType()) {
|
||||
case VARIABLE:
|
||||
return variablesArrays.containsKey(varName);
|
||||
return variablesArrays.hasArray(varName);
|
||||
case ARRAY:
|
||||
long tid = Thread.currentThread().getId();
|
||||
return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null);
|
||||
case CONSTANT:
|
||||
return constantArrays.containsKey(varName);
|
||||
return constantArrays.hasArray(varName);
|
||||
case PLACEHOLDER:
|
||||
return placeholdersPerThread.containsKey(Thread.currentThread().getId()) &&
|
||||
placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
|
||||
|
@ -724,11 +743,11 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable v = variables.get(varName).getVariable();
|
||||
switch (v.getVariableType()) {
|
||||
case VARIABLE:
|
||||
return variablesArrays.get(varName).get();
|
||||
return variablesArrays.getArray(varName);
|
||||
case CONSTANT:
|
||||
if (!constantArrays.containsKey(varName))
|
||||
if (!constantArrays.hasArray(varName))
|
||||
return null;
|
||||
return constantArrays.get(varName).get();
|
||||
return constantArrays.getArray(varName);
|
||||
case ARRAY:
|
||||
//Only stored in inference session...
|
||||
InferenceSession s = sessions.get(Thread.currentThread().getId());
|
||||
|
@ -781,31 +800,16 @@ public class SameDiff extends SDBaseOps {
|
|||
sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
|
||||
}
|
||||
|
||||
boolean duped = false;
|
||||
if (arr.isAttached()) {
|
||||
arr = arr.detach();
|
||||
duped = true;
|
||||
}
|
||||
if (arr.isView()) {
|
||||
arr = arr.dup();
|
||||
duped = true;
|
||||
}
|
||||
|
||||
if (!duped && variable.getVariableType() == VariableType.VARIABLE) {
|
||||
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
|
||||
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
|
||||
arr = arr.dup();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch (variable.getVariableType()) {
|
||||
case VARIABLE:
|
||||
variablesArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.setArray(variable.name(), arr);
|
||||
break;
|
||||
case CONSTANT:
|
||||
constantArrays.put(variable.name(), new DeviceLocalNDArray(arr, true));
|
||||
constantArrays.setArray(variable.name(), arr);
|
||||
break;
|
||||
case ARRAY:
|
||||
throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
|
||||
|
@ -859,9 +863,9 @@ public class SameDiff extends SDBaseOps {
|
|||
arr = arr.dup();
|
||||
|
||||
if(variable.getVariableType() == VariableType.VARIABLE ){
|
||||
variablesArrays.get(variable.name()).update(arr);
|
||||
variablesArrays.setArray(variable.name(), arr);
|
||||
} else {
|
||||
constantArrays.get(variable.name()).update(arr);
|
||||
constantArrays.setArray(variable.name(), arr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2715,7 +2719,7 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType());
|
||||
name = v.name();
|
||||
variables.put(name, Variable.builder().name(name).variable(v).build());
|
||||
constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.setArray(name, constant);
|
||||
return v;
|
||||
}
|
||||
|
||||
|
@ -2792,7 +2796,7 @@ public class SameDiff extends SDBaseOps {
|
|||
if(variableType == VariableType.VARIABLE){
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||
INDArray vArr = weightInitScheme.create(dataType, shape);
|
||||
variablesArrays.put(name, new DeviceLocalNDArray(vArr, true));
|
||||
variablesArrays.setArray(name, vArr);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2924,7 +2928,7 @@ public class SameDiff extends SDBaseOps {
|
|||
SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType());
|
||||
addVariable(r);
|
||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
|
||||
variablesArrays.put(v.name(), new DeviceLocalNDArray(v.getArr().dup(), true));
|
||||
variablesArrays.setArray(v.name(), v.getArr().dup());
|
||||
}
|
||||
return r;
|
||||
case ARRAY:
|
||||
|
@ -3014,20 +3018,17 @@ public class SameDiff extends SDBaseOps {
|
|||
arr = arr.detach();
|
||||
duped = true;
|
||||
}
|
||||
if (arr.isView()) {
|
||||
arr = arr.dup();
|
||||
duped = true;
|
||||
}
|
||||
|
||||
if (!duped) {
|
||||
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
|
||||
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
|
||||
for (String s : variablesArrays.arrayNames()) {
|
||||
if (variablesArrays.getArray(s) == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
|
||||
arr = arr.dup();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType());
|
||||
associateArrayWithVariable(arr, ret);
|
||||
|
||||
|
@ -3085,8 +3086,8 @@ public class SameDiff extends SDBaseOps {
|
|||
INDArray arr = variable.getArr();
|
||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||
|
||||
constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.remove(n);
|
||||
constantArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.removeArray(n);
|
||||
if (!placeholdersPerThread.isEmpty()) {
|
||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||
m.remove(n);
|
||||
|
@ -3183,8 +3184,8 @@ public class SameDiff extends SDBaseOps {
|
|||
INDArray arr = variable.getArr();
|
||||
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
|
||||
|
||||
variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.remove(n);
|
||||
variablesArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.removeArray(n);
|
||||
if (!placeholdersPerThread.isEmpty()) {
|
||||
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
|
||||
m.remove(n);
|
||||
|
@ -3260,16 +3261,14 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
switch (v.getVariableType()) {
|
||||
case VARIABLE:
|
||||
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
|
||||
INDArray arr = dl.get();
|
||||
INDArray arr = variablesArrays.removeArray(e.getKey());
|
||||
INDArray newArr = arr.castTo(d);
|
||||
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
variablesArrays.setArray(e.getKey(), newArr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
break;
|
||||
case CONSTANT:
|
||||
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
|
||||
INDArray arr2 = dl2.get();
|
||||
INDArray arr2 = constantArrays.removeArray(e.getKey());
|
||||
INDArray newArr2 = arr2.castTo(d);
|
||||
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
constantArrays.setArray(e.getKey(), newArr2); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
|
||||
break;
|
||||
case PLACEHOLDER:
|
||||
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());
|
||||
|
@ -3409,14 +3408,12 @@ public class SameDiff extends SDBaseOps {
|
|||
variables.remove(from);
|
||||
variables.put(to, v);
|
||||
|
||||
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.containsKey(from)){
|
||||
DeviceLocalNDArray dl = constantArrays.remove(from);
|
||||
constantArrays.put(to, dl);
|
||||
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)){
|
||||
constantArrays.rename(from, to);
|
||||
}
|
||||
|
||||
if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.containsKey(from)){
|
||||
DeviceLocalNDArray dl = variablesArrays.remove(from);
|
||||
variablesArrays.put(to, dl);
|
||||
if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.hasArray(from)){
|
||||
variablesArrays.rename(from, to);
|
||||
}
|
||||
|
||||
if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER ){
|
||||
|
@ -4187,6 +4184,8 @@ public class SameDiff extends SDBaseOps {
|
|||
|
||||
@Override
|
||||
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
|
||||
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); //Training isn't thread safe, no need to use DeviceLocal, even with lazy init
|
||||
|
||||
//Propagate graph to this samediff instance which will also contain the backward
|
||||
if (SameDiff.this.debugMode) {
|
||||
sameDiff.enableDebugMode();
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
package org.nd4j.autodiff.samediff.array;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.ArrayHolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* A simple {@link ArrayHolder} that uses a simple {@code Map<String, INDArray>} internally.
|
||||
* No thread safety guarantees
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class SingleThreadArrayHolder implements ArrayHolder {
|
||||
|
||||
private final Map<String, INDArray> map = new HashMap<>();
|
||||
|
||||
@Override
|
||||
public boolean hasArray(@NonNull String name) {
|
||||
return map.containsKey(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getArray(@NonNull String name) {
|
||||
return map.get(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setArray(@NonNull String name, @NonNull INDArray array) {
|
||||
map.put(name, array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray removeArray(@NonNull String name) {
|
||||
return map.remove(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return map.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFrom(ArrayHolder arrayHolder) {
|
||||
map.clear();
|
||||
Collection<String> names = arrayHolder.arrayNames();
|
||||
for (String n : names) {
|
||||
map.put(n, arrayHolder.getArray(n));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<String> arrayNames() {
|
||||
return Collections.unmodifiableCollection(map.keySet());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void rename(String from, String to) {
|
||||
INDArray arr = map.remove(from);
|
||||
map.put(to, arr);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
package org.nd4j.autodiff.samediff.array;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.ArrayHolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.util.DeviceLocalNDArray;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* An {@link ArrayHolder} that uses the thread safe {@link DeviceLocalNDArray} internally
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class ThreadSafeArrayHolder implements ArrayHolder {
|
||||
|
||||
private final Map<String, DeviceLocalNDArray> map = new ConcurrentHashMap<>();
|
||||
private final boolean lazyInit;
|
||||
|
||||
/**
|
||||
* @param lazyInit If true: use lazy initialization for {@link DeviceLocalNDArray}
|
||||
*/
|
||||
public ThreadSafeArrayHolder(boolean lazyInit) {
|
||||
this.lazyInit = lazyInit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasArray(@NonNull String name) {
|
||||
return map.containsKey(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray getArray(@NonNull String name) {
|
||||
return map.get(name).get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setArray(@NonNull String name, @NonNull INDArray array) {
|
||||
if (array.isView())
|
||||
array = array.dup(); //Device local doesn't support views
|
||||
if (!map.containsKey(name)) {
|
||||
DeviceLocalNDArray dla = new DeviceLocalNDArray(array, lazyInit);
|
||||
map.put(name, dla);
|
||||
} else {
|
||||
DeviceLocalNDArray dla = map.get(name);
|
||||
dla.update(array);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray removeArray(@NonNull String name) {
|
||||
DeviceLocalNDArray arr = map.remove(name);
|
||||
if (arr == null)
|
||||
return null;
|
||||
return arr.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return map.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initFrom(ArrayHolder arrayHolder) {
|
||||
map.clear();
|
||||
Collection<String> names = arrayHolder.arrayNames();
|
||||
for (String n : names) {
|
||||
setArray(n, arrayHolder.getArray(n));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<String> arrayNames() {
|
||||
return Collections.unmodifiableCollection(map.keySet());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void rename(@NonNull String from, @NonNull String to) {
|
||||
DeviceLocalNDArray dl = map.remove(from);
|
||||
map.put(to, dl);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue