CUDA sync tweaks (#194)

* ThreadLocal cache for CudaContext

Signed-off-by: raver119 <raver119@gmail.com>

* temp commit

Signed-off-by: raver119 <raver119@gmail.com>

* remove unwanted synchronization

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-01-28 10:55:06 +03:00 committed by GitHub
parent 7ef0ef907e
commit 9f719488b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 60 additions and 60 deletions

View File

@ -353,18 +353,8 @@ public class AtomicAllocator implements Allocator {
*/
@Override
public void synchronizeHostData(DataBuffer buffer) {
// we don't want non-committed ops left behind
Nd4j.getExecutioner().commit();
val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
//assert oPtr.address() == cPtr.address();
//assert buffer.address() == oPtr.address();
}

View File

@ -102,6 +102,8 @@ public class CudaZeroHandler implements MemoryHandler {
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
private final transient ThreadLocal<CudaContext> tlContext = new ThreadLocal<>();
/*
table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap
*/
@ -1018,9 +1020,11 @@ public class CudaZeroHandler implements MemoryHandler {
* @return
*/
public CudaContext getCudaContext() {
var ctx = tlContext.get();
if (ctx == null) {
val lc = nativeOps.defaultLaunchContext();
return CudaContext.builder()
ctx = CudaContext.builder()
.bufferScalar(nativeOps.lcScalarPointer(lc))
.bufferReduction(nativeOps.lcReductionPointer(lc))
.bufferAllocation(nativeOps.lcAllocationPointer(lc))
@ -1030,6 +1034,11 @@ public class CudaZeroHandler implements MemoryHandler {
.cublasHandle(getCudaCublasHandle(lc))
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
.build();
tlContext.set(ctx);
return ctx;
} else
return ctx;
}
/**

View File

@ -1610,8 +1610,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override
public void commit() {
AtomicAllocator.getInstance().getDeviceContext().syncOldStream();
AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream();
val ctx = AtomicAllocator.getInstance().getDeviceContext();
ctx.syncOldStream();
ctx.syncSpecialStream();
}
@Override