better handling of INDArray.close() (#154)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-23 10:24:56 +03:00 committed by GitHub
parent 243bf866c4
commit dbea687903
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 135 additions and 10 deletions

View File

@ -6697,6 +6697,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if (released || isEmpty())
return;
Nd4j.getExecutioner().commit();
if (!closeable())
throw new ND4JIllegalStateException("Can't release this INDArray");
@ -6715,5 +6717,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering());
}
@Override
public boolean wasClosed() {
if (released || data().wasClosed())
return true;
return false;
}
}

View File

@ -1218,6 +1218,11 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
}
@Override
public boolean wasClosed() {
return false;
}
@Override
public boolean isS() {
return false;

View File

@ -263,6 +263,12 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
return false; //todo
}
@Override
public boolean wasClosed() {
return false;
}
@Override
public int underlyingRank() {
return rank;

View File

@ -2831,6 +2831,12 @@ public interface INDArray extends Serializable, AutoCloseable {
*/
void close();
/**
* This method checks if array or its buffer was closed before
* @return true if was closed, false otherwise
*/
boolean wasClosed();
/**
* This method returns empty array with the same dtype/order/shape as this one
* @return

View File

@ -455,6 +455,14 @@ public class DefaultOpExecutioner implements OpExecutioner {
public long profilingConfigurableHookIn(CustomOp op) {
for (val arr: op.inputArguments())
if (arr.wasClosed())
throw new IllegalStateException("One of Input arguments was closed before call");
for (val arr: op.outputArguments())
if (arr.wasClosed())
throw new IllegalStateException("One of Output arguments was closed before call");
if (OpProfiler.getInstance().getConfig() == null)
return System.nanoTime();
@ -471,6 +479,18 @@ public class DefaultOpExecutioner implements OpExecutioner {
}
public long profilingConfigurableHookIn(Op op, DataBuffer... tadBuffers) {
if (op.x() != null)
if (op.x().wasClosed())
throw new IllegalStateException("Op.X argument was closed before call");
if (op.y() != null)
if (op.y().wasClosed())
throw new IllegalStateException("Op.Y argument was closed before call");
if (op.z() != null)
if (op.z().wasClosed())
throw new IllegalStateException("Op.Z argument was closed before call");
if (OpProfiler.getInstance().getConfig() == null)
return System.nanoTime();

View File

@ -69,6 +69,10 @@ public class AllocationPoint {
@Setter
private boolean isAttached = false;
@Getter
@Setter
private volatile boolean released = false;
// thread safety is guaranteed by allocLock
private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED;

View File

@ -400,14 +400,19 @@ public class AtomicAllocator implements Allocator {
public void freeMemory(AllocationPoint point) {
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
this.getMemoryHandler().getMemoryProvider().free(point);
if (point.getHostPointer() != null) {
point.setAllocationStatus(AllocationStatus.HOST);
this.getMemoryHandler().getMemoryProvider().free(point);
this.getMemoryHandler().forget(point, AllocationStatus.DEVICE);
}
} else {
// call it only once
if (point.getHostPointer() != null) {
this.getMemoryHandler().getMemoryProvider().free(point);
this.getMemoryHandler().forget(point, AllocationStatus.HOST);
}
}
allocationsMap.remove(point.getObjectId());
}

View File

@ -39,10 +39,10 @@ public class CudaDeallocator implements Deallocator {
public void deallocate() {
log.trace("Deallocating CUDA memory");
// skipping any allocation that is coming from workspace
if (point.isAttached()) {
if (point.isAttached() || point.isReleased()) {
// TODO: remove allocation point as well?
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
throw new RuntimeException();
return;
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);

View File

@ -544,6 +544,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
*/
@Override
public long address() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return allocationPoint.getPointers().getHostPointer().address();
}
@ -554,6 +557,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override
public Pointer pointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
// FIXME: very bad thing,
lazyAllocateHostPointer();
@ -1109,6 +1115,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override
public Pointer addressPointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return AtomicAllocator.getInstance().getHostPointer(this);
}
@ -1562,7 +1571,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
@Override
protected void release() {
if (!released) {
AtomicAllocator.getInstance().freeMemory(allocationPoint);
allocationPoint.setReleased(true);
}
released = true;
}

View File

@ -616,14 +616,23 @@ public abstract class BaseDataBuffer implements DataBuffer {
*/
@Override
public Indexer indexer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return indexer;
}
@Override
public Pointer pointer() {
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this)
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this) {
if (underlyingDataBuffer().wasClosed())
throw new IllegalStateException("You can't use DataBuffer once it was released");
return underlyingDataBuffer().pointer();
else {
} else {
if (underlyingDataBuffer() != null)
if (((BaseDataBuffer) underlyingDataBuffer()).released)
throw new IllegalStateException("Underlying buffer was released via close() call");
@ -923,6 +932,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public Pointer addressPointer() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (offset() > 0) {
Pointer ret;
@ -968,6 +979,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public long address() {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
return pointer().address() + getElementSize() * offset();
}
@ -1409,6 +1423,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public double getDouble(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (indexer == null) {
throw new IllegalStateException("Indexer must never be null");
}
@ -1444,6 +1461,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public long getLong(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case FLOAT:
return (long) ((FloatIndexer) indexer).get(offset() + i);
@ -1480,6 +1500,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
* @return
*/
protected short getShort(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case DOUBLE:
return (short) ((DoubleIndexer) indexer).get(offset() + i);
@ -1518,6 +1541,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public float getFloat(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case DOUBLE:
return (float) ((DoubleIndexer) indexer).get(offset() + i);
@ -1550,6 +1576,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public int getInt(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case DOUBLE:
return (int) ((DoubleIndexer) indexer).get(offset() + i);
@ -1582,6 +1611,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public Number getNumber(long i) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
if (dataType() == DataType.DOUBLE)
return getDouble(i);
else if (dataType() == DataType.INT)
@ -1685,6 +1717,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public void put(long i, float element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true);
@ -1732,6 +1767,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public void put(long i, double element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element > 0.0);
@ -1779,6 +1817,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public void put(long i, int element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
@ -1826,6 +1867,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public void put(long i, boolean element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element);
@ -1873,6 +1917,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
@Override
public void put(long i, long element) {
if (released)
throw new IllegalStateException("You can't use DataBuffer once it was released");
switch (dataType()) {
case BOOL:
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
@ -2667,4 +2714,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
public long platformAddress() {
return address();
}
@Override
public boolean wasClosed() {
return released;
}
}

View File

@ -703,4 +703,10 @@ public interface DataBuffer extends Serializable, AutoCloseable {
* PLEASE NOTE: This method is NOT safe by any means
*/
void close();
/**
* This method checks if array or its buffer was closed before
* @return true if was closed, false otherwise
*/
boolean wasClosed();
}