Add SameDiff memory reuse memory manager (array cache) (#39)
* Attention op comments Signed-off-by: AlexDBlack <blacka101@gmail.com> * ArrayCacheMemoryMgr - first pass Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak array cache for use with SameDiff identity arrays Signed-off-by: AlexDBlack <blacka101@gmail.com> * ArrayCacheMemoryMgr javadoc and properly get max memory Signed-off-by: AlexDBlack <blacka101@gmail.com> * LRU cache policy + add tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Resize arrays internally if required for ArrayCacheMemoryMgr Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test improvement Signed-off-by: AlexDBlack <blacka101@gmail.com> * Small polish Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
0eda1e733e
commit
18c01f5bdc
|
@ -98,6 +98,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
|
||||||
|
|
||||||
SDVariable diff = sd.f().squaredDifference(a1, label);
|
SDVariable diff = sd.f().squaredDifference(a1, label);
|
||||||
SDVariable lossMse = diff.mean();
|
SDVariable lossMse = diff.mean();
|
||||||
|
lossMse.markAsLoss();
|
||||||
|
|
||||||
IUpdater updater;
|
IUpdater updater;
|
||||||
double lr;
|
double lr;
|
||||||
|
|
|
@ -34,16 +34,16 @@ namespace nd4j {
|
||||||
auto numHeads = projectionMatrix->sizeAt(0);
|
auto numHeads = projectionMatrix->sizeAt(0);
|
||||||
auto projectedSize = projectionMatrix->sizeAt(1);
|
auto projectedSize = projectionMatrix->sizeAt(1);
|
||||||
|
|
||||||
auto inputPerm = input->permute({1, 0, 2});
|
auto inputPerm = input->permute({1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps]
|
||||||
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
|
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps]
|
||||||
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});
|
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn]
|
||||||
|
|
||||||
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);
|
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
|
||||||
nd4j::ops::matmul mmul;
|
nd4j::ops::matmul mmul;
|
||||||
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
|
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
|
||||||
|
|
||||||
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
|
||||||
projected.permutei({2, 0, 1, 3});
|
projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
|
||||||
|
|
||||||
return projected;
|
return projected;
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,13 +28,13 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) {
|
CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) {
|
||||||
auto queries = INPUT_VARIABLE(0);
|
auto queries = INPUT_VARIABLE(0); //[batch, nIn, timeSteps]
|
||||||
auto keys = INPUT_VARIABLE(1);
|
auto keys = INPUT_VARIABLE(1); //[batch, nIn, timeSteps]
|
||||||
auto values = INPUT_VARIABLE(2);
|
auto values = INPUT_VARIABLE(2); //[batch, nIn, timeSteps]
|
||||||
auto Wq = INPUT_VARIABLE(3);
|
auto Wq = INPUT_VARIABLE(3); //[nHeads, headSize, nIn]
|
||||||
auto Wk = INPUT_VARIABLE(4);
|
auto Wk = INPUT_VARIABLE(4); //[nHeads, headSize, nIn]
|
||||||
auto Wv = INPUT_VARIABLE(5);
|
auto Wv = INPUT_VARIABLE(5); //[nHeads, headSize, nIn]
|
||||||
auto Wo = INPUT_VARIABLE(6);
|
auto Wo = INPUT_VARIABLE(6); //[nHeads * headSize, nOut]
|
||||||
auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr;
|
auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr;
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,11 +93,12 @@ namespace ops {
|
||||||
|
|
||||||
|
|
||||||
// Project queries, keys, values
|
// Project queries, keys, values
|
||||||
auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext());
|
auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength]
|
||||||
auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext());
|
auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength]
|
||||||
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext());
|
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength]
|
||||||
|
|
||||||
// Apply Attention
|
// Apply Attention
|
||||||
|
// attnResults = [minibatch, numHeads, projectedSize, seqLenth
|
||||||
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
|
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
|
||||||
nd4j::ops::dot_product_attention attention;
|
nd4j::ops::dot_product_attention attention;
|
||||||
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
|
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});
|
||||||
|
|
|
@ -24,7 +24,7 @@ import org.nd4j.autodiff.listeners.Listener;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.autodiff.samediff.VariableType;
|
import org.nd4j.autodiff.samediff.VariableType;
|
||||||
import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr;
|
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
@ -84,8 +84,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
|
|
||||||
public InferenceSession(@NonNull SameDiff sameDiff) {
|
public InferenceSession(@NonNull SameDiff sameDiff) {
|
||||||
super(sameDiff);
|
super(sameDiff);
|
||||||
|
mmgr = new ArrayCacheMemoryMgr();
|
||||||
mmgr = new ArrayCloseMemoryMgr(); //TODO replace this with new (planned) array reuse memory manager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -215,7 +214,6 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
|
INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
|
||||||
op.getOp().clearArrays();
|
|
||||||
|
|
||||||
if (log.isTraceEnabled()) {
|
if (log.isTraceEnabled()) {
|
||||||
StringBuilder sb = new StringBuilder();
|
StringBuilder sb = new StringBuilder();
|
||||||
|
@ -254,6 +252,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op.getOp().clearArrays();
|
||||||
|
|
||||||
|
|
||||||
//Record array uses for memory management/deallocation
|
//Record array uses for memory management/deallocation
|
||||||
|
@ -842,11 +841,10 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
reqShape = reqShape.asDataType(dt);
|
reqShape = reqShape.asDataType(dt);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (currOutput == null || currOutput.wasClosed() || !currOutput.shapeDescriptor().equals(reqShape) || currOutput.isEmpty() != reqShape.isEmpty() || isLoop) {
|
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
|
||||||
boolean isOutput = allReqVariables.contains(outNames[i]);
|
boolean isOutput = allReqVariables.contains(outNames[i]);
|
||||||
INDArray out = mmgr.allocate(isOutput, reqShape);
|
INDArray out = mmgr.allocate(isOutput, reqShape);
|
||||||
customOp.setOutputArgument(i, out);
|
customOp.setOutputArgument(i, out);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (df instanceof Op) {
|
} else if (df instanceof Op) {
|
||||||
|
@ -893,29 +891,17 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
|
||||||
|
|
||||||
//Check output shape; allocate a new Z if required
|
//Check output shape; allocate a new Z if required
|
||||||
//For example, if minibatch size has changed since last op execution
|
//For example, if minibatch size has changed since last op execution
|
||||||
|
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
|
||||||
if (emptyReduce) {
|
if (emptyReduce) {
|
||||||
INDArray z = op.z();
|
//Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
|
||||||
if (z == null || !op.x().equalShapes(z) || isLoop) {
|
INDArray z = mmgr.allocate(false, op.x().dataType(), op.x().shape());
|
||||||
//Note: edge case: [x,y].sum(empty) = [x,y] for TF import compatibility.
|
op.setZ(z);
|
||||||
z = mmgr.allocate(false, op.x().dataType(), op.x().shape());
|
|
||||||
op.setZ(z);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape();
|
List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape();
|
||||||
Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
|
Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
|
||||||
INDArray z = op.z();
|
LongShapeDescriptor lsd = outputShape.get(0);
|
||||||
if (z == null || z.wasClosed() || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) {
|
INDArray z = mmgr.allocate(isOutput, lsd);
|
||||||
if (log.isTraceEnabled()) {
|
op.setZ(z);
|
||||||
log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}",
|
|
||||||
op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
LongShapeDescriptor lsd = outputShape.get(0);
|
|
||||||
|
|
||||||
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
|
|
||||||
z = mmgr.allocate(isOutput, lsd);
|
|
||||||
op.setZ(z);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,292 @@
|
||||||
|
package org.nd4j.autodiff.samediff.internal.memory;
|
||||||
|
|
||||||
|
import lombok.*;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.util.ArrayUtil;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ArrayCacheMemoryMgr reuses arrays to reduce the number of memory allocations and deallocations.<br>
|
||||||
|
* Memory allocations and deallocations can be quite expensive, especially on GPUs.<br>
|
||||||
|
* Note that when arrays are reused, they are reused for the same datatype only.<br>
|
||||||
|
* If caching a released array would result in the the maximum cache size being is exceeded, the oldest arrays will
|
||||||
|
* be deallocated first, until the new array can in the cache.
|
||||||
|
* <br><br>
|
||||||
|
* By default, the following parameters are used for the cache:
|
||||||
|
* <ul>
|
||||||
|
* <li>Maximum cache size: 0.25 x max memory, where:</li>
|
||||||
|
* <ul>
|
||||||
|
* <li>CPU: max memory is determined using {@link Pointer#maxBytes()}</li>
|
||||||
|
* <li>GPU: max memory is determined using GPU 0 total memory</li>
|
||||||
|
* </ul>
|
||||||
|
* <li>Larger array max multiple: 2.0</li>
|
||||||
|
* <ul>
|
||||||
|
* <li>This means: if an exact array size can't be provided from the cache, use the next smallest array with a buffer up to 2.0x larger than requested</li>
|
||||||
|
* <li>If no cached arrays of size < 2x requested exists, allocate a new array</li>
|
||||||
|
* </ul>
|
||||||
|
* <li>Small array threshold: 1024 elements</li>
|
||||||
|
* <ul>
|
||||||
|
* <li>This means: the "larger array max multiple" doesn't apply below this level. For example, we might return a size 1 array backed by a size 1023 buffer</li>
|
||||||
|
* </ul>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @author Alex Black
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
public class ArrayCacheMemoryMgr extends AbstractMemoryMgr {
|
||||||
|
|
||||||
|
private final double maxMemFrac;
|
||||||
|
private final long smallArrayThreshold;
|
||||||
|
private final double largerArrayMaxMultiple;
|
||||||
|
|
||||||
|
private final long maxCacheBytes;
|
||||||
|
private final long totalMemBytes;
|
||||||
|
|
||||||
|
private long currentCacheSize = 0;
|
||||||
|
private Map<DataType, ArrayStore> arrayStores = new HashMap<>();
|
||||||
|
|
||||||
|
private LinkedHashSet<Long> lruCache = new LinkedHashSet<>();
|
||||||
|
private Map<Long,INDArray> lruCacheValues = new HashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create an ArrayCacheMemoryMgr with default settings as per {@link ArrayCacheMemoryMgr}
|
||||||
|
*/
|
||||||
|
public ArrayCacheMemoryMgr() {
|
||||||
|
this(0.25, 1024, 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param maxMemFrac Maximum memory fraciton to use as cache
|
||||||
|
* @param smallArrayThreshold Below this size (elements), don't apply the "largerArrayMaxMultiple" rule
|
||||||
|
* @param largerArrayMaxMultiple Maximum multiple of the requested size to return from the cache. If an array of size
|
||||||
|
* 1024 is requested, and largerArrayMaxMultiple is 2.0, then we'll return from the cache
|
||||||
|
* the array with the smallest data buffer up to 2.0*1024 elements; otherwise we'll return
|
||||||
|
* a new array
|
||||||
|
*/
|
||||||
|
public ArrayCacheMemoryMgr(double maxMemFrac, long smallArrayThreshold, double largerArrayMaxMultiple) {
|
||||||
|
Preconditions.checkArgument(maxMemFrac > 0 && maxMemFrac < 1, "Maximum memory fraction for cache must be between 0.0 and 1.0, got %s", maxMemFrac);
|
||||||
|
Preconditions.checkArgument(smallArrayThreshold >= 0, "Small array threshould must be >= 0, got %s", smallArrayThreshold);
|
||||||
|
Preconditions.checkArgument(largerArrayMaxMultiple >= 1.0, "Larger array max multiple must be >= 1.0, got %s", largerArrayMaxMultiple);
|
||||||
|
this.maxMemFrac = maxMemFrac;
|
||||||
|
this.smallArrayThreshold = smallArrayThreshold;
|
||||||
|
this.largerArrayMaxMultiple = largerArrayMaxMultiple;
|
||||||
|
|
||||||
|
if(isCpu()){
|
||||||
|
totalMemBytes = Pointer.maxBytes();
|
||||||
|
} else {
|
||||||
|
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
|
||||||
|
List devList = (List) p.get("cuda.devicesInformation");
|
||||||
|
Map m = (Map) devList.get(0);
|
||||||
|
totalMemBytes = (Long)m.get("cuda.totalMemory");
|
||||||
|
}
|
||||||
|
maxCacheBytes = (long)(maxMemFrac * totalMemBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isCpu(){
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
return !"CUDA".equalsIgnoreCase(backend);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
|
||||||
|
if (arrayStores.containsKey(dataType)) {
|
||||||
|
INDArray arr = arrayStores.get(dataType).get(shape);
|
||||||
|
if (arr != null) {
|
||||||
|
//Decrement cache size
|
||||||
|
currentCacheSize -= dataType.width() * arr.data().length();
|
||||||
|
|
||||||
|
return arr; //Allocated from cache
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Allocation failed, allocate new array
|
||||||
|
return Nd4j.createUninitializedDetached(dataType, shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
|
||||||
|
return allocate(detached, descriptor.dataType(), descriptor.getShape());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void release(@NonNull INDArray array) {
|
||||||
|
//Check for multiple releases of the array
|
||||||
|
long id = array.getId();
|
||||||
|
Preconditions.checkState(!lruCache.contains(id), "Array was released multiple times: id=%s, shape=%ndShape", id, array);
|
||||||
|
|
||||||
|
|
||||||
|
DataType dt = array.dataType();
|
||||||
|
long thisBytes = array.data().length() * dt.width();
|
||||||
|
if(array.dataType() == DataType.UTF8) {
|
||||||
|
//Don't cache string arrays due to variable length buffers
|
||||||
|
if(array.closeable())
|
||||||
|
array.close();
|
||||||
|
} else if (currentCacheSize + thisBytes > maxCacheBytes) {
|
||||||
|
if(thisBytes > maxCacheBytes){
|
||||||
|
//Can't store even if we clear everything - too large
|
||||||
|
if(array.closeable())
|
||||||
|
array.close();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//Need to deallocate some arrays to stay under limit - do in "oldest first" order
|
||||||
|
Iterator<Long> iter = lruCache.iterator();
|
||||||
|
while(currentCacheSize + thisBytes > maxCacheBytes){
|
||||||
|
long next = iter.next();
|
||||||
|
iter.remove();
|
||||||
|
INDArray nextOldest = lruCacheValues.remove(next);
|
||||||
|
DataType ndt = nextOldest.dataType();
|
||||||
|
long nextBytes = ndt.width() * nextOldest.data().length();
|
||||||
|
arrayStores.get(ndt).removeObject(nextOldest);
|
||||||
|
currentCacheSize -= nextBytes;
|
||||||
|
|
||||||
|
if(nextOldest.closeable())
|
||||||
|
nextOldest.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
//After clearing space - can now cache
|
||||||
|
cacheArray(array);
|
||||||
|
} else {
|
||||||
|
//OK to cache
|
||||||
|
cacheArray(array);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Store in LRU cache for "last used" removal if we exceed cache size
|
||||||
|
lruCache.add(array.getId());
|
||||||
|
lruCacheValues.put(array.getId(), array);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void cacheArray(INDArray array){
|
||||||
|
DataType dt = array.dataType();
|
||||||
|
if (!arrayStores.containsKey(dt))
|
||||||
|
arrayStores.put(dt, new ArrayStore());
|
||||||
|
arrayStores.get(dt).add(array);
|
||||||
|
currentCacheSize += array.data().length() * dt.width();
|
||||||
|
|
||||||
|
lruCache.add(array.getId());
|
||||||
|
lruCacheValues.put(array.getId(), array);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
for (ArrayStore as : arrayStores.values()) {
|
||||||
|
as.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
public class ArrayStore {
|
||||||
|
private INDArray[] sorted = new INDArray[1000]; //TODO resizing, don't hardcode
|
||||||
|
private long[] lengths = new long[1000];
|
||||||
|
private long lengthSum;
|
||||||
|
private long bytesSum;
|
||||||
|
private int size;
|
||||||
|
|
||||||
|
private void add(@NonNull INDArray array) {
|
||||||
|
//Resize arrays
|
||||||
|
if(size == sorted.length){
|
||||||
|
sorted = Arrays.copyOf(sorted, 2*sorted.length);
|
||||||
|
lengths = Arrays.copyOf(lengths, 2*lengths.length);
|
||||||
|
}
|
||||||
|
|
||||||
|
long length = array.data().length();
|
||||||
|
int idx = Arrays.binarySearch(lengths, 0, size, length);
|
||||||
|
if (idx < 0) {
|
||||||
|
idx = -idx - 1; //See binarySearch javadoc
|
||||||
|
}
|
||||||
|
for (int i = size - 1; i >= idx; i--) {
|
||||||
|
sorted[i + 1] = sorted[i];
|
||||||
|
lengths[i + 1] = lengths[i];
|
||||||
|
}
|
||||||
|
sorted[idx] = array;
|
||||||
|
lengths[idx] = length;
|
||||||
|
size++;
|
||||||
|
lengthSum += length;
|
||||||
|
bytesSum += length * array.dataType().width();
|
||||||
|
}
|
||||||
|
|
||||||
|
private INDArray get(long[] shape) {
|
||||||
|
if (size == 0)
|
||||||
|
return null;
|
||||||
|
|
||||||
|
long length = shape.length == 0 ? 1 : ArrayUtil.prod(shape);
|
||||||
|
|
||||||
|
int idx = Arrays.binarySearch(lengths, 0, size, length);
|
||||||
|
if (idx < 0) {
|
||||||
|
idx = -idx - 1;
|
||||||
|
if (idx >= size) {
|
||||||
|
//Largest array is smaller than required -> can't return from cache
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
INDArray nextSmallest = sorted[idx];
|
||||||
|
long nextSmallestLength = nextSmallest.data().length();
|
||||||
|
long nextSmallestLengthBytes = nextSmallestLength * nextSmallest.dataType().width();
|
||||||
|
|
||||||
|
boolean tooLarge = (length > (long) (nextSmallestLength * largerArrayMaxMultiple));
|
||||||
|
|
||||||
|
if (nextSmallestLengthBytes > smallArrayThreshold && tooLarge) {
|
||||||
|
return null;
|
||||||
|
} // If less than smallArrayThreshold, ok, return as is
|
||||||
|
}
|
||||||
|
|
||||||
|
//Remove
|
||||||
|
INDArray arr = removeIdx(idx);
|
||||||
|
|
||||||
|
lruCache.remove(arr.getId());
|
||||||
|
lruCacheValues.remove(arr.getId());
|
||||||
|
|
||||||
|
//Create a new array with the specified buffer. This is for 2 reasons:
|
||||||
|
//(a) the cached array and requested array sizes may differ (though this is easy to check for)
|
||||||
|
//(b) Some SameDiff array use tracking uses *object identity* - so we want different objects when reusing arrays
|
||||||
|
// to avoid issues there
|
||||||
|
return Nd4j.create(arr.data(), shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void removeObject(INDArray array){
|
||||||
|
long length = array.data().length();
|
||||||
|
int idx = Arrays.binarySearch(lengths, 0, size, length);
|
||||||
|
Preconditions.checkState(idx > 0, "Cannot remove array from ArrayStore: no array with this length exists in the cache");
|
||||||
|
boolean found = false;
|
||||||
|
int i = 0;
|
||||||
|
while(!found && i <= size && lengths[i] == length){
|
||||||
|
found = sorted[i++] == array; //Object equality
|
||||||
|
}
|
||||||
|
Preconditions.checkState(found, "Cannot remove array: not found in ArrayCache");
|
||||||
|
removeIdx(i - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
private INDArray removeIdx(int idx){
|
||||||
|
INDArray arr = sorted[idx];
|
||||||
|
for (int i = idx; i < size; i++) {
|
||||||
|
sorted[i] = sorted[i + 1];
|
||||||
|
lengths[i] = lengths[i + 1];
|
||||||
|
}
|
||||||
|
sorted[size] = null;
|
||||||
|
lengths[size] = 0;
|
||||||
|
size--;
|
||||||
|
|
||||||
|
bytesSum -= (arr.data().length() * arr.dataType().width());
|
||||||
|
lengthSum -= arr.data().length();
|
||||||
|
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void close() {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
if (sorted[i].closeable())
|
||||||
|
sorted[i].close();
|
||||||
|
lengths[i] = 0;
|
||||||
|
}
|
||||||
|
lengthSum = 0;
|
||||||
|
bytesSum = 0;
|
||||||
|
size = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,119 @@
|
||||||
|
package org.nd4j.autodiff.samediff;
|
||||||
|
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
|
||||||
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
public class MemoryMgrTest extends BaseNd4jTest {
|
||||||
|
|
||||||
|
public MemoryMgrTest(Nd4jBackend b){
|
||||||
|
super(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public char ordering(){
|
||||||
|
return 'c';
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testArrayReuseTooLarge() throws Exception {
|
||||||
|
|
||||||
|
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
||||||
|
Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes");
|
||||||
|
f.setAccessible(true);
|
||||||
|
f.set(mmgr, 1000);
|
||||||
|
|
||||||
|
assertEquals(1000, mmgr.getMaxCacheBytes());
|
||||||
|
|
||||||
|
INDArray[] arrays = new INDArray[100];
|
||||||
|
for( int i=0; i<arrays.length; i++ ){
|
||||||
|
arrays[i] = Nd4j.create(DataType.FLOAT, 25); //100 bytes each
|
||||||
|
}
|
||||||
|
|
||||||
|
for( int i=0; i<10; i++ ){
|
||||||
|
mmgr.release(arrays[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(1000, mmgr.getCurrentCacheSize());
|
||||||
|
ArrayCacheMemoryMgr.ArrayStore as = mmgr.getArrayStores().get(DataType.FLOAT);
|
||||||
|
assertEquals(1000, as.getBytesSum());
|
||||||
|
assertEquals(250, as.getLengthSum());
|
||||||
|
assertEquals(10, as.getSize());
|
||||||
|
assertEquals(10, mmgr.getLruCache().size());
|
||||||
|
assertEquals(10, mmgr.getLruCacheValues().size());
|
||||||
|
|
||||||
|
|
||||||
|
//At this point: array store is full.
|
||||||
|
//If we try to release more, the oldest (first released) values should be closed
|
||||||
|
for( int i=0; i<10; i++ ) {
|
||||||
|
INDArray toRelease = Nd4j.create(DataType.FLOAT, 25);
|
||||||
|
mmgr.release(toRelease);
|
||||||
|
//oldest N only should be closed by this point...
|
||||||
|
for( int j=0; j<10; j++ ){
|
||||||
|
if(j <= i){
|
||||||
|
//Should have been closed
|
||||||
|
assertTrue(arrays[j].wasClosed());
|
||||||
|
} else {
|
||||||
|
//Should still be open
|
||||||
|
assertFalse(arrays[j].wasClosed());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(1000, mmgr.getCurrentCacheSize());
|
||||||
|
assertEquals(1000, as.getBytesSum());
|
||||||
|
assertEquals(250, as.getLengthSum());
|
||||||
|
assertEquals(10, as.getSize());
|
||||||
|
assertEquals(10, mmgr.getLruCache().size());
|
||||||
|
assertEquals(10, mmgr.getLruCacheValues().size());
|
||||||
|
|
||||||
|
//now, allocate some values:
|
||||||
|
for( int i=1; i<=10; i++ ) {
|
||||||
|
INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25);
|
||||||
|
assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize());
|
||||||
|
assertEquals(1000 - i * 100, as.getBytesSum());
|
||||||
|
assertEquals(250 - i * 25, as.getLengthSum());
|
||||||
|
assertEquals(10 - i, as.getSize());
|
||||||
|
assertEquals(10 - i, mmgr.getLruCache().size());
|
||||||
|
assertEquals(10 - i, mmgr.getLruCacheValues().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(0, mmgr.getCurrentCacheSize());
|
||||||
|
assertEquals(0, as.getBytesSum());
|
||||||
|
assertEquals(0, as.getLengthSum());
|
||||||
|
assertEquals(0, as.getSize());
|
||||||
|
assertEquals(0, mmgr.getLruCache().size());
|
||||||
|
assertEquals(0, mmgr.getLruCacheValues().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testManyArrays(){
|
||||||
|
|
||||||
|
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
|
||||||
|
for( int i=0; i<1000; i++ ){
|
||||||
|
mmgr.release(Nd4j.scalar(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(4*1000, mmgr.getCurrentCacheSize());
|
||||||
|
assertEquals(1000, mmgr.getLruCache().size());
|
||||||
|
assertEquals(1000, mmgr.getLruCacheValues().size());
|
||||||
|
|
||||||
|
for( int i=0; i<1000; i++ ){
|
||||||
|
mmgr.release(Nd4j.scalar(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
assertEquals(4*2000, mmgr.getCurrentCacheSize());
|
||||||
|
assertEquals(2000, mmgr.getLruCache().size());
|
||||||
|
assertEquals(2000, mmgr.getLruCacheValues().size());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue