cuda broadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									ae7933a428
								
							
						
					
					
						commit
						cea68c18f1
					
				@ -42,12 +42,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
 | 
			
		||||
          auto z = reinterpret_cast<Z*>(vz);
 | 
			
		||||
 | 
			
		||||
    __shared__ int xRank, yRank, zRank;
 | 
			
		||||
    __shared__ Nd4jLong zLen, totalThreads, *sharedMem;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
 | 
			
		||||
    __shared__ Nd4jLong zLen, totalThreads;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
 | 
			
		||||
 | 
			
		||||
    if (threadIdx.x == 0) {
 | 
			
		||||
        extern __shared__ unsigned char shmem[];
 | 
			
		||||
        sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
 | 
			
		||||
 | 
			
		||||
        xRank = shape::rank(xShapeInfo);
 | 
			
		||||
        yRank = shape::rank(yShapeInfo);
 | 
			
		||||
        zRank = shape::rank(zShapeInfo);
 | 
			
		||||
@ -57,9 +54,9 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
 | 
			
		||||
    auto yCoords = xCoords + xRank;
 | 
			
		||||
    auto zCoords = yCoords + yRank;
 | 
			
		||||
    Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank);
 | 
			
		||||
    Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank;
 | 
			
		||||
    Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank;
 | 
			
		||||
 | 
			
		||||
    const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
@ -94,7 +91,6 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
 | 
			
		||||
template<typename X, typename Y, typename Z>
 | 
			
		||||
template <typename OpType>
 | 
			
		||||
void TrueBroadcastHelper<X,Y,Z>::execLauncher(dim3 launchDims, cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
 | 
			
		||||
 | 
			
		||||
    trueBroadcastCuda<X, Y, Z, OpType><<<launchDims.x, launchDims.y, launchDims.z, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -106,7 +102,7 @@ void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const ND
 | 
			
		||||
 | 
			
		||||
    launchDims.x = MAX_NUM_THREADS / 8;   // threadsPerBlock
 | 
			
		||||
    launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x;  // blocksPerGrid
 | 
			
		||||
    launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
 | 
			
		||||
    launchDims.z = 1024; // sharedMem
 | 
			
		||||
 | 
			
		||||
    PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
 | 
			
		||||
 | 
			
		||||
@ -128,12 +124,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh
 | 
			
		||||
          auto z = reinterpret_cast<Z*>(vz);
 | 
			
		||||
 | 
			
		||||
    __shared__ int xRank, yRank, zRank;
 | 
			
		||||
    __shared__ Nd4jLong zLen, totalThreads, *sharedMem;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
 | 
			
		||||
    __shared__ Nd4jLong zLen, totalThreads;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
 | 
			
		||||
 | 
			
		||||
    if (threadIdx.x == 0) {
 | 
			
		||||
        extern __shared__ unsigned char shmem[];
 | 
			
		||||
        sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
 | 
			
		||||
 | 
			
		||||
        xRank = shape::rank(xShapeInfo);
 | 
			
		||||
        yRank = shape::rank(yShapeInfo);
 | 
			
		||||
        zRank = shape::rank(zShapeInfo);
 | 
			
		||||
@ -143,9 +136,9 @@ __global__ static void trueBroadcastBoolCuda(const void* vx, const Nd4jLong* xSh
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
 | 
			
		||||
    auto yCoords = xCoords + xRank;
 | 
			
		||||
    auto zCoords = yCoords + yRank;
 | 
			
		||||
    Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank);
 | 
			
		||||
    Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank;
 | 
			
		||||
    Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank;
 | 
			
		||||
 | 
			
		||||
    const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
@ -191,7 +184,7 @@ void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, co
 | 
			
		||||
    dim3 launchDims;
 | 
			
		||||
    launchDims.x = MAX_NUM_THREADS / 8;   // threadsPerBlock
 | 
			
		||||
    launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x;  // blocksPerGrid
 | 
			
		||||
    launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
 | 
			
		||||
    launchDims.z = 1024; // sharedMem
 | 
			
		||||
 | 
			
		||||
    PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
 | 
			
		||||
 | 
			
		||||
@ -213,12 +206,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha
 | 
			
		||||
          auto z = reinterpret_cast<X*>(vz);
 | 
			
		||||
 | 
			
		||||
    __shared__ int xRank, yRank, zRank;
 | 
			
		||||
    __shared__ Nd4jLong zLen, totalThreads, *sharedMem;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
 | 
			
		||||
    __shared__ Nd4jLong zLen, totalThreads;  // xLen == zLen, except when xRank = 1, in this case zLen = 2*xLen
 | 
			
		||||
 | 
			
		||||
    if (threadIdx.x == 0) {
 | 
			
		||||
        extern __shared__ unsigned char shmem[];
 | 
			
		||||
        sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
 | 
			
		||||
 | 
			
		||||
        xRank = shape::rank(xShapeInfo);
 | 
			
		||||
        yRank = shape::rank(yShapeInfo);
 | 
			
		||||
        zRank = shape::rank(zShapeInfo);
 | 
			
		||||
@ -228,9 +218,9 @@ __global__ static void trueBroadcastIntCuda(const void* vx, const Nd4jLong* xSha
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    auto xCoords = sharedMem + threadIdx.x * (xRank + yRank + zRank);
 | 
			
		||||
    auto yCoords = xCoords + xRank;
 | 
			
		||||
    auto zCoords = yCoords + yRank;
 | 
			
		||||
    Nd4jLong xCoords[MAX_RANK]; // = sharedMem + threadIdx.x * (xRank + yRank + zRank);
 | 
			
		||||
    Nd4jLong yCoords[MAX_RANK]; // = xCoords + xRank;
 | 
			
		||||
    Nd4jLong zCoords[MAX_RANK]; // = yCoords + yRank;
 | 
			
		||||
 | 
			
		||||
    const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
@ -276,7 +266,7 @@ void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const
 | 
			
		||||
    dim3 launchDims;
 | 
			
		||||
    launchDims.x = MAX_NUM_THREADS / 8;   // threadsPerBlock
 | 
			
		||||
    launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x;  // blocksPerGrid
 | 
			
		||||
    launchDims.z = sizeof(Nd4jLong) * launchDims.x * (xArr.rankOf() + yArr.rankOf() + zArr.rankOf()) + 128; // sharedMem
 | 
			
		||||
    launchDims.z = 1024; // sharedMem
 | 
			
		||||
 | 
			
		||||
    PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user