From 18828f97252669ff5f10cd41e8c331a47e77183c Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 2 Sep 2019 16:52:10 +0300 Subject: [PATCH] cublasHandle sharing + lock Signed-off-by: raver119 --- .../jita/handler/impl/CudaZeroHandler.java | 19 ++++++++++++++++--- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 3 +++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 3 +++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index fdd40f8cb..23301f4be 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -16,6 +16,7 @@ package org.nd4j.jita.handler.impl; +import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; import lombok.Getter; @@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler { private final AllocationStatus INITIAL_LOCATION; + private final List cublasHandles = new ArrayList<>(); + private final AffinityManager affinityManager = Nd4j.getAffinityManager(); /* @@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler { int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); for (int i = 0; i < numDevices; i++) { deviceAllocations.add(new ConcurrentHashMap()); + cublasHandles.add(null); } if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) { @@ -1176,6 +1180,17 @@ public class CudaZeroHandler implements MemoryHandler { return getCudaContext(); } + + + protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + if (cublasHandles.get(deviceId) == null) + cublasHandles.remove(deviceId); + cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); + + return cublasHandles.get(deviceId); + } + /** * This method returns CudaContext for current thread. If context doesn't exist - it gets created first. * @return @@ -1183,8 +1198,6 @@ public class CudaZeroHandler implements MemoryHandler { public CudaContext getCudaContext() { val lc = nativeOps.defaultLaunchContext(); - // TODO: maybe make ThreadLocal cache for context? - return CudaContext.builder() .bufferScalar(nativeOps.lcScalarPointer(lc)) .bufferReduction(nativeOps.lcReductionPointer(lc)) @@ -1192,7 +1205,7 @@ public class CudaZeroHandler implements MemoryHandler { .bufferSpecial(nativeOps.lcScalarPointer(lc)) .oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc))) .specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc))) - .cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc))) + .cublasHandle(getCudaCublasHandle(lc)) .solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc))) .build(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 15f6c52ef..f3080f05a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 6983e20f0..8e150f618 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list + + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list /**