diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index 4b7904bca..f40690795 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -42,12 +42,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -57,9 +54,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI } __syncthreads(); - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; + Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); + Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; + Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -94,7 +91,6 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI template template void TrueBroadcastHelper::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { - trueBroadcastCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); } @@ -106,7 +102,7 @@ void TrueBroadcastHelper::exec(const nd4j::broadcast::Ops opNum, const ND launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem + launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); @@ -128,12 +124,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -143,9 +136,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh } __syncthreads(); - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; + Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); + Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; + Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -191,7 +184,7 @@ void TrueBroadcastBoolHelper::exec(const nd4j::broadcast::BoolOps opNum, co dim3 launchDims; launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem + launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); @@ -213,12 +206,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha auto z = reinterpret_cast(vz); __shared__ int xRank, yRank, zRank; - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen + __shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - xRank = shape::rank(xShapeInfo); yRank = shape::rank(yShapeInfo); zRank = shape::rank(zShapeInfo); @@ -228,9 +218,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha } __syncthreads(); - auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank); - auto yCoords = xCoords + xRank; - auto zCoords = yCoords + yRank; + Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); + Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; + Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; const auto tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -276,7 +266,7 @@ void TrueBroadcastIntHelper::exec(const nd4j::broadcast::IntOps opNum, const dim3 launchDims; launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid - launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem + launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec");