parent
cb4c9377b1
commit
18828f9725
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.jita.handler.impl;
|
package org.nd4j.jita.handler.impl;
|
||||||
|
|
||||||
|
import org.nd4j.nativeblas.OpaqueLaunchContext;
|
||||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||||
import org.nd4j.shade.guava.collect.Table;
|
import org.nd4j.shade.guava.collect.Table;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -105,6 +106,8 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
|
|
||||||
private final AllocationStatus INITIAL_LOCATION;
|
private final AllocationStatus INITIAL_LOCATION;
|
||||||
|
|
||||||
|
private final List<cublasHandle_t> cublasHandles = new ArrayList<>();
|
||||||
|
|
||||||
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
private final AffinityManager affinityManager = Nd4j.getAffinityManager();
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -162,6 +165,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
|
int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices();
|
||||||
for (int i = 0; i < numDevices; i++) {
|
for (int i = 0; i < numDevices; i++) {
|
||||||
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
|
deviceAllocations.add(new ConcurrentHashMap<Long, Long>());
|
||||||
|
cublasHandles.add(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
|
if (NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(0) < 3) {
|
||||||
|
@ -1176,6 +1180,17 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
return getCudaContext();
|
return getCudaContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
|
||||||
|
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
if (cublasHandles.get(deviceId) == null)
|
||||||
|
cublasHandles.remove(deviceId);
|
||||||
|
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
|
||||||
|
|
||||||
|
return cublasHandles.get(deviceId);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
|
* This method returns CudaContext for current thread. If context doesn't exist - it gets created first.
|
||||||
* @return
|
* @return
|
||||||
|
@ -1183,8 +1198,6 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
public CudaContext getCudaContext() {
|
public CudaContext getCudaContext() {
|
||||||
val lc = nativeOps.defaultLaunchContext();
|
val lc = nativeOps.defaultLaunchContext();
|
||||||
|
|
||||||
// TODO: maybe make ThreadLocal cache for context?
|
|
||||||
|
|
||||||
return CudaContext.builder()
|
return CudaContext.builder()
|
||||||
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
.bufferScalar(nativeOps.lcScalarPointer(lc))
|
||||||
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
.bufferReduction(nativeOps.lcReductionPointer(lc))
|
||||||
|
@ -1192,7 +1205,7 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
.bufferSpecial(nativeOps.lcScalarPointer(lc))
|
.bufferSpecial(nativeOps.lcScalarPointer(lc))
|
||||||
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
|
.oldStream(new cudaStream_t(nativeOps.lcExecutionStream(lc)))
|
||||||
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
|
.specialStream(new cudaStream_t(nativeOps.lcCopyStream(lc)))
|
||||||
.cublasHandle(new cublasHandle_t(nativeOps.lcBlasHandle(lc)))
|
.cublasHandle(getCudaCublasHandle(lc))
|
||||||
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
.solverHandle(new cusolverDnHandle_t(nativeOps.lcSolverHandle(lc)))
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|
|
@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -3830,6 +3830,9 @@ public native @Cast("Nd4jPointer") Pointer lcSolverHandle(OpaqueLaunchContext lc
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Reference in New Issue