- permute threadsPerBlock and blocksPerGrid in signature of launching of cuda kernel for trueBroadcast op (#120)
Signed-off-by: Yurii <iuriish@yahoo.com>
This commit is contained in:
		
							parent
							
								
									0175ace4c3
								
							
						
					
					
						commit
						425c747330
					
				| @ -42,9 +42,12 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI | ||||
|           auto z = reinterpret_cast<Z*>(vz); | ||||
| 
 | ||||
|     __shared__ int xRank, yRank, zRank; | ||||
|     __shared__ Nd4jLong zLen, totalThreads;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen | ||||
|     __shared__ Nd4jLong zLen, totalThreads, *sharedMem;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen | ||||
| 
 | ||||
|     if (threadIdx.x == 0) { | ||||
|         extern __shared__ unsigned char shmem[]; | ||||
|         sharedMem = reinterpret_cast<Nd4jLong*>(shmem); | ||||
| 
 | ||||
|         xRank = shape::rank(xShapeInfo); | ||||
|         yRank = shape::rank(yShapeInfo); | ||||
|         zRank = shape::rank(zShapeInfo); | ||||
| @ -54,9 +57,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI | ||||
|     } | ||||
|     __syncthreads(); | ||||
| 
 | ||||
|     Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); | ||||
|     Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; | ||||
|     Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; | ||||
|     auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); | ||||
|     auto yCoords = xCoords + xRank; | ||||
|     auto zCoords = yCoords + yRank; | ||||
| 
 | ||||
|     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -66,19 +69,17 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI | ||||
| 
 | ||||
|         for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { | ||||
| 
 | ||||
|             if(ix >= 0) { | ||||
|                 if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) | ||||
|             if(ix >= 0) | ||||
|                 if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) | ||||
|                     xCoords[ix--] = zCoords[iz]; | ||||
|                 else | ||||
|                     xCoords[ix--] = 0; | ||||
|             } | ||||
| 
 | ||||
|             if(iy >= 0) { | ||||
|                 if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) | ||||
|             if(iy >= 0) | ||||
|                 if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) | ||||
|                     yCoords[iy--] = zCoords[iz]; | ||||
|                 else | ||||
|                     yCoords[iy--] = 0; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         const auto xOffset = shape::getOffset(xShapeInfo, xCoords); | ||||
| @ -93,6 +94,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI | ||||
| template<typename X, typename Y, typename Z> | ||||
| template <typename OpType> | ||||
| void TrueBroadcastHelper<X,Y,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { | ||||
| 
 | ||||
|     trueBroadcastCuda<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); | ||||
| } | ||||
| 
 | ||||
| @ -102,9 +104,9 @@ void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const ND | ||||
| 
 | ||||
|     dim3 launchDims; | ||||
| 
 | ||||
|     launchDims.x = 128; //MAX_NUM_THREADS / 8;   // threadsPerBlock | ||||
|     launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x;  // blocksPerGrid | ||||
|     launchDims.z = 1024; // sharedMem | ||||
|     launchDims.y = MAX_NUM_THREADS / 8;   // threadsPerBlock | ||||
|     launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y;  // blocksPerGrid | ||||
|     launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe | ||||
| 
 | ||||
|     PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec"); | ||||
| 
 | ||||
| @ -126,9 +128,12 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh | ||||
|           auto z = reinterpret_cast<Z*>(vz); | ||||
| 
 | ||||
|     __shared__ int xRank, yRank, zRank; | ||||
|     __shared__ Nd4jLong zLen, totalThreads;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen | ||||
|     __shared__ Nd4jLong zLen, totalThreads, *sharedMem;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen | ||||
| 
 | ||||
|     if (threadIdx.x == 0) { | ||||
|         extern __shared__ unsigned char shmem[]; | ||||
|         sharedMem = reinterpret_cast<Nd4jLong*>(shmem); | ||||
| 
 | ||||
|         xRank = shape::rank(xShapeInfo); | ||||
|         yRank = shape::rank(yShapeInfo); | ||||
|         zRank = shape::rank(zShapeInfo); | ||||
| @ -138,9 +143,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh | ||||
|     } | ||||
|     __syncthreads(); | ||||
| 
 | ||||
|     Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); | ||||
|     Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; | ||||
|     Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; | ||||
|     auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); | ||||
|     auto yCoords = xCoords + xRank; | ||||
|     auto zCoords = yCoords + yRank; | ||||
| 
 | ||||
|     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -184,9 +189,10 @@ template<typename X, typename Y> | ||||
| void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { | ||||
| 
 | ||||
|     dim3 launchDims; | ||||
|     launchDims.x = 128; //MAX_NUM_THREADS / 8;   // threadsPerBlock | ||||
|     launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x;  // blocksPerGrid | ||||
|     launchDims.z = 1024; // sharedMem | ||||
| 
 | ||||
|     launchDims.y = MAX_NUM_THREADS / 8;   // threadsPerBlock | ||||
|     launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y;  // blocksPerGrid | ||||
|     launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe | ||||
| 
 | ||||
|     PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec"); | ||||
| 
 | ||||
| @ -208,9 +214,12 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha | ||||
|           auto z = reinterpret_cast<X*>(vz); | ||||
| 
 | ||||
|     __shared__ int xRank, yRank, zRank; | ||||
|     __shared__ Nd4jLong zLen, totalThreads;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen | ||||
|     __shared__ Nd4jLong zLen, totalThreads, *sharedMem;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen | ||||
| 
 | ||||
|     if (threadIdx.x == 0) { | ||||
|         extern __shared__ unsigned char shmem[]; | ||||
|         sharedMem = reinterpret_cast<Nd4jLong*>(shmem); | ||||
| 
 | ||||
|         xRank = shape::rank(xShapeInfo); | ||||
|         yRank = shape::rank(yShapeInfo); | ||||
|         zRank = shape::rank(zShapeInfo); | ||||
| @ -220,9 +229,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha | ||||
|     } | ||||
|     __syncthreads(); | ||||
| 
 | ||||
|     Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); | ||||
|     Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; | ||||
|     Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; | ||||
|     auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); | ||||
|     auto yCoords = xCoords + xRank; | ||||
|     auto zCoords = yCoords + yRank; | ||||
| 
 | ||||
|     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -266,9 +275,10 @@ template<typename X> | ||||
| void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { | ||||
| 
 | ||||
|     dim3 launchDims; | ||||
|     launchDims.x = 128; //MAX_NUM_THREADS / 8;   // threadsPerBlock | ||||
|     launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x;  // blocksPerGrid | ||||
|     launchDims.z = 1024; // sharedMem | ||||
| 
 | ||||
|     launchDims.y = MAX_NUM_THREADS / 8;   // threadsPerBlock | ||||
|     launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y;  // blocksPerGrid | ||||
|     launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe | ||||
| 
 | ||||
|     PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec"); | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user