CUDA sync tweaks (#194)
* ThreadLocal cache for CudaContext Signed-off-by: raver119 <raver119@gmail.com> * temp commit Signed-off-by: raver119 <raver119@gmail.com> * remove unwanted synchronization Signed-off-by: raver119 <raver119@gmail.com>master
parent
7ef0ef907e
commit
9f719488b9
|
@ -353,18 +353,8 @@ public class AtomicAllocator implements Allocator {
|
|||
*/
|
||||
@Override
|
||||
public void synchronizeHostData(DataBuffer buffer) {
|
||||
// we don't want non-committed ops left behind
|
||||
Nd4j.getExecutioner().commit();
|
||||
|
||||
val oPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
||||
|
||||
// we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().dbSyncToPrimary(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
||||
|
||||
val cPtr = NativeOpsHolder.getInstance().getDeviceNativeOps().dbPrimaryBuffer(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
|
||||
|
||||
//assert oPtr.address() == cPtr.address();
|
||||
//assert buffer.address() == oPtr.address();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -102,6 +102,8 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
|
||||
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
||||
|
||||
private final transient ThreadLocal<CudaContext> tlContext = new ThreadLocal<>();
|
||||
|
||||
/*
|
||||
table for Thread, Device, Object allocations of device memory. Objects should be used to grab Allocation point from allocationsMap
|
||||
*/
|
||||
|
@ -1018,9 +1020,11 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
* @return
|
||||
*/
|
||||
public CudaContext getCudaContext() {
|
||||
var ctx = tlContext.get();
|
||||
if (ctx == null) {
|
||||
val lc = nativeOps.defaultLaunchContext();
|
||||
|
||||
return CudaContext.builder()
|
||||
ctx = CudaContext.builder()
|
||||
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
||||
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
||||
.bufferAllocation(nativeOps.lcAllocationPointer(lc))
|
||||
|
@ -1030,6 +1034,11 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
.cublasHandle(getCudaCublasHandle(lc))
|
||||
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
||||
.build();
|
||||
|
||||
tlContext.set(ctx);
|
||||
return ctx;
|
||||
} else
|
||||
return ctx;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -1610,8 +1610,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public void commit() {
|
||||
AtomicAllocator.getInstance().getDeviceContext().syncOldStream();
|
||||
AtomicAllocator.getInstance().getDeviceContext().syncSpecialStream();
|
||||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||
ctx.syncOldStream();
|
||||
ctx.syncSpecialStream();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
Loading…
Reference in New Issue