allocation of buffers backed by workspaces with 1 method call instead of 3 now (#420)

Signed-off-by: raver119@gmail.com <raver119@gmail.com>
master
raver119 2020-04-28 20:38:16 +03:00 committed by GitHub
parent dad6bc5ed2
commit bc8a9d1996
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 57 additions and 26 deletions

View File

@ -1637,6 +1637,8 @@ ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc);
ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth);
ND4J_EXPORT OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth);
ND4J_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special);
ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset);
ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer);
ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer);

View File

@ -3115,6 +3115,10 @@ bool isOptimalRequirementsMet() {
#endif
}
OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
return allocateDataBuffer(elements, dataType, allocateBoth);
}
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
try {
auto dtype = DataTypeUtils::fromInt(dataType);
@ -3138,6 +3142,18 @@ void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) {
delete dataBuffer;
}
OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) {
auto buffer = dbAllocateDataBuffer(0, dataType, false);
if (primary != nullptr)
buffer->setPrimary(primary, elements);
if (special != nullptr)
buffer->setSpecial(special, elements);
return buffer;
}
void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) {
dataBuffer->setPrimary(primaryBuffer, numBytes);
}

View File

@ -3802,6 +3802,22 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
}
OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) {
auto buffer = dbAllocateDataBuffer(0, dataType, false);
if (primary != nullptr)
buffer->setPrimary(primary, elements);
if (special != nullptr)
buffer->setSpecial(special, elements);
return buffer;
}
OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
return allocateDataBuffer(elements, dataType, allocateBoth);
}
OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) {
try {
auto dtype = DataTypeUtils::fromInt(dataType);

View File

@ -1201,6 +1201,8 @@ public interface NativeOps {
OpaqueDataBuffer allocateDataBuffer(long elements, int dataType, boolean allocateBoth);
OpaqueDataBuffer dbAllocateDataBuffer(long elements, int dataType, boolean allocateBoth);
OpaqueDataBuffer dbCreateExternalDataBuffer(long elements, int dataType, Pointer primary, Pointer special);
OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, long length, long offset);
Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer);
Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer);

View File

@ -35,6 +35,11 @@ public class OpaqueDataBuffer extends Pointer {
public OpaqueDataBuffer(Pointer p) { super(p); }
public static OpaqueDataBuffer externalizedDataBuffer(long numElements, @NonNull DataType dataType, Pointer primary, Pointer special) {
return NativeOpsHolder.getInstance().getDeviceNativeOps().dbCreateExternalDataBuffer(numElements, dataType.toInt(), primary, special);
}
/**
* This method allocates new InteropDataBuffer and returns pointer to it
* @param numElements

View File

@ -101,9 +101,8 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
initTypeAndSize();
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, this.type, false);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, this.type, pointer, specialPointer);
this.allocationPoint = new AllocationPoint(ptrDataBuffer, this.type.width() * length);
this.allocationPoint.setPointers(pointer, specialPointer, length);
Nd4j.getDeallocatorService().pickObject(this);
}
@ -411,14 +410,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
this.offset = 0;
this.originalOffset = 0;
// allocating empty databuffer
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false);
if (workspace.getWorkspaceConfiguration().getPolicyMirroring() == MirroringPolicy.FULL) {
val devicePtr = workspace.alloc(length * elementSize, MemoryKind.DEVICE, type, initialize);
// allocate from workspace, and pass it to native DataBuffer
ptrDataBuffer.setSpecialBuffer(devicePtr, this.length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, type, null, devicePtr);
if (initialize) {
val ctx = AtomicAllocator.getInstance().getDeviceContext();
@ -428,7 +424,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
} else {
// we can register this pointer as device, because it's pinned memory
val devicePtr = workspace.alloc(length * elementSize, MemoryKind.HOST, type, initialize);
ptrDataBuffer.setSpecialBuffer(devicePtr, this.length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, type, null, devicePtr);
if (initialize) {
val ctx = AtomicAllocator.getInstance().getDeviceContext();

View File

@ -29,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import java.nio.ByteBuffer;
@ -73,9 +74,7 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer {
initTypeAndSize();
// creating empty native DataBuffer and filling it with pointers
ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false);
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, hostPointer, numberOfElements);
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, devicePointer, numberOfElements);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(numberOfElements, DataType.INT64, hostPointer, devicePointer);
// setting up java side of things
this.pointer = new CudaPointer(hostPointer, numberOfElements).asLongPointer();

View File

@ -3139,6 +3139,8 @@ public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc);
public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc);
public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth);
public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth);
public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("Nd4jPointer") Pointer primary, @Cast("Nd4jPointer") Pointer special);
public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset);
public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer);
public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer);

View File

@ -460,9 +460,6 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
if (length < 0)
throw new IllegalArgumentException("Unable to create a buffer of length <= 0");
// creating empty native DataBuffer
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false);
if (dataType() == DataType.DOUBLE) {
attached = true;
parentWorkspace = workspace;
@ -559,7 +556,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
}
// storing pointer into native DataBuffer
ptrDataBuffer.setPrimaryBuffer(pointer, length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
// adding deallocator reference
Nd4j.getDeallocatorService().pickObject(this);
@ -570,8 +567,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
public BaseCpuDataBuffer(Pointer pointer, Indexer indexer, long length) {
super(pointer, indexer, length);
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false);
ptrDataBuffer.setPrimaryBuffer(this.pointer, length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);;
}
@ -633,8 +629,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asFloatPointer().put(data);
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false);
this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
@ -655,8 +650,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asDoublePointer().put(data);
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false);
this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
@ -678,8 +672,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asIntPointer().put(data);
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false);
this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();
@ -700,8 +693,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo
pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asLongPointer().put(data);
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false);
this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
workspaceGenerationId = workspace.getGenerationId();

View File

@ -124,8 +124,7 @@ public class LongBuffer extends BaseCpuDataBuffer {
// we still want this buffer to have native representation
ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, DataType.INT64, false);
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements);
ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(numberOfElements, DataType.INT64, this.pointer, null);
Nd4j.getDeallocatorService().pickObject(this);
}

View File

@ -3143,6 +3143,8 @@ public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc);
public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc);
public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth);
public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth);
public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("Nd4jPointer") Pointer primary, @Cast("Nd4jPointer") Pointer special);
public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset);
public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer);
public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer);