allocation of buffers backed by workspaces with 1 method call instead of 3 now (#420)
Signed-off-by: raver119@gmail.com <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									dad6bc5ed2
								
							
						
					
					
						commit
						bc8a9d1996
					
				| @ -1637,6 +1637,8 @@ ND4J_EXPORT Nd4jPointer lcBlasHandle(OpaqueLaunchContext* lc); | ||||
| ND4J_EXPORT Nd4jPointer lcSolverHandle(OpaqueLaunchContext* lc); | ||||
| 
 | ||||
| ND4J_EXPORT OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); | ||||
| ND4J_EXPORT OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth); | ||||
| ND4J_EXPORT OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special); | ||||
| ND4J_EXPORT OpaqueDataBuffer* dbCreateView(OpaqueDataBuffer *dataBuffer, Nd4jLong length, Nd4jLong offset); | ||||
| ND4J_EXPORT Nd4jPointer dbPrimaryBuffer(OpaqueDataBuffer *dataBuffer); | ||||
| ND4J_EXPORT Nd4jPointer dbSpecialBuffer(OpaqueDataBuffer *dataBuffer); | ||||
|  | ||||
| @ -3115,6 +3115,10 @@ bool isOptimalRequirementsMet() { | ||||
| #endif | ||||
| } | ||||
| 
 | ||||
| OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { | ||||
|     return allocateDataBuffer(elements, dataType, allocateBoth); | ||||
| } | ||||
| 
 | ||||
| OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { | ||||
|     try { | ||||
|         auto dtype = DataTypeUtils::fromInt(dataType); | ||||
| @ -3138,6 +3142,18 @@ void deleteDataBuffer(OpaqueDataBuffer *dataBuffer) { | ||||
|     delete dataBuffer; | ||||
| } | ||||
| 
 | ||||
| OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) { | ||||
|     auto buffer = dbAllocateDataBuffer(0, dataType, false); | ||||
| 
 | ||||
|     if (primary != nullptr) | ||||
|         buffer->setPrimary(primary, elements); | ||||
| 
 | ||||
|     if (special != nullptr) | ||||
|         buffer->setSpecial(special, elements); | ||||
| 
 | ||||
|     return buffer; | ||||
| } | ||||
| 
 | ||||
| void dbSetPrimaryBuffer(OpaqueDataBuffer *dataBuffer, Nd4jPointer primaryBuffer, Nd4jLong numBytes) { | ||||
|     dataBuffer->setPrimary(primaryBuffer, numBytes); | ||||
| } | ||||
|  | ||||
| @ -3802,6 +3802,22 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) { | ||||
|     ptr->setExecutionMode((samediff::ExecutionMode) execMode); | ||||
| } | ||||
| 
 | ||||
| OpaqueDataBuffer* dbCreateExternalDataBuffer(Nd4jLong elements, int dataType, Nd4jPointer primary, Nd4jPointer special) { | ||||
|     auto buffer = dbAllocateDataBuffer(0, dataType, false); | ||||
| 
 | ||||
|     if (primary != nullptr) | ||||
|         buffer->setPrimary(primary, elements); | ||||
| 
 | ||||
|     if (special != nullptr) | ||||
|         buffer->setSpecial(special, elements); | ||||
| 
 | ||||
|     return buffer; | ||||
| } | ||||
| 
 | ||||
| OpaqueDataBuffer* dbAllocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { | ||||
|     return allocateDataBuffer(elements, dataType, allocateBoth); | ||||
| } | ||||
| 
 | ||||
| OpaqueDataBuffer* allocateDataBuffer(Nd4jLong elements, int dataType, bool allocateBoth) { | ||||
|     try { | ||||
|         auto dtype = DataTypeUtils::fromInt(dataType); | ||||
|  | ||||
| @ -1201,6 +1201,8 @@ public interface NativeOps { | ||||
| 
 | ||||
| 
 | ||||
|     OpaqueDataBuffer allocateDataBuffer(long elements, int dataType, boolean allocateBoth); | ||||
|     OpaqueDataBuffer dbAllocateDataBuffer(long elements, int dataType, boolean allocateBoth); | ||||
|     OpaqueDataBuffer dbCreateExternalDataBuffer(long elements, int dataType, Pointer primary, Pointer special); | ||||
|     OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, long length, long offset); | ||||
|     Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); | ||||
|     Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); | ||||
|  | ||||
| @ -35,6 +35,11 @@ public class OpaqueDataBuffer extends Pointer { | ||||
| 
 | ||||
|     public OpaqueDataBuffer(Pointer p) { super(p); } | ||||
| 
 | ||||
| 
 | ||||
|     public static OpaqueDataBuffer externalizedDataBuffer(long numElements, @NonNull DataType dataType, Pointer primary, Pointer special) { | ||||
|         return NativeOpsHolder.getInstance().getDeviceNativeOps().dbCreateExternalDataBuffer(numElements, dataType.toInt(), primary, special); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * This method allocates new InteropDataBuffer and returns pointer to it | ||||
|      * @param numElements | ||||
|  | ||||
| @ -101,9 +101,8 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
| 
 | ||||
|         initTypeAndSize(); | ||||
| 
 | ||||
|         ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, this.type, false); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, this.type,  pointer, specialPointer); | ||||
|         this.allocationPoint = new AllocationPoint(ptrDataBuffer, this.type.width() * length); | ||||
|         this.allocationPoint.setPointers(pointer, specialPointer, length); | ||||
| 
 | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
|     } | ||||
| @ -411,14 +410,11 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|         this.offset = 0; | ||||
|         this.originalOffset = 0; | ||||
| 
 | ||||
|         // allocating empty databuffer | ||||
|         ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false); | ||||
| 
 | ||||
|         if (workspace.getWorkspaceConfiguration().getPolicyMirroring() == MirroringPolicy.FULL) { | ||||
|             val devicePtr = workspace.alloc(length * elementSize, MemoryKind.DEVICE, type, initialize); | ||||
| 
 | ||||
|             // allocate from workspace, and pass it  to native DataBuffer | ||||
|             ptrDataBuffer.setSpecialBuffer(devicePtr, this.length); | ||||
|             ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, type, null, devicePtr); | ||||
| 
 | ||||
|             if (initialize) { | ||||
|                 val ctx = AtomicAllocator.getInstance().getDeviceContext(); | ||||
| @ -428,7 +424,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|         }  else { | ||||
|             // we can register this pointer as device, because it's pinned memory | ||||
|             val devicePtr = workspace.alloc(length * elementSize, MemoryKind.HOST, type, initialize); | ||||
|             ptrDataBuffer.setSpecialBuffer(devicePtr, this.length); | ||||
|             ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(this.length, type, null, devicePtr); | ||||
| 
 | ||||
|             if (initialize) { | ||||
|                 val ctx = AtomicAllocator.getInstance().getDeviceContext(); | ||||
|  | ||||
| @ -29,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; | ||||
| import org.nd4j.nativeblas.NativeOpsHolder; | ||||
| import org.nd4j.nativeblas.OpaqueDataBuffer; | ||||
| 
 | ||||
| import java.nio.ByteBuffer; | ||||
| 
 | ||||
| @ -73,9 +74,7 @@ public class CudaLongDataBuffer extends BaseCudaDataBuffer { | ||||
|         initTypeAndSize(); | ||||
| 
 | ||||
|         // creating empty native DataBuffer and filling it with pointers | ||||
|         ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); | ||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, hostPointer, numberOfElements); | ||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetSpecialBuffer(ptrDataBuffer, devicePointer, numberOfElements); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(numberOfElements, DataType.INT64, hostPointer, devicePointer); | ||||
| 
 | ||||
|         // setting up java side of things | ||||
|         this.pointer = new CudaPointer(hostPointer, numberOfElements).asLongPointer(); | ||||
|  | ||||
| @ -3139,6 +3139,8 @@ public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); | ||||
| public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); | ||||
| 
 | ||||
| public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); | ||||
| public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); | ||||
| public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("Nd4jPointer") Pointer primary, @Cast("Nd4jPointer") Pointer special); | ||||
| public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); | ||||
| public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); | ||||
| public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); | ||||
|  | ||||
| @ -460,9 +460,6 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
|         if (length < 0) | ||||
|             throw new IllegalArgumentException("Unable to create a buffer of length <= 0"); | ||||
| 
 | ||||
|         // creating empty native DataBuffer | ||||
|         ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); | ||||
| 
 | ||||
|         if (dataType() == DataType.DOUBLE) { | ||||
|             attached = true; | ||||
|             parentWorkspace = workspace; | ||||
| @ -559,7 +556,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
|         } | ||||
| 
 | ||||
|         // storing pointer into native DataBuffer | ||||
|         ptrDataBuffer.setPrimaryBuffer(pointer, length); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null); | ||||
| 
 | ||||
|         // adding deallocator reference | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
| @ -570,8 +567,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
|     public BaseCpuDataBuffer(Pointer pointer, Indexer indexer, long length) { | ||||
|         super(pointer, indexer, length); | ||||
| 
 | ||||
|         ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, type, false); | ||||
|         ptrDataBuffer.setPrimaryBuffer(this.pointer, length); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null); | ||||
|         Nd4j.getDeallocatorService().pickObject(this);; | ||||
|     } | ||||
| 
 | ||||
| @ -633,8 +629,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
| 
 | ||||
|         pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asFloatPointer().put(data); | ||||
| 
 | ||||
|         this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); | ||||
|         this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null); | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
| 
 | ||||
|         workspaceGenerationId = workspace.getGenerationId(); | ||||
| @ -655,8 +650,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
| 
 | ||||
|         pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asDoublePointer().put(data); | ||||
| 
 | ||||
|         this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); | ||||
|         this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null); | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
| 
 | ||||
|         workspaceGenerationId = workspace.getGenerationId(); | ||||
| @ -678,8 +672,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
| 
 | ||||
|         pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asIntPointer().put(data); | ||||
| 
 | ||||
|         this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); | ||||
|         this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null); | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
| 
 | ||||
|         workspaceGenerationId = workspace.getGenerationId(); | ||||
| @ -700,8 +693,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo | ||||
| 
 | ||||
|         pointer = workspace.alloc(data.length * getElementSize(), dataType(), false).asLongPointer().put(data); | ||||
| 
 | ||||
|         this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, dataType(), false); | ||||
|         this.ptrDataBuffer.setPrimaryBuffer(pointer, this.length); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(length, dataType(), this.pointer, null); | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
| 
 | ||||
|         workspaceGenerationId = workspace.getGenerationId(); | ||||
|  | ||||
| @ -124,8 +124,7 @@ public class LongBuffer extends BaseCpuDataBuffer { | ||||
| 
 | ||||
|         // we still want this buffer to have native representation | ||||
| 
 | ||||
|         ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, DataType.INT64, false); | ||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements); | ||||
|         ptrDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(numberOfElements, DataType.INT64, this.pointer, null); | ||||
| 
 | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
|     } | ||||
|  | ||||
| @ -3143,6 +3143,8 @@ public native @Cast("Nd4jPointer") Pointer lcBlasHandle(OpaqueLaunchContext lc); | ||||
| public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc); | ||||
| 
 | ||||
| public native OpaqueDataBuffer allocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); | ||||
| public native OpaqueDataBuffer dbAllocateDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("bool") boolean allocateBoth); | ||||
| public native OpaqueDataBuffer dbCreateExternalDataBuffer(@Cast("Nd4jLong") long elements, int dataType, @Cast("Nd4jPointer") Pointer primary, @Cast("Nd4jPointer") Pointer special); | ||||
| public native OpaqueDataBuffer dbCreateView(OpaqueDataBuffer dataBuffer, @Cast("Nd4jLong") long length, @Cast("Nd4jLong") long offset); | ||||
| public native @Cast("Nd4jPointer") Pointer dbPrimaryBuffer(OpaqueDataBuffer dataBuffer); | ||||
| public native @Cast("Nd4jPointer") Pointer dbSpecialBuffer(OpaqueDataBuffer dataBuffer); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user