better handling of INDArray.close() (#154)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 10:24:56 +03:00 committed by GitHub
parent 243bf866c4
commit dbea687903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 135 additions and 10 deletions

View File

@ -6697,6 +6697,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if (released || isEmpty()) if (released || isEmpty())
return; return;
Nd4j.getExecutioner().commit();
if (!closeable()) if (!closeable())
throw new ND4JIllegalStateException("Can't release this INDArray"); throw new ND4JIllegalStateException("Can't release this INDArray");
@ -6715,5 +6717,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering());
} }
@Override
public boolean wasClosed() {
if (released || data().wasClosed())
return true;
return false;
}
} }

View File

@ -1218,6 +1218,11 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
} }
@Override
public boolean wasClosed() {
return false;
}
@Override @Override
public boolean isS() { public boolean isS() {
return false; return false;

View File

@ -263,6 +263,12 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
return false; //todo return false; //todo
} }
@Override
public boolean wasClosed() {
return false;
}
@Override @Override
public int underlyingRank() { public int underlyingRank() {
return rank; return rank;

View File

@ -2831,6 +2831,12 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
void close(); void close();
/**
* This method checks if array or its buffer was closed before
* @return true if was closed, false otherwise
*/
boolean wasClosed();
/** /**
* This method returns empty array with the same dtype/order/shape as this one * This method returns empty array with the same dtype/order/shape as this one
* @return * @return

View File

@ -455,6 +455,14 @@ public class DefaultOpExecutioner implements OpExecutioner {
public long profilingConfigurableHookIn(CustomOp op) { public long profilingConfigurableHookIn(CustomOp op) {
for (val arr: op.inputArguments())
if (arr.wasClosed())
throw new IllegalStateException("One of Input arguments was closed before call");
for (val arr: op.outputArguments())
if (arr.wasClosed())
throw new IllegalStateException("One of Output arguments was closed before call");
if (OpProfiler.getInstance().getConfig() == null) if (OpProfiler.getInstance().getConfig() == null)
return System.nanoTime(); return System.nanoTime();
@ -471,6 +479,18 @@ public class DefaultOpExecutioner implements OpExecutioner {
} }
public long profilingConfigurableHookIn(Op op, DataBuffer... tadBuffers) { public long profilingConfigurableHookIn(Op op, DataBuffer... tadBuffers) {
if (op.x() != null)
if (op.x().wasClosed())
throw new IllegalStateException("Op.X argument was closed before call");
if (op.y() != null)
if (op.y().wasClosed())
throw new IllegalStateException("Op.Y argument was closed before call");
if (op.z() != null)
if (op.z().wasClosed())
throw new IllegalStateException("Op.Z argument was closed before call");
if (OpProfiler.getInstance().getConfig() == null) if (OpProfiler.getInstance().getConfig() == null)
return System.nanoTime(); return System.nanoTime();

View File

@ -69,6 +69,10 @@ public class AllocationPoint {
@Setter @Setter
private boolean isAttached = false; private boolean isAttached = false;
@Getter
@Setter
private volatile boolean released = false;
// thread safety is guaranteed by allocLock // thread safety is guaranteed by allocLock
private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED; private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED;

View File

@ -400,13 +400,18 @@ public class AtomicAllocator implements Allocator {
public void freeMemory(AllocationPoint point) { public void freeMemory(AllocationPoint point) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) { if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
this.getMemoryHandler().getMemoryProvider().free(point); this.getMemoryHandler().getMemoryProvider().free(point);
point.setAllocationStatus(AllocationStatus.HOST);
this.getMemoryHandler().getMemoryProvider().free(point); if (point.getHostPointer() != null) {
this.getMemoryHandler().forget(point, AllocationStatus.DEVICE); point.setAllocationStatus(AllocationStatus.HOST);
this.getMemoryHandler().getMemoryProvider().free(point);
this.getMemoryHandler().forget(point, AllocationStatus.DEVICE);
}
} else { } else {
// call it only once // call it only once
this.getMemoryHandler().getMemoryProvider().free(point); if (point.getHostPointer() != null) {
this.getMemoryHandler().forget(point, AllocationStatus.HOST); this.getMemoryHandler().getMemoryProvider().free(point);
this.getMemoryHandler().forget(point, AllocationStatus.HOST);
}
} }
allocationsMap.remove(point.getObjectId()); allocationsMap.remove(point.getObjectId());

View File

@ -39,10 +39,10 @@ public class CudaDeallocator implements Deallocator {
public void deallocate() { public void deallocate() {
log.trace("Deallocating CUDA memory"); log.trace("Deallocating CUDA memory");
// skipping any allocation that is coming from workspace // skipping any allocation that is coming from workspace
if (point.isAttached()) { if (point.isAttached() || point.isReleased()) {
// TODO: remove allocation point as well? // TODO: remove allocation point as well?
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId())) if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
throw new RuntimeException(); return;
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point); AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);

View File

@ -544,6 +544,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
*/ */
@Override @Override
public long address() { public long address() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return allocationPoint.getPointers().getHostPointer().address(); return allocationPoint.getPointers().getHostPointer().address();
} }
@ -554,6 +557,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override @Override
public Pointer pointer() { public Pointer pointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
// FIXME: very bad thing, // FIXME: very bad thing,
lazyAllocateHostPointer(); lazyAllocateHostPointer();
@ -1109,6 +1115,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override @Override
public Pointer addressPointer() { public Pointer addressPointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return AtomicAllocator.getInstance().getHostPointer(this); return AtomicAllocator.getInstance().getHostPointer(this);
} }
@ -1562,7 +1571,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override @Override
protected void release() { protected void release() {
AtomicAllocator.getInstance().freeMemory(allocationPoint); if (!released) {
AtomicAllocator.getInstance().freeMemory(allocationPoint);
allocationPoint.setReleased(true);
}
released = true; released = true;
} }

View File

@ -616,14 +616,23 @@ public abstract class BaseDataBuffer implements DataBuffer {
*/ */
@Override @Override
public Indexer indexer() { public Indexer indexer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return indexer; return indexer;
} }
@Override @Override
public Pointer pointer() { public Pointer pointer() {
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this) if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this) {
if (underlyingDataBuffer().wasClosed())
throw new IllegalStateException("You can't use DataBuffer once it was released");
return underlyingDataBuffer().pointer(); return underlyingDataBuffer().pointer();
else { } else {
if (underlyingDataBuffer() != null) if (underlyingDataBuffer() != null)
if (((BaseDataBuffer) underlyingDataBuffer()).released) if (((BaseDataBuffer) underlyingDataBuffer()).released)
throw new IllegalStateException("Underlying buffer was released via close() call"); throw new IllegalStateException("Underlying buffer was released via close() call");
@ -923,6 +932,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public Pointer addressPointer() { public Pointer addressPointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (offset() > 0) { if (offset() > 0) {
Pointer ret; Pointer ret;
@ -968,6 +979,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public long address() { public long address() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return pointer().address() + getElementSize() * offset(); return pointer().address() + getElementSize() * offset();
} }
@ -1409,6 +1423,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public double getDouble(long i) { public double getDouble(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (indexer == null) { if (indexer == null) {
throw new IllegalStateException("Indexer must never be null"); throw new IllegalStateException("Indexer must never be null");
} }
@ -1444,6 +1461,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public long getLong(long i) { public long getLong(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case FLOAT: case FLOAT:
return (long) ((FloatIndexer) indexer).get(offset() + i); return (long) ((FloatIndexer) indexer).get(offset() + i);
@ -1480,6 +1500,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
* @return * @return
*/ */
protected short getShort(long i) { protected short getShort(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case DOUBLE: case DOUBLE:
return (short) ((DoubleIndexer) indexer).get(offset() + i); return (short) ((DoubleIndexer) indexer).get(offset() + i);
@ -1518,6 +1541,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public float getFloat(long i) { public float getFloat(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case DOUBLE: case DOUBLE:
return (float) ((DoubleIndexer) indexer).get(offset() + i); return (float) ((DoubleIndexer) indexer).get(offset() + i);
@ -1550,6 +1576,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public int getInt(long i) { public int getInt(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case DOUBLE: case DOUBLE:
return (int) ((DoubleIndexer) indexer).get(offset() + i); return (int) ((DoubleIndexer) indexer).get(offset() + i);
@ -1582,6 +1611,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public Number getNumber(long i) { public Number getNumber(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (dataType() == DataType.DOUBLE) if (dataType() == DataType.DOUBLE)
return getDouble(i); return getDouble(i);
else if (dataType() == DataType.INT) else if (dataType() == DataType.INT)
@ -1685,6 +1717,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public void put(long i, float element) { public void put(long i, float element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case BOOL: case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true); ((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true);
@ -1732,6 +1767,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public void put(long i, double element) { public void put(long i, double element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case BOOL: case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element > 0.0); ((BooleanIndexer) indexer).put(offset() + i, element > 0.0);
@ -1779,6 +1817,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public void put(long i, int element) { public void put(long i, int element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case BOOL: case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
@ -1826,6 +1867,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public void put(long i, boolean element) { public void put(long i, boolean element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case BOOL: case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element); ((BooleanIndexer) indexer).put(offset() + i, element);
@ -1873,6 +1917,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override @Override
public void put(long i, long element) { public void put(long i, long element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) { switch (dataType()) {
case BOOL: case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true); ((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
@ -2667,4 +2714,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
public long platformAddress() { public long platformAddress() {
return address(); return address();
} }
@Override
public boolean wasClosed() {
return released;
}
} }

View File

@ -703,4 +703,10 @@ public interface DataBuffer extends Serializable, AutoCloseable {
* PLEASE NOTE: This method is NOT safe by any means * PLEASE NOTE: This method is NOT safe by any means
*/ */
void close(); void close();
/**
* This method checks if array or its buffer was closed before
* @return true if was closed, false otherwise
*/
boolean wasClosed();
} }