dedicated lock for getCudaCublasHandle

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-02 20:01:13 +03:00
parent 2129d5bcac
commit d3253aff3f
2 changed files with 28 additions and 9 deletions

View File

@ -1180,16 +1180,23 @@ public class CudaZeroHandler implements MemoryHandler {
return getCudaContext();
}
//
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
if (cublasHandles.get(deviceId) == null) {
cublasHandles.remove(deviceId);
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
}
try {
lock.writeLock().lock();
return cublasHandles.get(deviceId);
if (cublasHandles.get(deviceId) == null) {
cublasHandles.remove(deviceId);
cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
}
return cublasHandles.get(deviceId);
} finally {
lock.writeLock().unlock();
}
}
/**

View File

@ -16985,8 +16985,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif
/**
* Returns a batched matrix tensor with new batched diagonal values.
*/
* Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
*
* Input arrays:
* input: input array, considered as batch of matrices
* diagonal: array containing elements to be inserted into input array,
* following rank condition should be satisfied: diagonal_rank = input_rank - 1,
* the shapes of diagonal and input arrays must be equal except last dimension of input array,
* for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
* also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
* that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
*
* Output array:
* has the same shape as input, corresponding diagonal elements are substituted
*/
// #if NOT_EXCLUDED(OP_matrix_set_diag)
@Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
static { Loader.load(); }