parent
cb4c9377b1
commit
18828f9725
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.jita.handler.impl;
|
||||
|
||||
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||
import org.nd4j.shade.guava.collect.Table;
|
||||
import lombok.Getter;
|
||||
|
@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
|
||||
private final AllocationStatus INITIAL_LOCATION;
|
||||
|
||||
private final List<cublasHandle_t> cublasHandles = new ArrayList<>();
|
||||
|
||||
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
||||
|
||||
/*
|
||||
|
@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
|
||||
for (int i = 0; i < numDevices; i++) {
|
||||
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
|
||||
cublasHandles.add(null);
|
||||
}
|
||||
|
||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
|
||||
|
@ -1176,6 +1180,17 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
return getCudaContext();
|
||||
}
|
||||
|
||||
|
||||
|
||||
protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
|
||||
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
if (cublasHandles.get(deviceId) == null)
|
||||
cublasHandles.remove(deviceId);
|
||||
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
|
||||
|
||||
return cublasHandles.get(deviceId);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
|
||||
* @return
|
||||
|
@ -1183,8 +1198,6 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
public CudaContext getCudaContext() {
|
||||
val lc = nativeOps.defaultLaunchContext();
|
||||
|
||||
// TODO: maybe make ThreadLocal cache for context?
|
||||
|
||||
return CudaContext.builder()
|
||||
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
||||
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
||||
|
@ -1192,7 +1205,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
.bufferSpecial(nativeOps.lcScalarPointer(lc))
|
||||
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
|
||||
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
|
||||
.cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc)))
|
||||
.cublasHandle(getCudaCublasHandle(lc))
|
||||
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
||||
.build();
|
||||
}
|
||||
|
|
|
@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
|||
* @param writeList
|
||||
* @param readList
|
||||
*/
|
||||
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||
|
||||
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||
|
||||
|
||||
/**
|
||||
|
|
|
@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
|||
* @param writeList
|
||||
* @param readList
|
||||
*/
|
||||
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||
|
||||
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||
|
||||
|
||||
/**
|
||||
|
|
Loading…
Reference in New Issue