dedicated lock for getCudaCublasHandle
Signed-off-by: raver119 <raver119@gmail.com>master
parent
2129d5bcac
commit
d3253aff3f
|
@ -1180,16 +1180,23 @@ public class CudaZeroHandler implements MemoryHandler {
|
||||||
return getCudaContext();
|
return getCudaContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
|
||||||
|
|
||||||
|
protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
|
||||||
protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
|
|
||||||
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||||
|
try {
|
||||||
|
lock.writeLock().lock();
|
||||||
|
|
||||||
if (cublasHandles.get(deviceId) == null) {
|
if (cublasHandles.get(deviceId) == null) {
|
||||||
cublasHandles.remove(deviceId);
|
cublasHandles.remove(deviceId);
|
||||||
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
|
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
|
||||||
}
|
}
|
||||||
|
|
||||||
return cublasHandles.get(deviceId);
|
return cublasHandles.get(deviceId);
|
||||||
|
} finally {
|
||||||
|
lock.writeLock().unlock();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -16985,7 +16985,19 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a batched matrix tensor with new batched diagonal values.
|
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
|
||||||
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* input: input array, considered as batch of matrices
|
||||||
|
* diagonal: array containing elements to be inserted into input array,
|
||||||
|
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
|
||||||
|
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
|
||||||
|
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
|
||||||
|
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
|
||||||
|
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* has the same shape as input, corresponding diagonal elements are substituted
|
||||||
*/
|
*/
|
||||||
// #if NOT_EXCLUDED(OP_matrix_set_diag)
|
// #if NOT_EXCLUDED(OP_matrix_set_diag)
|
||||||
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
|
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
|
||||||
|
|
Loading…
Reference in New Issue