parent
243bf866c4
commit
dbea687903
|
@ -6697,6 +6697,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
if (released || isEmpty())
|
||||
return;
|
||||
|
||||
Nd4j.getExecutioner().commit();
|
||||
|
||||
if (!closeable())
|
||||
throw new ND4JIllegalStateException("Can't release this INDArray");
|
||||
|
||||
|
@ -6715,5 +6717,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean wasClosed() {
|
||||
if (released || data().wasClosed())
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1218,6 +1218,11 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
|
|||
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean wasClosed() {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isS() {
|
||||
return false;
|
||||
|
|
|
@ -263,6 +263,12 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
|
|||
return false; //todo
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean wasClosed() {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public int underlyingRank() {
|
||||
return rank;
|
||||
|
|
|
@ -2831,6 +2831,12 @@ public interface INDArray extends Serializable, AutoCloseable {
|
|||
*/
|
||||
void close();
|
||||
|
||||
/**
|
||||
* This method checks if array or its buffer was closed before
|
||||
* @return true if was closed, false otherwise
|
||||
*/
|
||||
boolean wasClosed();
|
||||
|
||||
/**
|
||||
* This method returns empty array with the same dtype/order/shape as this one
|
||||
* @return
|
||||
|
|
|
@ -455,6 +455,14 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
|||
|
||||
|
||||
public long profilingConfigurableHookIn(CustomOp op) {
|
||||
for (val arr: op.inputArguments())
|
||||
if (arr.wasClosed())
|
||||
throw new IllegalStateException("One of Input arguments was closed before call");
|
||||
|
||||
for (val arr: op.outputArguments())
|
||||
if (arr.wasClosed())
|
||||
throw new IllegalStateException("One of Output arguments was closed before call");
|
||||
|
||||
if (OpProfiler.getInstance().getConfig() == null)
|
||||
return System.nanoTime();
|
||||
|
||||
|
@ -471,6 +479,18 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
|||
}
|
||||
|
||||
public long profilingConfigurableHookIn(Op op, DataBuffer... tadBuffers) {
|
||||
if (op.x() != null)
|
||||
if (op.x().wasClosed())
|
||||
throw new IllegalStateException("Op.X argument was closed before call");
|
||||
|
||||
if (op.y() != null)
|
||||
if (op.y().wasClosed())
|
||||
throw new IllegalStateException("Op.Y argument was closed before call");
|
||||
|
||||
if (op.z() != null)
|
||||
if (op.z().wasClosed())
|
||||
throw new IllegalStateException("Op.Z argument was closed before call");
|
||||
|
||||
if (OpProfiler.getInstance().getConfig() == null)
|
||||
return System.nanoTime();
|
||||
|
||||
|
|
|
@ -69,6 +69,10 @@ public class AllocationPoint {
|
|||
@Setter
|
||||
private boolean isAttached = false;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private volatile boolean released = false;
|
||||
|
||||
// thread safety is guaranteed by allocLock
|
||||
private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED;
|
||||
|
||||
|
|
|
@ -400,14 +400,19 @@ public class AtomicAllocator implements Allocator {
|
|||
public void freeMemory(AllocationPoint point) {
|
||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
||||
this.getMemoryHandler().getMemoryProvider().free(point);
|
||||
|
||||
if (point.getHostPointer() != null) {
|
||||
point.setAllocationStatus(AllocationStatus.HOST);
|
||||
this.getMemoryHandler().getMemoryProvider().free(point);
|
||||
this.getMemoryHandler().forget(point, AllocationStatus.DEVICE);
|
||||
}
|
||||
} else {
|
||||
// call it only once
|
||||
if (point.getHostPointer() != null) {
|
||||
this.getMemoryHandler().getMemoryProvider().free(point);
|
||||
this.getMemoryHandler().forget(point, AllocationStatus.HOST);
|
||||
}
|
||||
}
|
||||
|
||||
allocationsMap.remove(point.getObjectId());
|
||||
}
|
||||
|
|
|
@ -39,10 +39,10 @@ public class CudaDeallocator implements Deallocator {
|
|||
public void deallocate() {
|
||||
log.trace("Deallocating CUDA memory");
|
||||
// skipping any allocation that is coming from workspace
|
||||
if (point.isAttached()) {
|
||||
if (point.isAttached() || point.isReleased()) {
|
||||
// TODO: remove allocation point as well?
|
||||
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
|
||||
throw new RuntimeException();
|
||||
return;
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);
|
||||
|
||||
|
|
|
@ -544,6 +544,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
*/
|
||||
@Override
|
||||
public long address() {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
return allocationPoint.getPointers().getHostPointer().address();
|
||||
}
|
||||
|
||||
|
@ -554,6 +557,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
|
||||
@Override
|
||||
public Pointer pointer() {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
// FIXME: very bad thing,
|
||||
lazyAllocateHostPointer();
|
||||
|
||||
|
@ -1109,6 +1115,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
|
||||
@Override
|
||||
public Pointer addressPointer() {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
return AtomicAllocator.getInstance().getHostPointer(this);
|
||||
}
|
||||
|
||||
|
@ -1562,7 +1571,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
|
||||
@Override
|
||||
protected void release() {
|
||||
if (!released) {
|
||||
AtomicAllocator.getInstance().freeMemory(allocationPoint);
|
||||
allocationPoint.setReleased(true);
|
||||
}
|
||||
released = true;
|
||||
}
|
||||
|
||||
|
|
|
@ -616,14 +616,23 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
*/
|
||||
@Override
|
||||
public Indexer indexer() {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
return indexer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pointer pointer() {
|
||||
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this)
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this) {
|
||||
if (underlyingDataBuffer().wasClosed())
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
return underlyingDataBuffer().pointer();
|
||||
else {
|
||||
} else {
|
||||
if (underlyingDataBuffer() != null)
|
||||
if (((BaseDataBuffer) underlyingDataBuffer()).released)
|
||||
throw new IllegalStateException("Underlying buffer was released via close() call");
|
||||
|
@ -923,6 +932,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public Pointer addressPointer() {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
if (offset() > 0) {
|
||||
Pointer ret;
|
||||
|
@ -968,6 +979,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public long address() {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
return pointer().address() + getElementSize() * offset();
|
||||
}
|
||||
|
||||
|
@ -1409,6 +1423,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public double getDouble(long i) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
if (indexer == null) {
|
||||
throw new IllegalStateException("Indexer must never be null");
|
||||
}
|
||||
|
@ -1444,6 +1461,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public long getLong(long i) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case FLOAT:
|
||||
return (long) ((FloatIndexer) indexer).get(offset() + i);
|
||||
|
@ -1480,6 +1500,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
* @return
|
||||
*/
|
||||
protected short getShort(long i) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case DOUBLE:
|
||||
return (short) ((DoubleIndexer) indexer).get(offset() + i);
|
||||
|
@ -1518,6 +1541,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public float getFloat(long i) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case DOUBLE:
|
||||
return (float) ((DoubleIndexer) indexer).get(offset() + i);
|
||||
|
@ -1550,6 +1576,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public int getInt(long i) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case DOUBLE:
|
||||
return (int) ((DoubleIndexer) indexer).get(offset() + i);
|
||||
|
@ -1582,6 +1611,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public Number getNumber(long i) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
if (dataType() == DataType.DOUBLE)
|
||||
return getDouble(i);
|
||||
else if (dataType() == DataType.INT)
|
||||
|
@ -1685,6 +1717,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public void put(long i, float element) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case BOOL:
|
||||
((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true);
|
||||
|
@ -1732,6 +1767,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public void put(long i, double element) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case BOOL:
|
||||
((BooleanIndexer) indexer).put(offset() + i, element > 0.0);
|
||||
|
@ -1779,6 +1817,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public void put(long i, int element) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case BOOL:
|
||||
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
|
||||
|
@ -1826,6 +1867,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public void put(long i, boolean element) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case BOOL:
|
||||
((BooleanIndexer) indexer).put(offset() + i, element);
|
||||
|
@ -1873,6 +1917,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
|
||||
@Override
|
||||
public void put(long i, long element) {
|
||||
if (released)
|
||||
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||
|
||||
switch (dataType()) {
|
||||
case BOOL:
|
||||
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
|
||||
|
@ -2667,4 +2714,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
|||
public long platformAddress() {
|
||||
return address();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean wasClosed() {
|
||||
return released;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -703,4 +703,10 @@ public interface DataBuffer extends Serializable, AutoCloseable {
|
|||
* PLEASE NOTE: This method is NOT safe by any means
|
||||
*/
|
||||
void close();
|
||||
|
||||
/**
|
||||
* This method checks if array or its buffer was closed before
|
||||
* @return true if was closed, false otherwise
|
||||
*/
|
||||
boolean wasClosed();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue