parent
243bf866c4
commit
dbea687903
|
@ -6697,6 +6697,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
if (released || isEmpty())
|
if (released || isEmpty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
Nd4j.getExecutioner().commit();
|
||||||
|
|
||||||
if (!closeable())
|
if (!closeable())
|
||||||
throw new ND4JIllegalStateException("Can't release this INDArray");
|
throw new ND4JIllegalStateException("Can't release this INDArray");
|
||||||
|
|
||||||
|
@ -6715,5 +6717,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering());
|
return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean wasClosed() {
|
||||||
|
if (released || data().wasClosed())
|
||||||
|
return true;
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1218,6 +1218,11 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean wasClosed() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isS() {
|
public boolean isS() {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -263,6 +263,12 @@ public abstract class BaseSparseNDArrayCSR extends BaseSparseNDArray {
|
||||||
return false; //todo
|
return false; //todo
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean wasClosed() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int underlyingRank() {
|
public int underlyingRank() {
|
||||||
return rank;
|
return rank;
|
||||||
|
|
|
@ -2831,6 +2831,12 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
void close();
|
void close();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method checks if array or its buffer was closed before
|
||||||
|
* @return true if was closed, false otherwise
|
||||||
|
*/
|
||||||
|
boolean wasClosed();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns empty array with the same dtype/order/shape as this one
|
* This method returns empty array with the same dtype/order/shape as this one
|
||||||
* @return
|
* @return
|
||||||
|
|
|
@ -455,6 +455,14 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
|
|
||||||
|
|
||||||
public long profilingConfigurableHookIn(CustomOp op) {
|
public long profilingConfigurableHookIn(CustomOp op) {
|
||||||
|
for (val arr: op.inputArguments())
|
||||||
|
if (arr.wasClosed())
|
||||||
|
throw new IllegalStateException("One of Input arguments was closed before call");
|
||||||
|
|
||||||
|
for (val arr: op.outputArguments())
|
||||||
|
if (arr.wasClosed())
|
||||||
|
throw new IllegalStateException("One of Output arguments was closed before call");
|
||||||
|
|
||||||
if (OpProfiler.getInstance().getConfig() == null)
|
if (OpProfiler.getInstance().getConfig() == null)
|
||||||
return System.nanoTime();
|
return System.nanoTime();
|
||||||
|
|
||||||
|
@ -471,6 +479,18 @@ public class DefaultOpExecutioner implements OpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
public long profilingConfigurableHookIn(Op op, DataBuffer... tadBuffers) {
|
public long profilingConfigurableHookIn(Op op, DataBuffer... tadBuffers) {
|
||||||
|
if (op.x() != null)
|
||||||
|
if (op.x().wasClosed())
|
||||||
|
throw new IllegalStateException("Op.X argument was closed before call");
|
||||||
|
|
||||||
|
if (op.y() != null)
|
||||||
|
if (op.y().wasClosed())
|
||||||
|
throw new IllegalStateException("Op.Y argument was closed before call");
|
||||||
|
|
||||||
|
if (op.z() != null)
|
||||||
|
if (op.z().wasClosed())
|
||||||
|
throw new IllegalStateException("Op.Z argument was closed before call");
|
||||||
|
|
||||||
if (OpProfiler.getInstance().getConfig() == null)
|
if (OpProfiler.getInstance().getConfig() == null)
|
||||||
return System.nanoTime();
|
return System.nanoTime();
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,10 @@ public class AllocationPoint {
|
||||||
@Setter
|
@Setter
|
||||||
private boolean isAttached = false;
|
private boolean isAttached = false;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
private volatile boolean released = false;
|
||||||
|
|
||||||
// thread safety is guaranteed by allocLock
|
// thread safety is guaranteed by allocLock
|
||||||
private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED;
|
private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED;
|
||||||
|
|
||||||
|
|
|
@ -400,14 +400,19 @@ public class AtomicAllocator implements Allocator {
|
||||||
public void freeMemory(AllocationPoint point) {
|
public void freeMemory(AllocationPoint point) {
|
||||||
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
|
||||||
this.getMemoryHandler().getMemoryProvider().free(point);
|
this.getMemoryHandler().getMemoryProvider().free(point);
|
||||||
|
|
||||||
|
if (point.getHostPointer() != null) {
|
||||||
point.setAllocationStatus(AllocationStatus.HOST);
|
point.setAllocationStatus(AllocationStatus.HOST);
|
||||||
this.getMemoryHandler().getMemoryProvider().free(point);
|
this.getMemoryHandler().getMemoryProvider().free(point);
|
||||||
this.getMemoryHandler().forget(point, AllocationStatus.DEVICE);
|
this.getMemoryHandler().forget(point, AllocationStatus.DEVICE);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// call it only once
|
// call it only once
|
||||||
|
if (point.getHostPointer() != null) {
|
||||||
this.getMemoryHandler().getMemoryProvider().free(point);
|
this.getMemoryHandler().getMemoryProvider().free(point);
|
||||||
this.getMemoryHandler().forget(point, AllocationStatus.HOST);
|
this.getMemoryHandler().forget(point, AllocationStatus.HOST);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
allocationsMap.remove(point.getObjectId());
|
allocationsMap.remove(point.getObjectId());
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,10 +39,10 @@ public class CudaDeallocator implements Deallocator {
|
||||||
public void deallocate() {
|
public void deallocate() {
|
||||||
log.trace("Deallocating CUDA memory");
|
log.trace("Deallocating CUDA memory");
|
||||||
// skipping any allocation that is coming from workspace
|
// skipping any allocation that is coming from workspace
|
||||||
if (point.isAttached()) {
|
if (point.isAttached() || point.isReleased()) {
|
||||||
// TODO: remove allocation point as well?
|
// TODO: remove allocation point as well?
|
||||||
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
|
if (!AtomicAllocator.getInstance().allocationsMap().containsKey(point.getObjectId()))
|
||||||
throw new RuntimeException();
|
return;
|
||||||
|
|
||||||
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);
|
AtomicAllocator.getInstance().getFlowController().waitTillReleased(point);
|
||||||
|
|
||||||
|
|
|
@ -544,6 +544,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public long address() {
|
public long address() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
return allocationPoint.getPointers().getHostPointer().address();
|
return allocationPoint.getPointers().getHostPointer().address();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -554,6 +557,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pointer pointer() {
|
public Pointer pointer() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
// FIXME: very bad thing,
|
// FIXME: very bad thing,
|
||||||
lazyAllocateHostPointer();
|
lazyAllocateHostPointer();
|
||||||
|
|
||||||
|
@ -1109,6 +1115,9 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pointer addressPointer() {
|
public Pointer addressPointer() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
return AtomicAllocator.getInstance().getHostPointer(this);
|
return AtomicAllocator.getInstance().getHostPointer(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1562,7 +1571,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void release() {
|
protected void release() {
|
||||||
|
if (!released) {
|
||||||
AtomicAllocator.getInstance().freeMemory(allocationPoint);
|
AtomicAllocator.getInstance().freeMemory(allocationPoint);
|
||||||
|
allocationPoint.setReleased(true);
|
||||||
|
}
|
||||||
released = true;
|
released = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -616,14 +616,23 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public Indexer indexer() {
|
public Indexer indexer() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
return indexer;
|
return indexer;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pointer pointer() {
|
public Pointer pointer() {
|
||||||
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this)
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
|
if (underlyingDataBuffer() != null && underlyingDataBuffer() != this) {
|
||||||
|
if (underlyingDataBuffer().wasClosed())
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
return underlyingDataBuffer().pointer();
|
return underlyingDataBuffer().pointer();
|
||||||
else {
|
} else {
|
||||||
if (underlyingDataBuffer() != null)
|
if (underlyingDataBuffer() != null)
|
||||||
if (((BaseDataBuffer) underlyingDataBuffer()).released)
|
if (((BaseDataBuffer) underlyingDataBuffer()).released)
|
||||||
throw new IllegalStateException("Underlying buffer was released via close() call");
|
throw new IllegalStateException("Underlying buffer was released via close() call");
|
||||||
|
@ -923,6 +932,8 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pointer addressPointer() {
|
public Pointer addressPointer() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
if (offset() > 0) {
|
if (offset() > 0) {
|
||||||
Pointer ret;
|
Pointer ret;
|
||||||
|
@ -968,6 +979,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long address() {
|
public long address() {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
return pointer().address() + getElementSize() * offset();
|
return pointer().address() + getElementSize() * offset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1409,6 +1423,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public double getDouble(long i) {
|
public double getDouble(long i) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
if (indexer == null) {
|
if (indexer == null) {
|
||||||
throw new IllegalStateException("Indexer must never be null");
|
throw new IllegalStateException("Indexer must never be null");
|
||||||
}
|
}
|
||||||
|
@ -1444,6 +1461,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getLong(long i) {
|
public long getLong(long i) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
return (long) ((FloatIndexer) indexer).get(offset() + i);
|
return (long) ((FloatIndexer) indexer).get(offset() + i);
|
||||||
|
@ -1480,6 +1500,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
protected short getShort(long i) {
|
protected short getShort(long i) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return (short) ((DoubleIndexer) indexer).get(offset() + i);
|
return (short) ((DoubleIndexer) indexer).get(offset() + i);
|
||||||
|
@ -1518,6 +1541,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float getFloat(long i) {
|
public float getFloat(long i) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return (float) ((DoubleIndexer) indexer).get(offset() + i);
|
return (float) ((DoubleIndexer) indexer).get(offset() + i);
|
||||||
|
@ -1550,6 +1576,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getInt(long i) {
|
public int getInt(long i) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
return (int) ((DoubleIndexer) indexer).get(offset() + i);
|
return (int) ((DoubleIndexer) indexer).get(offset() + i);
|
||||||
|
@ -1582,6 +1611,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Number getNumber(long i) {
|
public Number getNumber(long i) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
if (dataType() == DataType.DOUBLE)
|
if (dataType() == DataType.DOUBLE)
|
||||||
return getDouble(i);
|
return getDouble(i);
|
||||||
else if (dataType() == DataType.INT)
|
else if (dataType() == DataType.INT)
|
||||||
|
@ -1685,6 +1717,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void put(long i, float element) {
|
public void put(long i, float element) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true);
|
((BooleanIndexer) indexer).put(offset() + i, element == 0.0 ? false : true);
|
||||||
|
@ -1732,6 +1767,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void put(long i, double element) {
|
public void put(long i, double element) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
((BooleanIndexer) indexer).put(offset() + i, element > 0.0);
|
((BooleanIndexer) indexer).put(offset() + i, element > 0.0);
|
||||||
|
@ -1779,6 +1817,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void put(long i, int element) {
|
public void put(long i, int element) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
|
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
|
||||||
|
@ -1826,6 +1867,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void put(long i, boolean element) {
|
public void put(long i, boolean element) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
((BooleanIndexer) indexer).put(offset() + i, element);
|
((BooleanIndexer) indexer).put(offset() + i, element);
|
||||||
|
@ -1873,6 +1917,9 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void put(long i, long element) {
|
public void put(long i, long element) {
|
||||||
|
if (released)
|
||||||
|
throw new IllegalStateException("You can't use DataBuffer once it was released");
|
||||||
|
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
|
((BooleanIndexer) indexer).put(offset() + i, element == 0 ? false : true);
|
||||||
|
@ -2667,4 +2714,10 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
public long platformAddress() {
|
public long platformAddress() {
|
||||||
return address();
|
return address();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean wasClosed() {
|
||||||
|
return released;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -703,4 +703,10 @@ public interface DataBuffer extends Serializable, AutoCloseable {
|
||||||
* PLEASE NOTE: This method is NOT safe by any means
|
* PLEASE NOTE: This method is NOT safe by any means
|
||||||
*/
|
*/
|
||||||
void close();
|
void close();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method checks if array or its buffer was closed before
|
||||||
|
* @return true if was closed, false otherwise
|
||||||
|
*/
|
||||||
|
boolean wasClosed();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue