diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index fdbd001fd..12c3eb0c5 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -42,9 +42,12 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -54,9 +57,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI } __syncthreads(); - Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); - Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; - Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; + auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); + auto yCoords = xCoords + xRank; + auto zCoords = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -66,19 +69,17 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) { - if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) + if(ix >= 0) + if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) xCoords[ix--] = zCoords[iz]; else xCoords[ix--] = 0; - } - if(iy >= 0) { - if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) + if(iy >= 0) + if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) yCoords[iy--] = zCoords[iz]; else yCoords[iy--] = 0; - } } const auto xOffset = shape::getOffset(xShapeInfo, xCoords); @@ -93,6 +94,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI template template void TrueBroadcastHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { + trueBroadcastCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } @@ -102,9 +104,9 @@ void TrueBroadcastHelper::exec(const nd4j::broadcast::Ops opNum, const ND dim3 launchDims; - launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = 1024; // sharedMem + launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); @@ -126,9 +128,12 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -138,9 +143,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh } __syncthreads(); - Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); - Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; - Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; + auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); + auto yCoords = xCoords + xRank; + auto zCoords = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -184,9 +189,10 @@ template void TrueBroadcastBoolHelper::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = 1024; // sharedMem + + launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); @@ -208,9 +214,12 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); + xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -220,9 +229,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha } __syncthreads(); - Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); - Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; - Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; + auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); + auto yCoords = xCoords + xRank; + auto zCoords = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -266,9 +275,10 @@ template void TrueBroadcastIntHelper::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = 1024; // sharedMem + + launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid + launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec");