CUDA sync tweaks (#194)

* ThreadLocal cache for CudaContext

Signed-off-by: raver119 <raver119@gmail.com>

* temp commit

Signed-off-by: raver119 <raver119@gmail.com>

* remove unwanted synchronization

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-28 10:55:06 +03:00 committed by GitHub
parent 7ef0ef907e
commit 9f719488b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 60 deletions

View File

@ -353,18 +353,8 @@ public class AtomicAllocator implements Allocator {
*/ */
@Override @Override
public void synchronizeHostData(DataBuffer buffer) { public void synchronizeHostData(DataBuffer buffer) {
// we don't want non-committed ops left behind
Nd4j.getExecutioner().commit();
val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer()); NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
//assert oPtr.address() == cPtr.address();
//assert buffer.address() == oPtr.address();
} }

View File

@ -102,6 +102,8 @@ public class CudaZeroHandler implements MemoryHandler {
private final AffinityManager affinityManager = Nd4j.getAffinityManager(); private final AffinityManager affinityManager = Nd4j.getAffinityManager();
private final transient ThreadLocal<CudaContext> tlContext = new ThreadLocal<>();
/* /*
table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap
*/ */
@ -1018,9 +1020,11 @@ public class CudaZeroHandler implements MemoryHandler {
* @return * @return
*/ */
public CudaContext getCudaContext() { public CudaContext getCudaContext() {
var ctx = tlContext.get();
if (ctx == null) {
val lc = nativeOps.defaultLaunchContext(); val lc = nativeOps.defaultLaunchContext();
return CudaContext.builder() ctx = CudaContext.builder()
.bufferScalar(nativeOps.lcScalarPointer(lc)) .bufferScalar(nativeOps.lcScalarPointer(lc))
.bufferReduction(nativeOps.lcReductionPointer(lc)) .bufferReduction(nativeOps.lcReductionPointer(lc))
.bufferAllocation(nativeOps.lcAllocationPointer(lc)) .bufferAllocation(nativeOps.lcAllocationPointer(lc))
@ -1030,6 +1034,11 @@ public class CudaZeroHandler implements MemoryHandler {
.cublasHandle(getCudaCublasHandle(lc)) .cublasHandle(getCudaCublasHandle(lc))
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
.build(); .build();
tlContext.set(ctx);
return ctx;
} else
return ctx;
} }
/** /**

View File

@ -1610,8 +1610,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public void commit() { public void commit() {
AtomicAllocator.getInstance().getDeviceContext().syncOldStream(); val ctx = AtomicAllocator.getInstance().getDeviceContext();
AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream(); ctx.syncOldStream();
ctx.syncSpecialStream();
} }
@Override @Override