Add SameDiff memory reuse memory manager (array cache) (#39)

* Attention op comments

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ArrayCacheMemoryMgr - first pass

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweak array cache for use with SameDiff identity arrays

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ArrayCacheMemoryMgr javadoc and properly get max memory

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* LRU cache policy + add tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Resize arrays internally if required for ArrayCacheMemoryMgr

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Test improvement

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-12 21:15:44 +11:00 committed by GitHub
parent 0eda1e733e
commit 18c01f5bdc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 442 additions and 43 deletions

View File

@ -98,6 +98,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
SDVariable diff = sd.f().squaredDifference(a1, label); SDVariable diff = sd.f().squaredDifference(a1, label);
SDVariable lossMse = diff.mean(); SDVariable lossMse = diff.mean();
lossMse.markAsLoss();
IUpdater updater; IUpdater updater;
double lr; double lr;

View File

@ -34,16 +34,16 @@ namespace nd4j {
auto numHeads = projectionMatrix->sizeAt(0); auto numHeads = projectionMatrix->sizeAt(0);
auto projectedSize = projectionMatrix->sizeAt(1); auto projectedSize = projectionMatrix->sizeAt(1);
auto inputPerm = input->permute({1, 0, 2}); auto inputPerm = input->permute({1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps]
auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps]
auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn]
NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps]
nd4j::ops::matmul mmul; nd4j::ops::matmul mmul;
mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {}); mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {});
projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
projected.permutei({2, 0, 1, 3}); projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength]
return projected; return projected;
} }

View File

@ -28,13 +28,13 @@ namespace nd4j {
namespace ops { namespace ops {
CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) { CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) {
auto queries = INPUT_VARIABLE(0); auto queries = INPUT_VARIABLE(0); //[batch, nIn, timeSteps]
auto keys = INPUT_VARIABLE(1); auto keys = INPUT_VARIABLE(1); //[batch, nIn, timeSteps]
auto values = INPUT_VARIABLE(2); auto values = INPUT_VARIABLE(2); //[batch, nIn, timeSteps]
auto Wq = INPUT_VARIABLE(3); auto Wq = INPUT_VARIABLE(3); //[nHeads, headSize, nIn]
auto Wk = INPUT_VARIABLE(4); auto Wk = INPUT_VARIABLE(4); //[nHeads, headSize, nIn]
auto Wv = INPUT_VARIABLE(5); auto Wv = INPUT_VARIABLE(5); //[nHeads, headSize, nIn]
auto Wo = INPUT_VARIABLE(6); auto Wo = INPUT_VARIABLE(6); //[nHeads * headSize, nOut]
auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr;
@ -93,11 +93,12 @@ namespace ops {
// Project queries, keys, values // Project queries, keys, values
auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength]
auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength]
auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength]
// Apply Attention // Apply Attention
// attnResults = [minibatch, numHeads, projectedSize, seqLenth
NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext()); NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext());
nd4j::ops::dot_product_attention attention; nd4j::ops::dot_product_attention attention;
attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {}); attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {});

View File

@ -24,7 +24,7 @@ import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -84,8 +84,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
public InferenceSession(@NonNull SameDiff sameDiff) { public InferenceSession(@NonNull SameDiff sameDiff) {
super(sameDiff); super(sameDiff);
mmgr = new ArrayCacheMemoryMgr();
mmgr = new ArrayCloseMemoryMgr(); //TODO replace this with new (planned) array reuse memory manager
} }
@Override @Override
@ -215,7 +214,6 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
op.getOp().clearArrays();
if (log.isTraceEnabled()) { if (log.isTraceEnabled()) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
@ -254,6 +252,7 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
} }
} }
} }
op.getOp().clearArrays();
//Record array uses for memory management/deallocation //Record array uses for memory management/deallocation
@ -842,12 +841,11 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
reqShape = reqShape.asDataType(dt); reqShape = reqShape.asDataType(dt);
} }
if (currOutput == null || currOutput.wasClosed() || !currOutput.shapeDescriptor().equals(reqShape) || currOutput.isEmpty() != reqShape.isEmpty() || isLoop) { //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
boolean isOutput = allReqVariables.contains(outNames[i]); boolean isOutput = allReqVariables.contains(outNames[i]);
INDArray out = mmgr.allocate(isOutput, reqShape); INDArray out = mmgr.allocate(isOutput, reqShape);
customOp.setOutputArgument(i, out); customOp.setOutputArgument(i, out);
} }
}
} else if (df instanceof Op) { } else if (df instanceof Op) {
Op op = (Op) df; Op op = (Op) df;
@ -893,31 +891,19 @@ public class InferenceSession extends AbstractSession<INDArray, SameDiffOp> {
//Check output shape; allocate a new Z if required //Check output shape; allocate a new Z if required
//For example, if minibatch size has changed since last op execution //For example, if minibatch size has changed since last op execution
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
if (emptyReduce) { if (emptyReduce) {
INDArray z = op.z(); //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc
if (z == null || !op.x().equalShapes(z) || isLoop) { INDArray z = mmgr.allocate(false, op.x().dataType(), op.x().shape());
//Note: edge case: [x,y].sum(empty) = [x,y] for TF import compatibility.
z = mmgr.allocate(false, op.x().dataType(), op.x().shape());
op.setZ(z); op.setZ(z);
}
} else { } else {
List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape(); List<LongShapeDescriptor> outputShape = ((BaseOp) op).calculateOutputShape();
Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass());
INDArray z = op.z();
if (z == null || z.wasClosed() || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) {
if (log.isTraceEnabled()) {
log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}",
op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString());
}
LongShapeDescriptor lsd = outputShape.get(0); LongShapeDescriptor lsd = outputShape.get(0);
INDArray z = mmgr.allocate(isOutput, lsd);
boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]);
z = mmgr.allocate(isOutput, lsd);
op.setZ(z); op.setZ(z);
} }
} }
}
return sdo; return sdo;
} }

View File

@ -0,0 +1,292 @@
package org.nd4j.autodiff.samediff.internal.memory;
import lombok.*;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.util.*;
/**
* ArrayCacheMemoryMgr reuses arrays to reduce the number of memory allocations and deallocations.<br>
* Memory allocations and deallocations can be quite expensive, especially on GPUs.<br>
* Note that when arrays are reused, they are reused for the same datatype only.<br>
* If caching a released array would result in the the maximum cache size being is exceeded, the oldest arrays will
* be deallocated first, until the new array can in the cache.
* <br><br>
* By default, the following parameters are used for the cache:
* <ul>
* <li>Maximum cache size: 0.25 x max memory, where:</li>
* <ul>
* <li>CPU: max memory is determined using {@link Pointer#maxBytes()}</li>
* <li>GPU: max memory is determined using GPU 0 total memory</li>
* </ul>
* <li>Larger array max multiple: 2.0</li>
* <ul>
* <li>This means: if an exact array size can't be provided from the cache, use the next smallest array with a buffer up to 2.0x larger than requested</li>
* <li>If no cached arrays of size &lt; 2x requested exists, allocate a new array</li>
* </ul>
* <li>Small array threshold: 1024 elements</li>
* <ul>
* <li>This means: the "larger array max multiple" doesn't apply below this level. For example, we might return a size 1 array backed by a size 1023 buffer</li>
* </ul>
* </ul>
*
* @author Alex Black
*/
@Getter
public class ArrayCacheMemoryMgr extends AbstractMemoryMgr {
private final double maxMemFrac;
private final long smallArrayThreshold;
private final double largerArrayMaxMultiple;
private final long maxCacheBytes;
private final long totalMemBytes;
private long currentCacheSize = 0;
private Map<DataType, ArrayStore> arrayStores = new HashMap<>();
private LinkedHashSet<Long> lruCache = new LinkedHashSet<>();
private Map<Long,INDArray> lruCacheValues = new HashMap<>();
/**
* Create an ArrayCacheMemoryMgr with default settings as per {@link ArrayCacheMemoryMgr}
*/
public ArrayCacheMemoryMgr() {
this(0.25, 1024, 2.0);
}
/**
* @param maxMemFrac Maximum memory fraciton to use as cache
* @param smallArrayThreshold Below this size (elements), don't apply the "largerArrayMaxMultiple" rule
* @param largerArrayMaxMultiple Maximum multiple of the requested size to return from the cache. If an array of size
* 1024 is requested, and largerArrayMaxMultiple is 2.0, then we'll return from the cache
* the array with the smallest data buffer up to 2.0*1024 elements; otherwise we'll return
* a new array
*/
public ArrayCacheMemoryMgr(double maxMemFrac, long smallArrayThreshold, double largerArrayMaxMultiple) {
Preconditions.checkArgument(maxMemFrac > 0 && maxMemFrac < 1, "Maximum memory fraction for cache must be between 0.0 and 1.0, got %s", maxMemFrac);
Preconditions.checkArgument(smallArrayThreshold >= 0, "Small array threshould must be >= 0, got %s", smallArrayThreshold);
Preconditions.checkArgument(largerArrayMaxMultiple >= 1.0, "Larger array max multiple must be >= 1.0, got %s", largerArrayMaxMultiple);
this.maxMemFrac = maxMemFrac;
this.smallArrayThreshold = smallArrayThreshold;
this.largerArrayMaxMultiple = largerArrayMaxMultiple;
if(isCpu()){
totalMemBytes = Pointer.maxBytes();
} else {
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
List devList = (List) p.get("cuda.devicesInformation");
Map m = (Map) devList.get(0);
totalMemBytes = (Long)m.get("cuda.totalMemory");
}
maxCacheBytes = (long)(maxMemFrac * totalMemBytes);
}
private boolean isCpu(){
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
return !"CUDA".equalsIgnoreCase(backend);
}
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
if (arrayStores.containsKey(dataType)) {
INDArray arr = arrayStores.get(dataType).get(shape);
if (arr != null) {
//Decrement cache size
currentCacheSize -= dataType.width() * arr.data().length();
return arr; //Allocated from cache
}
}
//Allocation failed, allocate new array
return Nd4j.createUninitializedDetached(dataType, shape);
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
return allocate(detached, descriptor.dataType(), descriptor.getShape());
}
@Override
public void release(@NonNull INDArray array) {
//Check for multiple releases of the array
long id = array.getId();
Preconditions.checkState(!lruCache.contains(id), "Array was released multiple times: id=%s, shape=%ndShape", id, array);
DataType dt = array.dataType();
long thisBytes = array.data().length() * dt.width();
if(array.dataType() == DataType.UTF8) {
//Don't cache string arrays due to variable length buffers
if(array.closeable())
array.close();
} else if (currentCacheSize + thisBytes > maxCacheBytes) {
if(thisBytes > maxCacheBytes){
//Can't store even if we clear everything - too large
if(array.closeable())
array.close();
return;
}
//Need to deallocate some arrays to stay under limit - do in "oldest first" order
Iterator<Long> iter = lruCache.iterator();
while(currentCacheSize + thisBytes > maxCacheBytes){
long next = iter.next();
iter.remove();
INDArray nextOldest = lruCacheValues.remove(next);
DataType ndt = nextOldest.dataType();
long nextBytes = ndt.width() * nextOldest.data().length();
arrayStores.get(ndt).removeObject(nextOldest);
currentCacheSize -= nextBytes;
if(nextOldest.closeable())
nextOldest.close();
}
//After clearing space - can now cache
cacheArray(array);
} else {
//OK to cache
cacheArray(array);
}
//Store in LRU cache for "last used" removal if we exceed cache size
lruCache.add(array.getId());
lruCacheValues.put(array.getId(), array);
}
private void cacheArray(INDArray array){
DataType dt = array.dataType();
if (!arrayStores.containsKey(dt))
arrayStores.put(dt, new ArrayStore());
arrayStores.get(dt).add(array);
currentCacheSize += array.data().length() * dt.width();
lruCache.add(array.getId());
lruCacheValues.put(array.getId(), array);
}
@Override
public void close() {
for (ArrayStore as : arrayStores.values()) {
as.close();
}
}
@Getter
public class ArrayStore {
private INDArray[] sorted = new INDArray[1000]; //TODO resizing, don't hardcode
private long[] lengths = new long[1000];
private long lengthSum;
private long bytesSum;
private int size;
private void add(@NonNull INDArray array) {
//Resize arrays
if(size == sorted.length){
sorted = Arrays.copyOf(sorted, 2*sorted.length);
lengths = Arrays.copyOf(lengths, 2*lengths.length);
}
long length = array.data().length();
int idx = Arrays.binarySearch(lengths, 0, size, length);
if (idx < 0) {
idx = -idx - 1; //See binarySearch javadoc
}
for (int i = size - 1; i >= idx; i--) {
sorted[i + 1] = sorted[i];
lengths[i + 1] = lengths[i];
}
sorted[idx] = array;
lengths[idx] = length;
size++;
lengthSum += length;
bytesSum += length * array.dataType().width();
}
private INDArray get(long[] shape) {
if (size == 0)
return null;
long length = shape.length == 0 ? 1 : ArrayUtil.prod(shape);
int idx = Arrays.binarySearch(lengths, 0, size, length);
if (idx < 0) {
idx = -idx - 1;
if (idx >= size) {
//Largest array is smaller than required -> can't return from cache
return null;
}
INDArray nextSmallest = sorted[idx];
long nextSmallestLength = nextSmallest.data().length();
long nextSmallestLengthBytes = nextSmallestLength * nextSmallest.dataType().width();
boolean tooLarge = (length > (long) (nextSmallestLength * largerArrayMaxMultiple));
if (nextSmallestLengthBytes > smallArrayThreshold && tooLarge) {
return null;
} // If less than smallArrayThreshold, ok, return as is
}
//Remove
INDArray arr = removeIdx(idx);
lruCache.remove(arr.getId());
lruCacheValues.remove(arr.getId());
//Create a new array with the specified buffer. This is for 2 reasons:
//(a) the cached array and requested array sizes may differ (though this is easy to check for)
//(b) Some SameDiff array use tracking uses *object identity* - so we want different objects when reusing arrays
// to avoid issues there
return Nd4j.create(arr.data(), shape);
}
private void removeObject(INDArray array){
long length = array.data().length();
int idx = Arrays.binarySearch(lengths, 0, size, length);
Preconditions.checkState(idx > 0, "Cannot remove array from ArrayStore: no array with this length exists in the cache");
boolean found = false;
int i = 0;
while(!found && i <= size && lengths[i] == length){
found = sorted[i++] == array; //Object equality
}
Preconditions.checkState(found, "Cannot remove array: not found in ArrayCache");
removeIdx(i - 1);
}
private INDArray removeIdx(int idx){
INDArray arr = sorted[idx];
for (int i = idx; i < size; i++) {
sorted[i] = sorted[i + 1];
lengths[i] = lengths[i + 1];
}
sorted[size] = null;
lengths[size] = 0;
size--;
bytesSum -= (arr.data().length() * arr.dataType().width());
lengthSum -= arr.data().length();
return arr;
}
private void close() {
for (int i = 0; i < size; i++) {
if (sorted[i].closeable())
sorted[i].close();
lengths[i] = 0;
}
lengthSum = 0;
bytesSum = 0;
size = 0;
}
}
}

View File

@ -0,0 +1,119 @@
package org.nd4j.autodiff.samediff;
import org.junit.Test;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.lang.reflect.Field;
import static org.junit.Assert.*;
public class MemoryMgrTest extends BaseNd4jTest {
public MemoryMgrTest(Nd4jBackend b){
super(b);
}
@Override
public char ordering(){
return 'c';
}
@Test
public void testArrayReuseTooLarge() throws Exception {
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes");
f.setAccessible(true);
f.set(mmgr, 1000);
assertEquals(1000, mmgr.getMaxCacheBytes());
INDArray[] arrays = new INDArray[100];
for( int i=0; i<arrays.length; i++ ){
arrays[i] = Nd4j.create(DataType.FLOAT, 25); //100 bytes each
}
for( int i=0; i<10; i++ ){
mmgr.release(arrays[i]);
}
assertEquals(1000, mmgr.getCurrentCacheSize());
ArrayCacheMemoryMgr.ArrayStore as = mmgr.getArrayStores().get(DataType.FLOAT);
assertEquals(1000, as.getBytesSum());
assertEquals(250, as.getLengthSum());
assertEquals(10, as.getSize());
assertEquals(10, mmgr.getLruCache().size());
assertEquals(10, mmgr.getLruCacheValues().size());
//At this point: array store is full.
//If we try to release more, the oldest (first released) values should be closed
for( int i=0; i<10; i++ ) {
INDArray toRelease = Nd4j.create(DataType.FLOAT, 25);
mmgr.release(toRelease);
//oldest N only should be closed by this point...
for( int j=0; j<10; j++ ){
if(j <= i){
//Should have been closed
assertTrue(arrays[j].wasClosed());
} else {
//Should still be open
assertFalse(arrays[j].wasClosed());
}
}
}
assertEquals(1000, mmgr.getCurrentCacheSize());
assertEquals(1000, as.getBytesSum());
assertEquals(250, as.getLengthSum());
assertEquals(10, as.getSize());
assertEquals(10, mmgr.getLruCache().size());
assertEquals(10, mmgr.getLruCacheValues().size());
//now, allocate some values:
for( int i=1; i<=10; i++ ) {
INDArray a1 = mmgr.allocate(true, DataType.FLOAT, 25);
assertEquals(1000 - i * 100, mmgr.getCurrentCacheSize());
assertEquals(1000 - i * 100, as.getBytesSum());
assertEquals(250 - i * 25, as.getLengthSum());
assertEquals(10 - i, as.getSize());
assertEquals(10 - i, mmgr.getLruCache().size());
assertEquals(10 - i, mmgr.getLruCacheValues().size());
}
assertEquals(0, mmgr.getCurrentCacheSize());
assertEquals(0, as.getBytesSum());
assertEquals(0, as.getLengthSum());
assertEquals(0, as.getSize());
assertEquals(0, mmgr.getLruCache().size());
assertEquals(0, mmgr.getLruCacheValues().size());
}
@Test
public void testManyArrays(){
ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr();
for( int i=0; i<1000; i++ ){
mmgr.release(Nd4j.scalar(0));
}
assertEquals(4*1000, mmgr.getCurrentCacheSize());
assertEquals(1000, mmgr.getLruCache().size());
assertEquals(1000, mmgr.getLruCacheValues().size());
for( int i=0; i<1000; i++ ){
mmgr.release(Nd4j.scalar(0));
}
assertEquals(4*2000, mmgr.getCurrentCacheSize());
assertEquals(2000, mmgr.getLruCache().size());
assertEquals(2000, mmgr.getLruCacheValues().size());
}
}