dedicated lock for getCudaCublasHandle
Signed-off-by: raver119 <raver119@gmail.com>master
parent
2129d5bcac
commit
d3253aff3f
|
@ -1180,16 +1180,23 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
return getCudaContext();
|
||||
}
|
||||
|
||||
//
|
||||
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
|
||||
|
||||
|
||||
protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
|
||||
protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
|
||||
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
if (cublasHandles.get(deviceId) == null) {
|
||||
cublasHandles.remove(deviceId);
|
||||
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
|
||||
}
|
||||
try {
|
||||
lock.writeLock().lock();
|
||||
|
||||
return cublasHandles.get(deviceId);
|
||||
if (cublasHandles.get(deviceId) == null) {
|
||||
cublasHandles.remove(deviceId);
|
||||
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
|
||||
}
|
||||
|
||||
return cublasHandles.get(deviceId);
|
||||
} finally {
|
||||
lock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -16985,8 +16985,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
// #endif
|
||||
|
||||
/**
|
||||
* Returns a batched matrix tensor with new batched diagonal values.
|
||||
*/
|
||||
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
|
||||
*
|
||||
* Input arrays:
|
||||
* input: input array, considered as batch of matrices
|
||||
* diagonal: array containing elements to be inserted into input array,
|
||||
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
|
||||
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
|
||||
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
|
||||
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
|
||||
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
|
||||
*
|
||||
* Output array:
|
||||
* has the same shape as input, corresponding diagonal elements are substituted
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_matrix_set_diag)
|
||||
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
|
||||
static { Loader.load(); }
|
||||
|
|
Loading…
Reference in New Issue