CUDA host pointer fix (#322)
* CUDA fix: host pointer propagation Signed-off-by: raver119 <raver119@gmail.com> * disable logging Signed-off-by: raver119 <raver119@gmail.com>master
parent
2cd4522f94
commit
bac130bd78
|
@ -121,29 +121,22 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
|||
// allocating interop buffer
|
||||
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false);
|
||||
|
||||
// passing existing pointer to native holder
|
||||
this.ptrDataBuffer.setPrimaryBuffer(pointer, length);
|
||||
|
||||
//cuda specific bits
|
||||
this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize);
|
||||
Nd4j.getDeallocatorService().pickObject(this);
|
||||
|
||||
// now we're
|
||||
// now we're getting context and copying our stuff to device
|
||||
val context = AtomicAllocator.getInstance().getDeviceContext();
|
||||
|
||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
|
||||
if (allocationPoint.getHostPointer() != null) {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
|
||||
} else {
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
||||
}
|
||||
|
||||
context.getSpecialStream().synchronize();
|
||||
|
||||
if (allocationPoint.getHostPointer() != null)
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
|
||||
|
||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
||||
|
||||
context.getSpecialStream().synchronize();
|
||||
}
|
||||
|
||||
public BaseCudaDataBuffer(float[] data, boolean copy) {
|
||||
|
|
Loading…
Reference in New Issue