From 18c01f5bdc4b2349869b861aaebfbc4c7c9e9c21 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 Nov 2019 21:15:44 +1100 Subject: [PATCH] Add SameDiff memory reuse memory manager (array cache) (#39) * Attention op comments Signed-off-by: AlexDBlack * ArrayCacheMemoryMgr - first pass Signed-off-by: AlexDBlack * Tweak array cache for use with SameDiff identity arrays Signed-off-by: AlexDBlack * ArrayCacheMemoryMgr javadoc and properly get max memory Signed-off-by: AlexDBlack * LRU cache policy + add tests Signed-off-by: AlexDBlack * Fixes Signed-off-by: AlexDBlack * Resize arrays internally if required for ArrayCacheMemoryMgr Signed-off-by: AlexDBlack * Test improvement Signed-off-by: AlexDBlack * Small polish Signed-off-by: AlexDBlack --- .../CompareTrainingImplementations.java | 1 + .../include/helpers/impl/AttentionHelper.cpp | 10 +- .../nn/multi_head_dot_product_attention.cpp | 21 +- .../samediff/internal/InferenceSession.java | 42 +-- .../internal/memory/ArrayCacheMemoryMgr.java | 292 ++++++++++++++++++ .../nd4j/autodiff/samediff/MemoryMgrTest.java | 119 +++++++ 6 files changed, 442 insertions(+), 43 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index 12564f01a..fa0fc335f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -98,6 +98,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { SDVariable diff = sd.f().squaredDifference(a1, label); SDVariable lossMse = diff.mean(); + lossMse.markAsLoss(); IUpdater updater; double lr; diff --git a/libnd4j/include/helpers/impl/AttentionHelper.cpp b/libnd4j/include/helpers/impl/AttentionHelper.cpp index 4e7393a8e..3cfee1c08 100644 --- a/libnd4j/include/helpers/impl/AttentionHelper.cpp +++ b/libnd4j/include/helpers/impl/AttentionHelper.cpp @@ -34,16 +34,16 @@ namespace nd4j { auto numHeads = projectionMatrix->sizeAt(0); auto projectedSize = projectionMatrix->sizeAt(1); - auto inputPerm = input->permute({1, 0, 2}); - auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); - auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); + auto inputPerm = input->permute({1, 0, 2}); //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps] + auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)}); //[nIn, batch*timeSteps] + auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)}); //[nHeads, hS, nIn] -> [nHeads*hS, nIn] - NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); + NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context); //[nHeads*hS, batch*timeSteps] nd4j::ops::matmul mmul; mmul.execute({&projectionPrep, &inputPrep}, {&projected}, {}, {}, {}); projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength}); - projected.permutei({2, 0, 1, 3}); + projected.permutei({2, 0, 1, 3}); //[minibatch, numHeads, projectedSize, seqLength] return projected; } diff --git a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp index 45324300d..2123317b5 100644 --- a/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/multi_head_dot_product_attention.cpp @@ -28,13 +28,13 @@ namespace nd4j { namespace ops { CUSTOM_OP_IMPL(multi_head_dot_product_attention, 7, -1, false, 0, 2) { - auto queries = INPUT_VARIABLE(0); - auto keys = INPUT_VARIABLE(1); - auto values = INPUT_VARIABLE(2); - auto Wq = INPUT_VARIABLE(3); - auto Wk = INPUT_VARIABLE(4); - auto Wv = INPUT_VARIABLE(5); - auto Wo = INPUT_VARIABLE(6); + auto queries = INPUT_VARIABLE(0); //[batch, nIn, timeSteps] + auto keys = INPUT_VARIABLE(1); //[batch, nIn, timeSteps] + auto values = INPUT_VARIABLE(2); //[batch, nIn, timeSteps] + auto Wq = INPUT_VARIABLE(3); //[nHeads, headSize, nIn] + auto Wk = INPUT_VARIABLE(4); //[nHeads, headSize, nIn] + auto Wv = INPUT_VARIABLE(5); //[nHeads, headSize, nIn] + auto Wo = INPUT_VARIABLE(6); //[nHeads * headSize, nOut] auto mask = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr; @@ -93,11 +93,12 @@ namespace ops { // Project queries, keys, values - auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); - auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); - auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); + auto projectedQueries = AttentionHelper::multiHeadProject(queries, Wq, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + auto projectedKeys = AttentionHelper::multiHeadProject(keys, Wk, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] + auto projectedValues = AttentionHelper::multiHeadProject(values, Wv, block.launchContext()); //[minibatch, numHeads, projectedSize, seqLength] // Apply Attention + // attnResults = [minibatch, numHeads, projectedSize, seqLenth NDArray attnResults('c', {projectedQueries.sizeAt(0), projectedValues.sizeAt(1), projectedValues.sizeAt(2), projectedQueries.sizeAt(3)}, projectedValues.dataType(), block.launchContext()); nd4j::ops::dot_product_attention attention; attention.execute({&projectedQueries, &projectedKeys, &projectedValues, mask}, {&attnResults, weights ? OUTPUT_VARIABLE(1) : nullptr}, {}, {normalization, weights}, {}); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 55165b530..32a1cc362 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -24,7 +24,7 @@ import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; -import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; +import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -84,8 +84,7 @@ public class InferenceSession extends AbstractSession { public InferenceSession(@NonNull SameDiff sameDiff) { super(sameDiff); - - mmgr = new ArrayCloseMemoryMgr(); //TODO replace this with new (planned) array reuse memory manager + mmgr = new ArrayCacheMemoryMgr(); } @Override @@ -215,7 +214,6 @@ public class InferenceSession extends AbstractSession { } INDArray[] out = doExec(op.getOp(), outputFrameIter, opInputs, allIterInputs, constAndPhInputs); - op.getOp().clearArrays(); if (log.isTraceEnabled()) { StringBuilder sb = new StringBuilder(); @@ -254,6 +252,7 @@ public class InferenceSession extends AbstractSession { } } } + op.getOp().clearArrays(); //Record array uses for memory management/deallocation @@ -842,11 +841,10 @@ public class InferenceSession extends AbstractSession { reqShape = reqShape.asDataType(dt); } - if (currOutput == null || currOutput.wasClosed() || !currOutput.shapeDescriptor().equals(reqShape) || currOutput.isEmpty() != reqShape.isEmpty() || isLoop) { - boolean isOutput = allReqVariables.contains(outNames[i]); - INDArray out = mmgr.allocate(isOutput, reqShape); - customOp.setOutputArgument(i, out); - } + //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc + boolean isOutput = allReqVariables.contains(outNames[i]); + INDArray out = mmgr.allocate(isOutput, reqShape); + customOp.setOutputArgument(i, out); } } else if (df instanceof Op) { @@ -893,29 +891,17 @@ public class InferenceSession extends AbstractSession { //Check output shape; allocate a new Z if required //For example, if minibatch size has changed since last op execution + boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]); if (emptyReduce) { - INDArray z = op.z(); - if (z == null || !op.x().equalShapes(z) || isLoop) { - //Note: edge case: [x,y].sum(empty) = [x,y] for TF import compatibility. - z = mmgr.allocate(false, op.x().dataType(), op.x().shape()); - op.setZ(z); - } + //Always allocate new output array, rely on memory manager for efficient memory management and array reuse etc + INDArray z = mmgr.allocate(false, op.x().dataType(), op.x().shape()); + op.setZ(z); } else { List outputShape = ((BaseOp) op).calculateOutputShape(); Preconditions.checkState(outputShape != null && outputShape.size() == 1, "Could not calculate output shape for op: %s", op.getClass()); - INDArray z = op.z(); - if (z == null || z.wasClosed() || !outputShape.get(0).equals(z.shapeDescriptor()) || isLoop) { - if (log.isTraceEnabled()) { - log.trace("Existing op result (z) array shape for op {} was {}, allocating new array of shape {}", - op.getClass().getSimpleName(), (z == null ? null : Arrays.toString(z.shape())), outputShape.get(0).toString()); - } - - LongShapeDescriptor lsd = outputShape.get(0); - - boolean isOutput = allReqVariables.contains(((BaseOp) op).outputVariablesNames()[0]); - z = mmgr.allocate(isOutput, lsd); - op.setZ(z); - } + LongShapeDescriptor lsd = outputShape.get(0); + INDArray z = mmgr.allocate(isOutput, lsd); + op.setZ(z); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java new file mode 100644 index 000000000..c802dd4e2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java @@ -0,0 +1,292 @@ +package org.nd4j.autodiff.samediff.internal.memory; + +import lombok.*; +import org.bytedeco.javacpp.Pointer; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.util.ArrayUtil; + +import java.util.*; + +/** + * ArrayCacheMemoryMgr reuses arrays to reduce the number of memory allocations and deallocations.
+ * Memory allocations and deallocations can be quite expensive, especially on GPUs.
+ * Note that when arrays are reused, they are reused for the same datatype only.
+ * If caching a released array would result in the the maximum cache size being is exceeded, the oldest arrays will + * be deallocated first, until the new array can in the cache. + *

+ * By default, the following parameters are used for the cache: + *
    + *
  • Maximum cache size: 0.25 x max memory, where:
  • + *
      + *
    • CPU: max memory is determined using {@link Pointer#maxBytes()}
    • + *
    • GPU: max memory is determined using GPU 0 total memory
    • + *
    + *
  • Larger array max multiple: 2.0
  • + *
      + *
    • This means: if an exact array size can't be provided from the cache, use the next smallest array with a buffer up to 2.0x larger than requested
    • + *
    • If no cached arrays of size < 2x requested exists, allocate a new array
    • + *
    + *
  • Small array threshold: 1024 elements
  • + *
      + *
    • This means: the "larger array max multiple" doesn't apply below this level. For example, we might return a size 1 array backed by a size 1023 buffer
    • + *
    + *
+ * + * @author Alex Black + */ +@Getter +public class ArrayCacheMemoryMgr extends AbstractMemoryMgr { + + private final double maxMemFrac; + private final long smallArrayThreshold; + private final double largerArrayMaxMultiple; + + private final long maxCacheBytes; + private final long totalMemBytes; + + private long currentCacheSize = 0; + private Map arrayStores = new HashMap<>(); + + private LinkedHashSet lruCache = new LinkedHashSet<>(); + private Map lruCacheValues = new HashMap<>(); + + /** + * Create an ArrayCacheMemoryMgr with default settings as per {@link ArrayCacheMemoryMgr} + */ + public ArrayCacheMemoryMgr() { + this(0.25, 1024, 2.0); + } + + /** + * @param maxMemFrac Maximum memory fraciton to use as cache + * @param smallArrayThreshold Below this size (elements), don't apply the "largerArrayMaxMultiple" rule + * @param largerArrayMaxMultiple Maximum multiple of the requested size to return from the cache. If an array of size + * 1024 is requested, and largerArrayMaxMultiple is 2.0, then we'll return from the cache + * the array with the smallest data buffer up to 2.0*1024 elements; otherwise we'll return + * a new array + */ + public ArrayCacheMemoryMgr(double maxMemFrac, long smallArrayThreshold, double largerArrayMaxMultiple) { + Preconditions.checkArgument(maxMemFrac > 0 && maxMemFrac < 1, "Maximum memory fraction for cache must be between 0.0 and 1.0, got %s", maxMemFrac); + Preconditions.checkArgument(smallArrayThreshold >= 0, "Small array threshould must be >= 0, got %s", smallArrayThreshold); + Preconditions.checkArgument(largerArrayMaxMultiple >= 1.0, "Larger array max multiple must be >= 1.0, got %s", largerArrayMaxMultiple); + this.maxMemFrac = maxMemFrac; + this.smallArrayThreshold = smallArrayThreshold; + this.largerArrayMaxMultiple = largerArrayMaxMultiple; + + if(isCpu()){ + totalMemBytes = Pointer.maxBytes(); + } else { + Properties p = Nd4j.getExecutioner().getEnvironmentInformation(); + List devList = (List) p.get("cuda.devicesInformation"); + Map m = (Map) devList.get(0); + totalMemBytes = (Long)m.get("cuda.totalMemory"); + } + maxCacheBytes = (long)(maxMemFrac * totalMemBytes); + } + + private boolean isCpu(){ + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + return !"CUDA".equalsIgnoreCase(backend); + } + + @Override + public INDArray allocate(boolean detached, DataType dataType, long... shape) { + if (arrayStores.containsKey(dataType)) { + INDArray arr = arrayStores.get(dataType).get(shape); + if (arr != null) { + //Decrement cache size + currentCacheSize -= dataType.width() * arr.data().length(); + + return arr; //Allocated from cache + } + } + + //Allocation failed, allocate new array + return Nd4j.createUninitializedDetached(dataType, shape); + } + + @Override + public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) { + return allocate(detached, descriptor.dataType(), descriptor.getShape()); + } + + @Override + public void release(@NonNull INDArray array) { + //Check for multiple releases of the array + long id = array.getId(); + Preconditions.checkState(!lruCache.contains(id), "Array was released multiple times: id=%s, shape=%ndShape", id, array); + + + DataType dt = array.dataType(); + long thisBytes = array.data().length() * dt.width(); + if(array.dataType() == DataType.UTF8) { + //Don't cache string arrays due to variable length buffers + if(array.closeable()) + array.close(); + } else if (currentCacheSize + thisBytes > maxCacheBytes) { + if(thisBytes > maxCacheBytes){ + //Can't store even if we clear everything - too large + if(array.closeable()) + array.close(); + return; + } + + //Need to deallocate some arrays to stay under limit - do in "oldest first" order + Iterator iter = lruCache.iterator(); + while(currentCacheSize + thisBytes > maxCacheBytes){ + long next = iter.next(); + iter.remove(); + INDArray nextOldest = lruCacheValues.remove(next); + DataType ndt = nextOldest.dataType(); + long nextBytes = ndt.width() * nextOldest.data().length(); + arrayStores.get(ndt).removeObject(nextOldest); + currentCacheSize -= nextBytes; + + if(nextOldest.closeable()) + nextOldest.close(); + } + + //After clearing space - can now cache + cacheArray(array); + } else { + //OK to cache + cacheArray(array); + } + + //Store in LRU cache for "last used" removal if we exceed cache size + lruCache.add(array.getId()); + lruCacheValues.put(array.getId(), array); + } + + private void cacheArray(INDArray array){ + DataType dt = array.dataType(); + if (!arrayStores.containsKey(dt)) + arrayStores.put(dt, new ArrayStore()); + arrayStores.get(dt).add(array); + currentCacheSize += array.data().length() * dt.width(); + + lruCache.add(array.getId()); + lruCacheValues.put(array.getId(), array); + } + + @Override + public void close() { + for (ArrayStore as : arrayStores.values()) { + as.close(); + } + } + + + @Getter + public class ArrayStore { + private INDArray[] sorted = new INDArray[1000]; //TODO resizing, don't hardcode + private long[] lengths = new long[1000]; + private long lengthSum; + private long bytesSum; + private int size; + + private void add(@NonNull INDArray array) { + //Resize arrays + if(size == sorted.length){ + sorted = Arrays.copyOf(sorted, 2*sorted.length); + lengths = Arrays.copyOf(lengths, 2*lengths.length); + } + + long length = array.data().length(); + int idx = Arrays.binarySearch(lengths, 0, size, length); + if (idx < 0) { + idx = -idx - 1; //See binarySearch javadoc + } + for (int i = size - 1; i >= idx; i--) { + sorted[i + 1] = sorted[i]; + lengths[i + 1] = lengths[i]; + } + sorted[idx] = array; + lengths[idx] = length; + size++; + lengthSum += length; + bytesSum += length * array.dataType().width(); + } + + private INDArray get(long[] shape) { + if (size == 0) + return null; + + long length = shape.length == 0 ? 1 : ArrayUtil.prod(shape); + + int idx = Arrays.binarySearch(lengths, 0, size, length); + if (idx < 0) { + idx = -idx - 1; + if (idx >= size) { + //Largest array is smaller than required -> can't return from cache + return null; + } + INDArray nextSmallest = sorted[idx]; + long nextSmallestLength = nextSmallest.data().length(); + long nextSmallestLengthBytes = nextSmallestLength * nextSmallest.dataType().width(); + + boolean tooLarge = (length > (long) (nextSmallestLength * largerArrayMaxMultiple)); + + if (nextSmallestLengthBytes > smallArrayThreshold && tooLarge) { + return null; + } // If less than smallArrayThreshold, ok, return as is + } + + //Remove + INDArray arr = removeIdx(idx); + + lruCache.remove(arr.getId()); + lruCacheValues.remove(arr.getId()); + + //Create a new array with the specified buffer. This is for 2 reasons: + //(a) the cached array and requested array sizes may differ (though this is easy to check for) + //(b) Some SameDiff array use tracking uses *object identity* - so we want different objects when reusing arrays + // to avoid issues there + return Nd4j.create(arr.data(), shape); + } + + private void removeObject(INDArray array){ + long length = array.data().length(); + int idx = Arrays.binarySearch(lengths, 0, size, length); + Preconditions.checkState(idx > 0, "Cannot remove array from ArrayStore: no array with this length exists in the cache"); + boolean found = false; + int i = 0; + while(!found && i <= size && lengths[i] == length){ + found = sorted[i++] == array; //Object equality + } + Preconditions.checkState(found, "Cannot remove array: not found in ArrayCache"); + removeIdx(i - 1); + } + + private INDArray removeIdx(int idx){ + INDArray arr = sorted[idx]; + for (int i = idx; i < size; i++) { + sorted[i] = sorted[i + 1]; + lengths[i] = lengths[i + 1]; + } + sorted[size] = null; + lengths[size] = 0; + size--; + + bytesSum -= (arr.data().length() * arr.dataType().width()); + lengthSum -= arr.data().length(); + + return arr; + } + + private void close() { + for (int i = 0; i < size; i++) { + if (sorted[i].closeable()) + sorted[i].close(); + lengths[i] = 0; + } + lengthSum = 0; + bytesSum = 0; + size = 0; + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java new file mode 100644 index 000000000..6505bee20 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/MemoryMgrTest.java @@ -0,0 +1,119 @@ +package org.nd4j.autodiff.samediff; + +import org.junit.Test; +import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import java.lang.reflect.Field; + +import static org.junit.Assert.*; + +public class MemoryMgrTest extends BaseNd4jTest { + + public MemoryMgrTest(Nd4jBackend b){ + super(b); + } + + @Override + public char ordering(){ + return 'c'; + } + + @Test + public void testArrayReuseTooLarge() throws Exception { + + ArrayCacheMemoryMgr mmgr = new ArrayCacheMemoryMgr(); + Field f = ArrayCacheMemoryMgr.class.getDeclaredField("maxCacheBytes"); + f.setAccessible(true); + f.set(mmgr, 1000); + + assertEquals(1000, mmgr.getMaxCacheBytes()); + + INDArray[] arrays = new INDArray[100]; + for( int i=0; i