SameDiff: make use of DeviceLocal configurable (#32)

* #8340 make DeviceLocal configurable

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J SameDiff layers: use SingleThreadArrayHolder to avoid assigns + DeviceLocalNDArray overhead

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-06 18:52:41 +11:00 committed by GitHub
parent 7583ccfa15
commit df8b4e607a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 282 additions and 104 deletions

View File

@ -31,6 +31,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions;
@ -133,16 +134,6 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
}
is.setMmgr(mmgr);
if(paramTable != null && paramTable.size() > 0) {
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
}
INDArray result = sameDiff.outputSingle(phMap, outputKey);
//Edge case: "vertex" is just an identity activation, for example
@ -212,17 +203,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
List<String> required = new ArrayList<>(config.getVertexParams().getInputs()); //Ensure that the input placeholder gradients are calculated
required.addAll(paramTable.keySet());
required.addAll(inputNames);
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
for(String s : paramTable.keySet() ){
@ -279,6 +261,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
protected void doInit(){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
sameDiff = SameDiff.create();
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
inputVars = new LinkedHashMap<>();
LinkedHashMap<String, SDVariable> maskVars = new LinkedHashMap<>();

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions;
@ -100,13 +101,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
}
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
//Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
@ -179,13 +173,6 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(fn.getGradPlaceholderName(), epsilon);
@ -300,6 +287,8 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
sameDiff = SameDiff.create();
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
Map<String, INDArray> p = paramTable();
long[] inputShape = input.shape().clone();

View File

@ -29,6 +29,7 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions;
@ -119,15 +120,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
}
is.setMmgr(mmgr);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(!activations && layerConf().labelsRequired() && labels != null) {
@ -193,13 +185,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
sameDiff.createGradFunction(INPUT_KEY);
}
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
List<String> gradVarNames = new ArrayList<>();
gradVarNames.addAll(paramTable.keySet());
gradVarNames.add(INPUT_KEY);
@ -317,6 +302,8 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer bl = layerConf();
sameDiff = SameDiff.create();
//Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe)
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
Map<String, INDArray> p = paramTable();
long[] inputShape = input.shape().clone();
@ -339,7 +326,6 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
Preconditions.checkNotNull(layerOutput, "Invalid output: layer output is null");
outputVar = layerOutput;
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
for (Map.Entry<String, INDArray> e : p.entrySet()) {
INDArray arr = e.getValue();
sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey()));

View File

@ -0,0 +1,69 @@
package org.nd4j.autodiff.samediff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
/**
* Holds a set of arrays keyed by a String name, functioning essentially like a {@code Map<String,INDArray>}.<br>
* Implementations may have different internal ways of storing arrays, however.<br>
* For example for single threaded applications: {@link org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder}<br>
* And for multi-threaded: {@link org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder}
*
* @author Alex Black
*/
public interface ArrayHolder {
/**
* @return True if an array by that name exists
*/
boolean hasArray(String name);
/**
* @param name Name of the array to get
* @return The array, or null if no array with that name exists
*/
INDArray getArray(String name);
/**
* Set the array for the specified name (new array, or replace if it already exists)
*
* @param name Name of the array
* @param array Array to set
*/
void setArray(String name, INDArray array);
/**
* Remove the array from the ArrayHolder, returning it (if it exists)
*
* @param name Name of the array to return
* @return The now-removed array
*/
INDArray removeArray(String name);
/**
* @return Number of arrays in the ArrayHolder
*/
int size();
/**
* Initialize from the specified array holder.
* This clears all internal arrays, and adds all arrays from the specified array holder
*
* @param arrayHolder Array holder to initialize this based on
*/
void initFrom(ArrayHolder arrayHolder);
/**
* @return Names of the arrays currently in the ArrayHolder
*/
Collection<String> arrayNames();
/**
* Rename the entry with the specified name
*
* @param from Original name
* @param to New name
*/
void rename(String from, String to);
}

View File

@ -30,6 +30,8 @@ import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
@ -122,8 +124,8 @@ public class SameDiff extends SDBaseOps {
@Getter
private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<>(); //Key: thread ID
private final Map<String, DeviceLocalNDArray> constantArrays = new ConcurrentHashMap<>();
private final Map<String, DeviceLocalNDArray> variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training?
private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them
private final List<String> lossVariables = new ArrayList<>();
@ -346,6 +348,23 @@ public class SameDiff extends SDBaseOps {
return listeners;
}
/**
* Set the array holders for variable and constant arrays<br>
* <b>NOTE:</b> this is usually reserved for developers and internal use, and should not be needed by almost all users<br>
* See {@link ArrayHolder} for more details
*
* @param variableArrayHolder Array holder for variable arrays
* @param constantArrayHolder Array holder for constant arrays
* @param initialize If true: transfer any arrays from the current array holders to the new/specified ones
*/
public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize){
if(initialize){
variableArrayHolder.initFrom(this.variablesArrays);
constantArrayHolder.initFrom(this.constantArrays);
}
this.variablesArrays = variableArrayHolder;
this.constantArrays = constantArrayHolder;
}
/**
* @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details.
@ -674,9 +693,9 @@ public class SameDiff extends SDBaseOps {
SDVariable v = getVariable(varName);
if (v.isConstant()) {
constantArrays.put(varName, new DeviceLocalNDArray(arr, true));
constantArrays.setArray(varName, arr);
} else if (v.getVariableType() == VariableType.VARIABLE) {
variablesArrays.put(varName, new DeviceLocalNDArray(arr, true));
variablesArrays.setArray(varName, arr);
} else if (v.isPlaceHolder()) {
long tid = Thread.currentThread().getId();
if (!placeholdersPerThread.containsKey(tid)) {
@ -699,12 +718,12 @@ public class SameDiff extends SDBaseOps {
SDVariable var = getVariable(varName);
switch (var.getVariableType()) {
case VARIABLE:
return variablesArrays.containsKey(varName);
return variablesArrays.hasArray(varName);
case ARRAY:
long tid = Thread.currentThread().getId();
return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null);
case CONSTANT:
return constantArrays.containsKey(varName);
return constantArrays.hasArray(varName);
case PLACEHOLDER:
return placeholdersPerThread.containsKey(Thread.currentThread().getId()) &&
placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName);
@ -724,11 +743,11 @@ public class SameDiff extends SDBaseOps {
SDVariable v = variables.get(varName).getVariable();
switch (v.getVariableType()) {
case VARIABLE:
return variablesArrays.get(varName).get();
return variablesArrays.getArray(varName);
case CONSTANT:
if (!constantArrays.containsKey(varName))
if (!constantArrays.hasArray(varName))
return null;
return constantArrays.get(varName).get();
return constantArrays.getArray(varName);
case ARRAY:
//Only stored in inference session...
InferenceSession s = sessions.get(Thread.currentThread().getId());
@ -781,31 +800,16 @@ public class SameDiff extends SDBaseOps {
sessions.put(Thread.currentThread().getId(), new InferenceSession(this));
}
boolean duped = false;
if (arr.isAttached()) {
arr = arr.detach();
duped = true;
}
if (arr.isView()) {
arr = arr.dup();
duped = true;
}
if (!duped && variable.getVariableType() == VariableType.VARIABLE) {
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
arr = arr.dup();
break;
}
}
}
switch (variable.getVariableType()) {
case VARIABLE:
variablesArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
variablesArrays.setArray(variable.name(), arr);
break;
case CONSTANT:
constantArrays.put(variable.name(), new DeviceLocalNDArray(arr, true));
constantArrays.setArray(variable.name(), arr);
break;
case ARRAY:
throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
@ -859,9 +863,9 @@ public class SameDiff extends SDBaseOps {
arr = arr.dup();
if(variable.getVariableType() == VariableType.VARIABLE ){
variablesArrays.get(variable.name()).update(arr);
variablesArrays.setArray(variable.name(), arr);
} else {
constantArrays.get(variable.name()).update(arr);
constantArrays.setArray(variable.name(), arr);
}
}
@ -2715,7 +2719,7 @@ public class SameDiff extends SDBaseOps {
SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType());
name = v.name();
variables.put(name, Variable.builder().name(name).variable(v).build());
constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
constantArrays.setArray(name, constant);
return v;
}
@ -2792,7 +2796,7 @@ public class SameDiff extends SDBaseOps {
if(variableType == VariableType.VARIABLE){
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
INDArray vArr = weightInitScheme.create(dataType, shape);
variablesArrays.put(name, new DeviceLocalNDArray(vArr, true));
variablesArrays.setArray(name, vArr);
}
}
@ -2924,7 +2928,7 @@ public class SameDiff extends SDBaseOps {
SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType());
addVariable(r);
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
variablesArrays.put(v.name(), new DeviceLocalNDArray(v.getArr().dup(), true));
variablesArrays.setArray(v.name(), v.getArr().dup());
}
return r;
case ARRAY:
@ -3014,20 +3018,17 @@ public class SameDiff extends SDBaseOps {
arr = arr.detach();
duped = true;
}
if (arr.isView()) {
arr = arr.dup();
duped = true;
}
if (!duped) {
for (DeviceLocalNDArray otherArr : variablesArrays.values()) {
if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
for (String s : variablesArrays.arrayNames()) {
if (variablesArrays.getArray(s) == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour)
arr = arr.dup();
break;
}
}
}
SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType());
associateArrayWithVariable(arr, ret);
@ -3085,8 +3086,8 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
variablesArrays.remove(n);
constantArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
variablesArrays.removeArray(n);
if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
m.remove(n);
@ -3183,8 +3184,8 @@ public class SameDiff extends SDBaseOps {
INDArray arr = variable.getArr();
Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable);
variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
constantArrays.remove(n);
variablesArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
constantArrays.removeArray(n);
if (!placeholdersPerThread.isEmpty()) {
for (Map<String, INDArray> m : placeholdersPerThread.values()) {
m.remove(n);
@ -3260,16 +3261,14 @@ public class SameDiff extends SDBaseOps {
switch (v.getVariableType()) {
case VARIABLE:
DeviceLocalNDArray dl = variablesArrays.remove(e.getKey());
INDArray arr = dl.get();
INDArray arr = variablesArrays.removeArray(e.getKey());
INDArray newArr = arr.castTo(d);
variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
variablesArrays.setArray(e.getKey(), newArr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break;
case CONSTANT:
DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey());
INDArray arr2 = dl2.get();
INDArray arr2 = constantArrays.removeArray(e.getKey());
INDArray newArr2 = arr2.castTo(d);
constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
constantArrays.setArray(e.getKey(), newArr2); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads
break;
case PLACEHOLDER:
Map<String, INDArray> m = placeholdersPerThread.get(Thread.currentThread().getId());
@ -3409,14 +3408,12 @@ public class SameDiff extends SDBaseOps {
variables.remove(from);
variables.put(to, v);
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.containsKey(from)){
DeviceLocalNDArray dl = constantArrays.remove(from);
constantArrays.put(to, dl);
if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)){
constantArrays.rename(from, to);
}
if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.containsKey(from)){
DeviceLocalNDArray dl = variablesArrays.remove(from);
variablesArrays.put(to, dl);
if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.hasArray(from)){
variablesArrays.rename(from, to);
}
if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER ){
@ -4187,6 +4184,8 @@ public class SameDiff extends SDBaseOps {
@Override
public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); //Training isn't thread safe, no need to use DeviceLocal, even with lazy init
//Propagate graph to this samediff instance which will also contain the backward
if (SameDiff.this.debugMode) {
sameDiff.enableDebugMode();

View File

@ -0,0 +1,66 @@
package org.nd4j.autodiff.samediff.array;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/**
* A simple {@link ArrayHolder} that uses a simple {@code Map<String, INDArray>} internally.
* No thread safety guarantees
*
* @author Alex Black
*/
public class SingleThreadArrayHolder implements ArrayHolder {
private final Map<String, INDArray> map = new HashMap<>();
@Override
public boolean hasArray(@NonNull String name) {
return map.containsKey(name);
}
@Override
public INDArray getArray(@NonNull String name) {
return map.get(name);
}
@Override
public void setArray(@NonNull String name, @NonNull INDArray array) {
map.put(name, array);
}
@Override
public INDArray removeArray(@NonNull String name) {
return map.remove(name);
}
@Override
public int size() {
return map.size();
}
@Override
public void initFrom(ArrayHolder arrayHolder) {
map.clear();
Collection<String> names = arrayHolder.arrayNames();
for (String n : names) {
map.put(n, arrayHolder.getArray(n));
}
}
@Override
public Collection<String> arrayNames() {
return Collections.unmodifiableCollection(map.keySet());
}
@Override
public void rename(String from, String to) {
INDArray arr = map.remove(from);
map.put(to, arr);
}
}

View File

@ -0,0 +1,85 @@
package org.nd4j.autodiff.samediff.array;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* An {@link ArrayHolder} that uses the thread safe {@link DeviceLocalNDArray} internally
*
* @author Alex Black
*/
public class ThreadSafeArrayHolder implements ArrayHolder {
private final Map<String, DeviceLocalNDArray> map = new ConcurrentHashMap<>();
private final boolean lazyInit;
/**
* @param lazyInit If true: use lazy initialization for {@link DeviceLocalNDArray}
*/
public ThreadSafeArrayHolder(boolean lazyInit) {
this.lazyInit = lazyInit;
}
@Override
public boolean hasArray(@NonNull String name) {
return map.containsKey(name);
}
@Override
public INDArray getArray(@NonNull String name) {
return map.get(name).get();
}
@Override
public void setArray(@NonNull String name, @NonNull INDArray array) {
if (array.isView())
array = array.dup(); //Device local doesn't support views
if (!map.containsKey(name)) {
DeviceLocalNDArray dla = new DeviceLocalNDArray(array, lazyInit);
map.put(name, dla);
} else {
DeviceLocalNDArray dla = map.get(name);
dla.update(array);
}
}
@Override
public INDArray removeArray(@NonNull String name) {
DeviceLocalNDArray arr = map.remove(name);
if (arr == null)
return null;
return arr.get();
}
@Override
public int size() {
return map.size();
}
@Override
public void initFrom(ArrayHolder arrayHolder) {
map.clear();
Collection<String> names = arrayHolder.arrayNames();
for (String n : names) {
setArray(n, arrayHolder.getArray(n));
}
}
@Override
public Collection<String> arrayNames() {
return Collections.unmodifiableCollection(map.keySet());
}
@Override
public void rename(@NonNull String from, @NonNull String to) {
DeviceLocalNDArray dl = map.remove(from);
map.put(to, dl);
}
}