CUDA host pointer fix (#322)
* CUDA fix: host pointer propagation Signed-off-by: raver119 <raver119@gmail.com> * disable logging Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									2cd4522f94
								
							
						
					
					
						commit
						bac130bd78
					
				| @ -121,29 +121,22 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|         // allocating interop buffer | ||||
|         this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false); | ||||
| 
 | ||||
|         // passing existing pointer to native holder | ||||
|         this.ptrDataBuffer.setPrimaryBuffer(pointer, length); | ||||
| 
 | ||||
|         //cuda specific bits | ||||
|         this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize); | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
| 
 | ||||
|         // now we're | ||||
|         // now we're getting context and copying our stuff to device | ||||
|         val context = AtomicAllocator.getInstance().getDeviceContext(); | ||||
| 
 | ||||
|         val perfD = PerformanceTracker.getInstance().helperStartTransaction(); | ||||
| 
 | ||||
|         if (allocationPoint.getHostPointer() != null) { | ||||
|             NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); | ||||
|             NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); | ||||
|         } else { | ||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); | ||||
|         } | ||||
| 
 | ||||
|         context.getSpecialStream().synchronize(); | ||||
| 
 | ||||
|         if (allocationPoint.getHostPointer() != null) | ||||
|             PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST); | ||||
| 
 | ||||
|         PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); | ||||
| 
 | ||||
|         context.getSpecialStream().synchronize(); | ||||
|     } | ||||
| 
 | ||||
|     public BaseCudaDataBuffer(float[] data, boolean copy) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user