From dbea687903301fad56f21030e39e78ea3a914640 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 23 Aug 2019 10:24:56 +0300 Subject: [PATCH] better handling of INDArray.close() (#154) Signed-off-by: raver119 --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 8 +++ .../api/ndarray/BaseSparseNDArrayCOO.java | 5 ++ .../api/ndarray/BaseSparseNDArrayCSR.java | 6 ++ .../org/nd4j/linalg/api/ndarray/INDArray.java | 6 ++ .../ops/executioner/DefaultOpExecutioner.java | 20 +++++++ .../jita/allocator/impl/AllocationPoint.java | 4 ++ .../jita/allocator/impl/AtomicAllocator.java | 15 +++-- .../jita/allocator/impl/CudaDeallocator.java | 4 +- .../jcublas/buffer/BaseCudaDataBuffer.java | 14 ++++- .../linalg/api/buffer/BaseDataBuffer.java | 57 ++++++++++++++++++- .../nd4j/linalg/api/buffer/DataBuffer.java | 6 ++ 11 files changed, 135 insertions(+), 10 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index c02b3f546..012b974fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -6697,6 +6697,8 @@ public abstract class BaseNDArray implements INDArray, Iterable { if (released || isEmpty()) return; + Nd4j.getExecutioner().commit(); + if (!closeable()) throw new ND4JIllegalStateException("Can't release this INDArray"); @@ -6715,5 +6717,11 @@ public abstract class BaseNDArray implements INDArray, Iterable { return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); } + @Override + public boolean wasClosed() { + if (released || data().wasClosed()) + return true; + return false; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java index 9c49f573d..1e85be0cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java @@ -1218,6 +1218,11 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray { } + @Override + public boolean wasClosed() { + return false; + } + @Override public boolean isS() { return false; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java index ee2e3f0a8..bd0f7c905 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCSR.java @@ -263,6 +263,12 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray { return false; //todo } + @Override + public boolean wasClosed() { + return false; + } + + @Override public int underlyingRank() { return rank; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 7c547f4af..727b5db6d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -2831,6 +2831,12 @@ public interface INDArray extends Serializable, AutoCloseable { */ void close(); + /** + * This method checks if array or its buffer was closed before + * @return true if was closed, false otherwise + */ + boolean wasClosed(); + /** * This method returns empty array with the same dtype/order/shape as this one * @return diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java index 330fb1c31..57606e452 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.java @@ -455,6 +455,14 @@ public class DefaultOpExecutioner implements OpExecutioner { public long profilingConfigurableHookIn(CustomOp op) { + for (val arr: op.inputArguments()) + if (arr.wasClosed()) + throw new IllegalStateException("One of Input arguments was closed before call"); + + for (val arr: op.outputArguments()) + if (arr.wasClosed()) + throw new IllegalStateException("One of Output arguments was closed before call"); + if (OpProfiler.getInstance().getConfig() == null) return System.nanoTime(); @@ -471,6 +479,18 @@ public class DefaultOpExecutioner implements OpExecutioner { } public long profilingConfigurableHookIn(Op op, DataBuffer... tadBuffers) { + if (op.x() != null) + if (op.x().wasClosed()) + throw new IllegalStateException("Op.X argument was closed before call"); + + if (op.y() != null) + if (op.y().wasClosed()) + throw new IllegalStateException("Op.Y argument was closed before call"); + + if (op.z() != null) + if (op.z().wasClosed()) + throw new IllegalStateException("Op.Z argument was closed before call"); + if (OpProfiler.getInstance().getConfig() == null) return System.nanoTime(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java index e03871d43..f673a15d7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java @@ -69,6 +69,10 @@ public class AllocationPoint { @Setter private boolean isAttached = false; + @Getter + @Setter + private volatile boolean released = false; + // thread safety is guaranteed by allocLock private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index 873e97809..8fbf0a000 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -400,13 +400,18 @@ public class AtomicAllocator implements Allocator { public void freeMemory(AllocationPoint point) { if (point.getAllocationStatus() == AllocationStatus.DEVICE) { this.getMemoryHandler().getMemoryProvider().free(point); - point.setAllocationStatus(AllocationStatus.HOST); - this.getMemoryHandler().getMemoryProvider().free(point); - this.getMemoryHandler().forget(point, AllocationStatus.DEVICE); + + if (point.getHostPointer() != null) { + point.setAllocationStatus(AllocationStatus.HOST); + this.getMemoryHandler().getMemoryProvider().free(point); + this.getMemoryHandler().forget(point, AllocationStatus.DEVICE); + } } else { // call it only once - this.getMemoryHandler().getMemoryProvider().free(point); - this.getMemoryHandler().forget(point, AllocationStatus.HOST); + if (point.getHostPointer() != null) { + this.getMemoryHandler().getMemoryProvider().free(point); + this.getMemoryHandler().forget(point, AllocationStatus.HOST); + } } allocationsMap.remove(point.getObjectId()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java index e50dd059c..ae1ad93cd 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java @@ -39,10 +39,10 @@ public class CudaDeallocator implements Deallocator { public void deallocate() { log.trace("Deallocating CUDA memory"); // skipping any allocation that is coming from workspace - if (point.isAttached()) { + if (point.isAttached() || point.isReleased()) { // TODO: remove allocation point as well? if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId())) - throw new RuntimeException(); + return; AtomicAllocator.getInstance().getFlowController().waitTillReleased(point); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 1bec19dd0..5e0583d56 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -544,6 +544,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda */ @Override public long address() { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + return allocationPoint.getPointers().getHostPointer().address(); } @@ -554,6 +557,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public Pointer pointer() { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + // FIXME: very bad thing, lazyAllocateHostPointer(); @@ -1109,6 +1115,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public Pointer addressPointer() { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + return AtomicAllocator.getInstance().getHostPointer(this); } @@ -1562,7 +1571,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override protected void release() { - AtomicAllocator.getInstance().freeMemory(allocationPoint); + if (!released) { + AtomicAllocator.getInstance().freeMemory(allocationPoint); + allocationPoint.setReleased(true); + } released = true; } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 82fff8437..15249acc9 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -616,14 +616,23 @@ public abstract class BaseDataBuffer implements DataBuffer { */ @Override public Indexer indexer() { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + return indexer; } @Override public Pointer pointer() { - if (underlyingDataBuffer() != null && underlyingDataBuffer() != this) + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + + if (underlyingDataBuffer() != null && underlyingDataBuffer() != this) { + if (underlyingDataBuffer().wasClosed()) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + return underlyingDataBuffer().pointer(); - else { + } else { if (underlyingDataBuffer() != null) if (((BaseDataBuffer) underlyingDataBuffer()).released) throw new IllegalStateException("Underlying buffer was released via close() call"); @@ -923,6 +932,8 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public Pointer addressPointer() { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); if (offset() > 0) { Pointer ret; @@ -968,6 +979,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public long address() { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + return pointer().address() + getElementSize() * offset(); } @@ -1409,6 +1423,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public double getDouble(long i) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + if (indexer == null) { throw new IllegalStateException("Indexer must never be null"); } @@ -1444,6 +1461,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public long getLong(long i) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case FLOAT: return (long) ((FloatIndexer) indexer).get(offset() + i); @@ -1480,6 +1500,9 @@ public abstract class BaseDataBuffer implements DataBuffer { * @return */ protected short getShort(long i) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case DOUBLE: return (short) ((DoubleIndexer) indexer).get(offset() + i); @@ -1518,6 +1541,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public float getFloat(long i) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case DOUBLE: return (float) ((DoubleIndexer) indexer).get(offset() + i); @@ -1550,6 +1576,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public int getInt(long i) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case DOUBLE: return (int) ((DoubleIndexer) indexer).get(offset() + i); @@ -1582,6 +1611,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public Number getNumber(long i) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + if (dataType() == DataType.DOUBLE) return getDouble(i); else if (dataType() == DataType.INT) @@ -1685,6 +1717,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public void put(long i, float element) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case BOOL: ((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true); @@ -1732,6 +1767,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public void put(long i, double element) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case BOOL: ((BooleanIndexer) indexer).put(offset() + i, element > 0.0); @@ -1779,6 +1817,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public void put(long i, int element) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case BOOL: ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); @@ -1826,6 +1867,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public void put(long i, boolean element) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case BOOL: ((BooleanIndexer) indexer).put(offset() + i, element); @@ -1873,6 +1917,9 @@ public abstract class BaseDataBuffer implements DataBuffer { @Override public void put(long i, long element) { + if (released) + throw new IllegalStateException("You can't use DataBuffer once it was released"); + switch (dataType()) { case BOOL: ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); @@ -2667,4 +2714,10 @@ public abstract class BaseDataBuffer implements DataBuffer { public long platformAddress() { return address(); } + + + @Override + public boolean wasClosed() { + return released; + } } diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java index 28b04f07f..9b1c2ecec 100644 --- a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java +++ b/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java @@ -703,4 +703,10 @@ public interface DataBuffer extends Serializable, AutoCloseable { * PLEASE NOTE: This method is NOT safe by any means */ void close(); + + /** + * This method checks if array or its buffer was closed before + * @return true if was closed, false otherwise + */ + boolean wasClosed(); }