CUDA host pointer fix (#322)
* CUDA fix: host pointer propagation Signed-off-by: raver119 <raver119@gmail.com> * disable logging Signed-off-by: raver119 <raver119@gmail.com>master
parent
2cd4522f94
commit
bac130bd78
|
@ -121,29 +121,22 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
|
||||||
// allocating interop buffer
|
// allocating interop buffer
|
||||||
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false);
|
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false);
|
||||||
|
|
||||||
|
// passing existing pointer to native holder
|
||||||
|
this.ptrDataBuffer.setPrimaryBuffer(pointer, length);
|
||||||
|
|
||||||
//cuda specific bits
|
//cuda specific bits
|
||||||
this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize);
|
this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize);
|
||||||
Nd4j.getDeallocatorService().pickObject(this);
|
Nd4j.getDeallocatorService().pickObject(this);
|
||||||
|
|
||||||
// now we're
|
// now we're getting context and copying our stuff to device
|
||||||
val context = AtomicAllocator.getInstance().getDeviceContext();
|
val context = AtomicAllocator.getInstance().getDeviceContext();
|
||||||
|
|
||||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||||
|
|
||||||
if (allocationPoint.getHostPointer() != null) {
|
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
|
|
||||||
} else {
|
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
|
|
||||||
}
|
|
||||||
|
|
||||||
context.getSpecialStream().synchronize();
|
|
||||||
|
|
||||||
if (allocationPoint.getHostPointer() != null)
|
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
|
|
||||||
|
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
|
||||||
|
context.getSpecialStream().synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
public BaseCudaDataBuffer(float[] data, boolean copy) {
|
public BaseCudaDataBuffer(float[] data, boolean copy) {
|
||||||
|
|
Loading…
Reference in New Issue