CUDA host pointer fix (#322)

* CUDA fix: host pointer propagation

Signed-off-by: raver119 <raver119@gmail.com>

* disable logging

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-03-16 12:07:34 +03:00 committed by GitHub
parent 2cd4522f94
commit bac130bd78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 13 deletions

View File

@ -121,29 +121,22 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda
// allocating interop buffer // allocating interop buffer
this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false); this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false);
// passing existing pointer to native holder
this.ptrDataBuffer.setPrimaryBuffer(pointer, length);
//cuda specific bits //cuda specific bits
this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize); this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize);
Nd4j.getDeallocatorService().pickObject(this); Nd4j.getDeallocatorService().pickObject(this);
// now we're // now we're getting context and copying our stuff to device
val context = AtomicAllocator.getInstance().getDeviceContext(); val context = AtomicAllocator.getInstance().getDeviceContext();
val perfD = PerformanceTracker.getInstance().helperStartTransaction(); val perfD = PerformanceTracker.getInstance().helperStartTransaction();
if (allocationPoint.getHostPointer() != null) { NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
} else {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
}
context.getSpecialStream().synchronize();
if (allocationPoint.getHostPointer() != null)
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);
context.getSpecialStream().synchronize();
} }
public BaseCudaDataBuffer(float[] data, boolean copy) { public BaseCudaDataBuffer(float[] data, boolean copy) {