dedicated lock for getCudaCublasHandle
Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									2129d5bcac
								
							
						
					
					
						commit
						d3253aff3f
					
				@ -1180,16 +1180,23 @@ public class CudaZeroHandler implements MemoryHandler {
 | 
				
			|||||||
        return getCudaContext();
 | 
					        return getCudaContext();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    protected cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
 | 
				
			||||||
    protected synchronized cublasHandle_t getCudaCublasHandle(OpaqueLaunchContext lc) {
 | 
					 | 
				
			||||||
        val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
 | 
					        val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
 | 
				
			||||||
        if (cublasHandles.get(deviceId) == null) {
 | 
					        try {
 | 
				
			||||||
            cublasHandles.remove(deviceId);
 | 
					            lock.writeLock().lock();
 | 
				
			||||||
            cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return cublasHandles.get(deviceId);
 | 
					            if (cublasHandles.get(deviceId) == null) {
 | 
				
			||||||
 | 
					                cublasHandles.remove(deviceId);
 | 
				
			||||||
 | 
					                cublasHandles.add(deviceId, new cublasHandle_t(nativeOps.lcBlasHandle(lc)));
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return cublasHandles.get(deviceId);
 | 
				
			||||||
 | 
					        } finally {
 | 
				
			||||||
 | 
					            lock.writeLock().unlock();
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /**
 | 
					    /**
 | 
				
			||||||
 | 
				
			|||||||
@ -16985,8 +16985,20 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
 | 
				
			|||||||
//         #endif
 | 
					//         #endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        /**
 | 
					        /**
 | 
				
			||||||
         * Returns a batched matrix tensor with new batched diagonal values.
 | 
					        * Inserts elements provided by diagonal array into the main diagonal of innermost matrices of input array
 | 
				
			||||||
         */
 | 
					        *
 | 
				
			||||||
 | 
					        * Input arrays:
 | 
				
			||||||
 | 
					        *    input:    input array, considered as batch of matrices
 | 
				
			||||||
 | 
					        *    diagonal: array containing elements to be inserted into input array,
 | 
				
			||||||
 | 
					        *              following rank condition should be satisfied: diagonal_rank = input_rank - 1,
 | 
				
			||||||
 | 
					        *              the shapes of diagonal and input arrays must be equal except last dimension of input array,
 | 
				
			||||||
 | 
					        *              for example if input_shape = [A,B,C,D] then diagonal_shape = [A,B,C],
 | 
				
			||||||
 | 
					        *              also last dimension of diagonal array should be equal to smaller of last and last but one input dimensions
 | 
				
			||||||
 | 
					        *              that is: diagonal_shape[-1] = min(input_shape[-1], input_shape[-2])
 | 
				
			||||||
 | 
					        *
 | 
				
			||||||
 | 
					        * Output array:
 | 
				
			||||||
 | 
					        *    has the same shape as input, corresponding diagonal elements are substituted
 | 
				
			||||||
 | 
					        */
 | 
				
			||||||
//         #if NOT_EXCLUDED(OP_matrix_set_diag)
 | 
					//         #if NOT_EXCLUDED(OP_matrix_set_diag)
 | 
				
			||||||
        @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
 | 
					        @Namespace("nd4j::ops") public static class matrix_set_diag extends DeclarableOp {
 | 
				
			||||||
            static { Loader.load(); }
 | 
					            static { Loader.load(); }
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user