- permute threadsPerBlock and blocksPerGrid in signature of launching of cuda kernel for trueBroadcast op (#120)

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Yurii Shyrma 2019-12-09 19:08:36 +02:00 committed by raver119
parent 0175ace4c3
commit 425c747330
1 changed files with 37 additions and 27 deletions

View File

@ -42,9 +42,12 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
__shared__ int xRank, yRank, zRank; __shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo); xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo); yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo); zRank = shape::rank(zShapeInfo);
@ -54,9 +57,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
} }
__syncthreads(); __syncthreads();
Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; auto yCoords = xCoords + xRank;
Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -66,19 +69,17 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) { if(ix >= 0)
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz]; xCoords[ix--] = zCoords[iz];
else else
xCoords[ix--] = 0; xCoords[ix--] = 0;
}
if(iy >= 0) { if(iy >= 0)
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz]; yCoords[iy--] = zCoords[iz];
else else
yCoords[iy--] = 0; yCoords[iy--] = 0;
}
} }
const auto xOffset = shape::getOffset(xShapeInfo, xCoords); const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
@ -93,6 +94,7 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
template<typename X, typename Y, typename Z> template<typename X, typename Y, typename Z>
template <typename OpType> template <typename OpType>
void TrueBroadcastHelper<X,Y,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) { void TrueBroadcastHelper<X,Y,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
trueBroadcastCuda<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo); trueBroadcastCuda<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
} }
@ -102,9 +104,9 @@ void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const ND
dim3 launchDims; dim3 launchDims;
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = 1024; // sharedMem launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec"); PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
@ -126,9 +128,12 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh
auto z = reinterpret_cast<Z*>(vz); auto z = reinterpret_cast<Z*>(vz);
__shared__ int xRank, yRank, zRank; __shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo); xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo); yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo); zRank = shape::rank(zShapeInfo);
@ -138,9 +143,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh
} }
__syncthreads(); __syncthreads();
Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; auto yCoords = xCoords + xRank;
Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -184,9 +189,10 @@ template<typename X, typename Y>
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims; dim3 launchDims;
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.z = 1024; // sharedMem launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec"); PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
@ -208,9 +214,12 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha
auto z = reinterpret_cast<X*>(vz); auto z = reinterpret_cast<X*>(vz);
__shared__ int xRank, yRank, zRank; __shared__ int xRank, yRank, zRank;
__shared__ Nd4jLong zLen, totalThreads; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen __shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
xRank = shape::rank(xShapeInfo); xRank = shape::rank(xShapeInfo);
yRank = shape::rank(yShapeInfo); yRank = shape::rank(yShapeInfo);
zRank = shape::rank(zShapeInfo); zRank = shape::rank(zShapeInfo);
@ -220,9 +229,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha
} }
__syncthreads(); __syncthreads();
Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank); auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank; auto yCoords = xCoords + xRank;
Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank; auto zCoords = yCoords + yRank;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -266,9 +275,10 @@ template<typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims; dim3 launchDims;
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.y = MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.z = 1024; // sharedMem launchDims.x = (zArr.lengthOf() + launchDims.y - 1) / launchDims.y; // blocksPerGrid
launchDims.z = sizeof(Nd4jLong) * launchDims.y * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMe
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec"); PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");