diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index 106ac9c3a..9b8c1012c 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -1180,16 +1180,23 @@ public class CudaZeroHandler implements MemoryHandler { return getCudaContext(); } + // + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); - - protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { + protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) { val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); - if (cublasHandles.get(deviceId) == null) { - cublasHandles.remove(deviceId); - cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); - } + try { + lock.writeLock().lock(); - return cublasHandles.get(deviceId); + if (cublasHandles.get(deviceId) == null) { + cublasHandles.remove(deviceId); + cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc))); + } + + return cublasHandles.get(deviceId); + } finally { + lock.writeLock().unlock(); + } } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 8e150f618..9554a94e9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -16985,8 +16985,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #endif /** - * Returns a batched matrix tensor with new batched diagonal values. - */ + * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array + * + * Input arrays: + * input: input array, considered as batch of matrices + * diagonal: array containing elements to be inserted into input array, + * following rank condition should be satisfied: diagonal_rank = input_rank - 1, + * the shapes of diagonal and input arrays must be equal except last dimension of input array, + * for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C], + * also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions + * that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2]) + * + * Output array: + * has the same shape as input, corresponding diagonal elements are substituted + */ // #if NOT_EXCLUDED(OP_matrix_set_diag) @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp { static { Loader.load(); }