diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 2f1cab334..f944d20cf 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -121,29 +121,22 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda // allocating interop buffer this.ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(length, type, false); + // passing existing pointer to native holder + this.ptrDataBuffer.setPrimaryBuffer(pointer, length); + //cuda specific bits this.allocationPoint = new AllocationPoint(ptrDataBuffer, length * elementSize); Nd4j.getDeallocatorService().pickObject(this); - // now we're + // now we're getting context and copying our stuff to device val context = AtomicAllocator.getInstance().getDeviceContext(); val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - if (allocationPoint.getHostPointer() != null) { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream()); - } else { - NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); - } - - context.getSpecialStream().synchronize(); - - if (allocationPoint.getHostPointer() != null) - PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST); + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE); - + context.getSpecialStream().synchronize(); } public BaseCudaDataBuffer(float[] data, boolean copy) {