cublasHandle sharing + lock

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-02 16:52:10 +03:00
parent cb4c9377b1
commit 18828f9725
3 changed files with 22 additions and 3 deletions

View File

@ -16,6 +16,7 @@
package org.nd4j.jita.handler.impl;
import org.nd4j.nativeblas.OpaqueLaunchContext;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;
import lombok.Getter;
@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler {
private final AllocationStatus INITIAL_LOCATION;
private final List<cublasHandle_t> cublasHandles = new ArrayList<>();
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
/*
@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler {
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
for (int i = 0; i < numDevices; i++) {
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
cublasHandles.add(null);
}
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
@ -1176,6 +1180,17 @@ public class CudaZeroHandler implements MemoryHandler {
return getCudaContext();
}
protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
if (cublasHandles.get(deviceId) == null)
cublasHandles.remove(deviceId);
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
return cublasHandles.get(deviceId);
}
/**
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
* @return
@ -1183,8 +1198,6 @@ public class CudaZeroHandler implements MemoryHandler {
public CudaContext getCudaContext() {
val lc = nativeOps.defaultLaunchContext();
// TODO: maybe make ThreadLocal cache for context?
return CudaContext.builder()
.bufferScalar(nativeOps.lcScalarPointer(lc))
.bufferReduction(nativeOps.lcReductionPointer(lc))
@ -1192,7 +1205,7 @@ public class CudaZeroHandler implements MemoryHandler {
.bufferSpecial(nativeOps.lcScalarPointer(lc))
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
.cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc)))
.cublasHandle(getCudaCublasHandle(lc))
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
.build();
}

View File

@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
* @param writeList
* @param readList
*/
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
/**

View File

@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
* @param writeList
* @param readList
*/
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
/**