diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index 1d2abe2b6..712265e05 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.base.Preconditions; @@ -133,16 +134,6 @@ public class SameDiffGraphVertex extends BaseGraphVertex { } is.setMmgr(mmgr); - - - if(paramTable != null && paramTable.size() > 0) { - //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration - //TODO Find a more efficient solution for this - for (Map.Entry e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - } INDArray result = sameDiff.outputSingle(phMap, outputKey); //Edge case: "vertex" is just an identity activation, for example @@ -212,17 +203,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex { String epsName = fn.getGradPlaceholderName(); phMap.put(epsName, epsilon); - - //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration - //TODO Find a more efficient solution for this - List required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated - for (Map.Entry e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - + List required = new ArrayList<>(config.getVertexParams().getInputs()); //Ensure that the input placeholder gradients are calculated required.addAll(paramTable.keySet()); - required.addAll(inputNames); Map gradsMap = sameDiff.calculateGradients(phMap, required); for(String s : paramTable.keySet() ){ @@ -279,6 +261,8 @@ public class SameDiffGraphVertex extends BaseGraphVertex { protected void doInit(){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { sameDiff = SameDiff.create(); + //Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe) + sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); inputVars = new LinkedHashMap<>(); LinkedHashMap maskVars = new LinkedHashMap<>(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 38a0f4075..fcf899544 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.base.Preconditions; @@ -100,13 +101,6 @@ public class SameDiffLayer extends AbstractLayer { phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); } - //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration - //TODO Find a more efficient solution for this - for (Map.Entry e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - //Configure memory management for SameDiff instance - use DL4J workspaces String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM); String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS); @@ -179,13 +173,6 @@ public class SameDiffLayer extends AbstractLayer { org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); bl.validateInput(input); - //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration - //TODO Find a more efficient solution for this - for (Map.Entry e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - Map phMap = new HashMap<>(); phMap.put(INPUT_KEY, input); phMap.put(fn.getGradPlaceholderName(), epsilon); @@ -300,6 +287,8 @@ public class SameDiffLayer extends AbstractLayer { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); sameDiff = SameDiff.create(); + //Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe) + sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); Map p = paramTable(); long[] inputShape = input.shape().clone(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index 35c44d17d..184c0ea16 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SessionMemMgr; import org.nd4j.base.Preconditions; @@ -119,15 +120,6 @@ public class SameDiffOutputLayer extends AbstractLayer e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - Map phMap = new HashMap<>(); phMap.put(INPUT_KEY, input); if(!activations && layerConf().labelsRequired() && labels != null) { @@ -193,13 +185,6 @@ public class SameDiffOutputLayer extends AbstractLayer e : paramTable.entrySet()) { - INDArray arr = e.getValue(); - sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); - } - List gradVarNames = new ArrayList<>(); gradVarNames.addAll(paramTable.keySet()); gradVarNames.add(INPUT_KEY); @@ -317,6 +302,8 @@ public class SameDiffOutputLayer extends AbstractLayer p = paramTable(); long[] inputShape = input.shape().clone(); @@ -339,7 +326,6 @@ public class SameDiffOutputLayer extends AbstractLayer e : p.entrySet()) { INDArray arr = e.getValue(); sameDiff.associateArrayWithVariable(arr, sameDiff.getVariable(e.getKey())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ArrayHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ArrayHolder.java new file mode 100644 index 000000000..9c8f59357 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ArrayHolder.java @@ -0,0 +1,69 @@ +package org.nd4j.autodiff.samediff; + +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Collection; + +/** + * Holds a set of arrays keyed by a String name, functioning essentially like a {@code Map}.
+ * Implementations may have different internal ways of storing arrays, however.
+ * For example for single threaded applications: {@link org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder}
+ * And for multi-threaded: {@link org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder} + * + * @author Alex Black + */ +public interface ArrayHolder { + + /** + * @return True if an array by that name exists + */ + boolean hasArray(String name); + + /** + * @param name Name of the array to get + * @return The array, or null if no array with that name exists + */ + INDArray getArray(String name); + + /** + * Set the array for the specified name (new array, or replace if it already exists) + * + * @param name Name of the array + * @param array Array to set + */ + void setArray(String name, INDArray array); + + /** + * Remove the array from the ArrayHolder, returning it (if it exists) + * + * @param name Name of the array to return + * @return The now-removed array + */ + INDArray removeArray(String name); + + /** + * @return Number of arrays in the ArrayHolder + */ + int size(); + + /** + * Initialize from the specified array holder. + * This clears all internal arrays, and adds all arrays from the specified array holder + * + * @param arrayHolder Array holder to initialize this based on + */ + void initFrom(ArrayHolder arrayHolder); + + /** + * @return Names of the arrays currently in the ArrayHolder + */ + Collection arrayNames(); + + /** + * Rename the entry with the specified name + * + * @param from Original name + * @param to New name + */ + void rename(String from, String to); +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index c79677c1e..449c2ef78 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -30,6 +30,8 @@ import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.samediff.api.OutAndGrad; +import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; +import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder; import org.nd4j.autodiff.samediff.config.BatchOutputConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.FitConfig; @@ -122,8 +124,8 @@ public class SameDiff extends SDBaseOps { @Getter private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID - private final Map constantArrays = new ConcurrentHashMap<>(); - private final Map variablesArrays = new ConcurrentHashMap<>(); //TODO issues with DeviceLocal + mutable / changed during training? + private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true); + private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true); private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them private final List lossVariables = new ArrayList<>(); @@ -346,6 +348,23 @@ public class SameDiff extends SDBaseOps { return listeners; } + /** + * Set the array holders for variable and constant arrays
+ * NOTE: this is usually reserved for developers and internal use, and should not be needed by almost all users
+ * See {@link ArrayHolder} for more details + * + * @param variableArrayHolder Array holder for variable arrays + * @param constantArrayHolder Array holder for constant arrays + * @param initialize If true: transfer any arrays from the current array holders to the new/specified ones + */ + public void setArrayHolders(@NonNull ArrayHolder variableArrayHolder, @NonNull ArrayHolder constantArrayHolder, boolean initialize){ + if(initialize){ + variableArrayHolder.initFrom(this.variablesArrays); + constantArrayHolder.initFrom(this.constantArrays); + } + this.variablesArrays = variableArrayHolder; + this.constantArrays = constantArrayHolder; + } /** * @return The current name scope, if any (null otherwise). See {@link #withNameScope(String)} for more details. @@ -674,9 +693,9 @@ public class SameDiff extends SDBaseOps { SDVariable v = getVariable(varName); if (v.isConstant()) { - constantArrays.put(varName, new DeviceLocalNDArray(arr, true)); + constantArrays.setArray(varName, arr); } else if (v.getVariableType() == VariableType.VARIABLE) { - variablesArrays.put(varName, new DeviceLocalNDArray(arr, true)); + variablesArrays.setArray(varName, arr); } else if (v.isPlaceHolder()) { long tid = Thread.currentThread().getId(); if (!placeholdersPerThread.containsKey(tid)) { @@ -699,12 +718,12 @@ public class SameDiff extends SDBaseOps { SDVariable var = getVariable(varName); switch (var.getVariableType()) { case VARIABLE: - return variablesArrays.containsKey(varName); + return variablesArrays.hasArray(varName); case ARRAY: long tid = Thread.currentThread().getId(); return sessions.containsKey(tid) && sessions.get(tid).contains(varName, InferenceSession.OUTER_FRAME, 0, null); case CONSTANT: - return constantArrays.containsKey(varName); + return constantArrays.hasArray(varName); case PLACEHOLDER: return placeholdersPerThread.containsKey(Thread.currentThread().getId()) && placeholdersPerThread.get(Thread.currentThread().getId()).containsKey(varName); @@ -724,11 +743,11 @@ public class SameDiff extends SDBaseOps { SDVariable v = variables.get(varName).getVariable(); switch (v.getVariableType()) { case VARIABLE: - return variablesArrays.get(varName).get(); + return variablesArrays.getArray(varName); case CONSTANT: - if (!constantArrays.containsKey(varName)) + if (!constantArrays.hasArray(varName)) return null; - return constantArrays.get(varName).get(); + return constantArrays.getArray(varName); case ARRAY: //Only stored in inference session... InferenceSession s = sessions.get(Thread.currentThread().getId()); @@ -781,31 +800,16 @@ public class SameDiff extends SDBaseOps { sessions.put(Thread.currentThread().getId(), new InferenceSession(this)); } - boolean duped = false; if (arr.isAttached()) { arr = arr.detach(); - duped = true; - } - if (arr.isView()) { - arr = arr.dup(); - duped = true; - } - - if (!duped && variable.getVariableType() == VariableType.VARIABLE) { - for (DeviceLocalNDArray otherArr : variablesArrays.values()) { - if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) - arr = arr.dup(); - break; - } - } } switch (variable.getVariableType()) { case VARIABLE: - variablesArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + variablesArrays.setArray(variable.name(), arr); break; case CONSTANT: - constantArrays.put(variable.name(), new DeviceLocalNDArray(arr, true)); + constantArrays.setArray(variable.name(), arr); break; case ARRAY: throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" + @@ -859,9 +863,9 @@ public class SameDiff extends SDBaseOps { arr = arr.dup(); if(variable.getVariableType() == VariableType.VARIABLE ){ - variablesArrays.get(variable.name()).update(arr); + variablesArrays.setArray(variable.name(), arr); } else { - constantArrays.get(variable.name()).update(arr); + constantArrays.setArray(variable.name(), arr); } } @@ -2715,7 +2719,7 @@ public class SameDiff extends SDBaseOps { SDVariable v = new SDVariable(name, VariableType.CONSTANT, this, constant.shape(), constant.dataType()); name = v.name(); variables.put(name, Variable.builder().name(name).variable(v).build()); - constantArrays.put(name, new DeviceLocalNDArray(constant, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + constantArrays.setArray(name, constant); return v; } @@ -2792,7 +2796,7 @@ public class SameDiff extends SDBaseOps { if(variableType == VariableType.VARIABLE){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { INDArray vArr = weightInitScheme.create(dataType, shape); - variablesArrays.put(name, new DeviceLocalNDArray(vArr, true)); + variablesArrays.setArray(name, vArr); } } @@ -2924,7 +2928,7 @@ public class SameDiff extends SDBaseOps { SDVariable r = new SDVariable(v.name(), v.getVariableType(), this, v.getShape(), v.dataType()); addVariable(r); try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ - variablesArrays.put(v.name(), new DeviceLocalNDArray(v.getArr().dup(), true)); + variablesArrays.setArray(v.name(), v.getArr().dup()); } return r; case ARRAY: @@ -3014,20 +3018,17 @@ public class SameDiff extends SDBaseOps { arr = arr.detach(); duped = true; } - if (arr.isView()) { - arr = arr.dup(); - duped = true; - } if (!duped) { - for (DeviceLocalNDArray otherArr : variablesArrays.values()) { - if (otherArr.get() == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) + for (String s : variablesArrays.arrayNames()) { + if (variablesArrays.getArray(s) == arr) { //Check for exact same object, to avoid array reuse (can result in unexpected behaviour) arr = arr.dup(); break; } } } + SDVariable ret = new SDVariable(name, VariableType.VARIABLE, this, arr.shape(), arr.dataType()); associateArrayWithVariable(arr, ret); @@ -3085,8 +3086,8 @@ public class SameDiff extends SDBaseOps { INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); - constantArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads - variablesArrays.remove(n); + constantArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + variablesArrays.removeArray(n); if (!placeholdersPerThread.isEmpty()) { for (Map m : placeholdersPerThread.values()) { m.remove(n); @@ -3183,8 +3184,8 @@ public class SameDiff extends SDBaseOps { INDArray arr = variable.getArr(); Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", variable); - variablesArrays.put(n, new DeviceLocalNDArray(arr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads - constantArrays.remove(n); + variablesArrays.setArray(n, arr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + constantArrays.removeArray(n); if (!placeholdersPerThread.isEmpty()) { for (Map m : placeholdersPerThread.values()) { m.remove(n); @@ -3260,16 +3261,14 @@ public class SameDiff extends SDBaseOps { switch (v.getVariableType()) { case VARIABLE: - DeviceLocalNDArray dl = variablesArrays.remove(e.getKey()); - INDArray arr = dl.get(); + INDArray arr = variablesArrays.removeArray(e.getKey()); INDArray newArr = arr.castTo(d); - variablesArrays.put(e.getKey(), new DeviceLocalNDArray(newArr, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + variablesArrays.setArray(e.getKey(), newArr); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case CONSTANT: - DeviceLocalNDArray dl2 = constantArrays.remove(e.getKey()); - INDArray arr2 = dl2.get(); + INDArray arr2 = constantArrays.removeArray(e.getKey()); INDArray newArr2 = arr2.castTo(d); - constantArrays.put(e.getKey(), new DeviceLocalNDArray(newArr2, true)); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads + constantArrays.setArray(e.getKey(), newArr2); //DeviceLocal with delayed initialization, in case we don't actually need multiple threads break; case PLACEHOLDER: Map m = placeholdersPerThread.get(Thread.currentThread().getId()); @@ -3409,14 +3408,12 @@ public class SameDiff extends SDBaseOps { variables.remove(from); variables.put(to, v); - if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.containsKey(from)){ - DeviceLocalNDArray dl = constantArrays.remove(from); - constantArrays.put(to, dl); + if(v.getVariable().getVariableType() == VariableType.CONSTANT && constantArrays.hasArray(from)){ + constantArrays.rename(from, to); } - if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.containsKey(from)){ - DeviceLocalNDArray dl = variablesArrays.remove(from); - variablesArrays.put(to, dl); + if(v.getVariable().getVariableType() == VariableType.VARIABLE && variablesArrays.hasArray(from)){ + variablesArrays.rename(from, to); } if(v.getVariable().getVariableType() == VariableType.PLACEHOLDER ){ @@ -4187,6 +4184,8 @@ public class SameDiff extends SDBaseOps { @Override public SDVariable[] define(SameDiff sameDiff, Map inputs, SDVariable[] variableInputs) { + sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); //Training isn't thread safe, no need to use DeviceLocal, even with lazy init + //Propagate graph to this samediff instance which will also contain the backward if (SameDiff.this.debugMode) { sameDiff.enableDebugMode(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/SingleThreadArrayHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/SingleThreadArrayHolder.java new file mode 100644 index 000000000..3f67b57bd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/SingleThreadArrayHolder.java @@ -0,0 +1,66 @@ +package org.nd4j.autodiff.samediff.array; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A simple {@link ArrayHolder} that uses a simple {@code Map} internally. + * No thread safety guarantees + * + * @author Alex Black + */ +public class SingleThreadArrayHolder implements ArrayHolder { + + private final Map map = new HashMap<>(); + + @Override + public boolean hasArray(@NonNull String name) { + return map.containsKey(name); + } + + @Override + public INDArray getArray(@NonNull String name) { + return map.get(name); + } + + @Override + public void setArray(@NonNull String name, @NonNull INDArray array) { + map.put(name, array); + } + + @Override + public INDArray removeArray(@NonNull String name) { + return map.remove(name); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public void initFrom(ArrayHolder arrayHolder) { + map.clear(); + Collection names = arrayHolder.arrayNames(); + for (String n : names) { + map.put(n, arrayHolder.getArray(n)); + } + } + + @Override + public Collection arrayNames() { + return Collections.unmodifiableCollection(map.keySet()); + } + + @Override + public void rename(String from, String to) { + INDArray arr = map.remove(from); + map.put(to, arr); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/ThreadSafeArrayHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/ThreadSafeArrayHolder.java new file mode 100644 index 000000000..34832d45f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/ThreadSafeArrayHolder.java @@ -0,0 +1,85 @@ +package org.nd4j.autodiff.samediff.array; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.util.DeviceLocalNDArray; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * An {@link ArrayHolder} that uses the thread safe {@link DeviceLocalNDArray} internally + * + * @author Alex Black + */ +public class ThreadSafeArrayHolder implements ArrayHolder { + + private final Map map = new ConcurrentHashMap<>(); + private final boolean lazyInit; + + /** + * @param lazyInit If true: use lazy initialization for {@link DeviceLocalNDArray} + */ + public ThreadSafeArrayHolder(boolean lazyInit) { + this.lazyInit = lazyInit; + } + + @Override + public boolean hasArray(@NonNull String name) { + return map.containsKey(name); + } + + @Override + public INDArray getArray(@NonNull String name) { + return map.get(name).get(); + } + + @Override + public void setArray(@NonNull String name, @NonNull INDArray array) { + if (array.isView()) + array = array.dup(); //Device local doesn't support views + if (!map.containsKey(name)) { + DeviceLocalNDArray dla = new DeviceLocalNDArray(array, lazyInit); + map.put(name, dla); + } else { + DeviceLocalNDArray dla = map.get(name); + dla.update(array); + } + } + + @Override + public INDArray removeArray(@NonNull String name) { + DeviceLocalNDArray arr = map.remove(name); + if (arr == null) + return null; + return arr.get(); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public void initFrom(ArrayHolder arrayHolder) { + map.clear(); + Collection names = arrayHolder.arrayNames(); + for (String n : names) { + setArray(n, arrayHolder.getArray(n)); + } + } + + @Override + public Collection arrayNames() { + return Collections.unmodifiableCollection(map.keySet()); + } + + @Override + public void rename(@NonNull String from, @NonNull String to) { + DeviceLocalNDArray dl = map.remove(from); + map.put(to, dl); + } +}