[WIP] CUDA concat tweak (#148)
* one special test Signed-off-by: raver119 <raver119@gmail.com> * one special test Signed-off-by: raver119 <raver119@gmail.com> * local memory for concat Signed-off-by: raver119 <raver119@gmail.com> * fixed grid size for concat Signed-off-by: raver119 <raver119@gmail.com> * fixed grid size for concat Signed-off-by: raver119 <raver119@gmail.com> * test commented out Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									39d43ca170
								
							
						
					
					
						commit
						d1e5e79c10
					
				| @ -39,13 +39,10 @@ template<typename T> | |||||||
| __global__ static void concatCuda(void* pVx,  void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { | __global__ static void concatCuda(void* pVx,  void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { | ||||||
| 
 | 
 | ||||||
|     T* z = reinterpret_cast<T*>(vz); |     T* z = reinterpret_cast<T*>(vz); | ||||||
|     __shared__ Nd4jLong zLen, totalThreads, *sharedMem; |     __shared__ Nd4jLong zLen, totalThreads; | ||||||
|     __shared__ int rank; |     __shared__ int rank; | ||||||
| 
 | 
 | ||||||
|     if (threadIdx.x == 0) { |     if (threadIdx.x == 0) { | ||||||
|         extern __shared__ unsigned char shmem[]; |  | ||||||
|         sharedMem = reinterpret_cast<Nd4jLong*>(shmem); |  | ||||||
| 
 |  | ||||||
|         zLen = shape::length(zShapeInfo); |         zLen = shape::length(zShapeInfo); | ||||||
|         rank = shape::rank(zShapeInfo); |         rank = shape::rank(zShapeInfo); | ||||||
|         totalThreads = gridDim.x * blockDim.x; |         totalThreads = gridDim.x * blockDim.x; | ||||||
| @ -54,27 +51,26 @@ __global__ static void concatCuda(void* pVx,  void* pxShapeInfo, void* vz, Nd4jL | |||||||
| 
 | 
 | ||||||
|     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; |     const auto tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
| 
 | 
 | ||||||
|     if(tid >= zLen) |     Nd4jLong coords[MAX_RANK]; | ||||||
|         return; |  | ||||||
| 
 | 
 | ||||||
|     auto coords = sharedMem + threadIdx.x * rank; |     for (uint64_t i = tid; i < zLen; i += totalThreads) { | ||||||
| 
 |         shape::index2coords(i, zShapeInfo, coords); | ||||||
|     shape::index2coords(tid, zShapeInfo, coords); |  | ||||||
| 
 | 
 | ||||||
|         const auto zOffset = shape::getOffset(zShapeInfo, coords); |         const auto zOffset = shape::getOffset(zShapeInfo, coords); | ||||||
| 
 | 
 | ||||||
|         int inArrIdx = 0; |         int inArrIdx = 0; | ||||||
|     Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[inArrIdx]; |         Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong **>(pxShapeInfo)[inArrIdx]; | ||||||
| 
 | 
 | ||||||
|     while(coords[axis] >= xShapeInfo[axis + 1]) { |         while (coords[axis] >= xShapeInfo[axis + 1]) { | ||||||
|             coords[axis] -= xShapeInfo[axis + 1]; |             coords[axis] -= xShapeInfo[axis + 1]; | ||||||
|         xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[++inArrIdx]; |             xShapeInfo = reinterpret_cast<Nd4jLong **>(pxShapeInfo)[++inArrIdx]; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     const auto* x      = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[inArrIdx]); |         const auto *x = reinterpret_cast<T *>(reinterpret_cast<void **>(pVx)[inArrIdx]); | ||||||
|         const auto xOffset = shape::getOffset(xShapeInfo, coords); |         const auto xOffset = shape::getOffset(xShapeInfo, coords); | ||||||
| 
 | 
 | ||||||
|         z[zOffset] = x[xOffset]; |         z[zOffset] = x[xOffset]; | ||||||
|  |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /////////////////////////////////////////////////////////////////// | /////////////////////////////////////////////////////////////////// | ||||||
| @ -89,9 +85,9 @@ BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid | |||||||
| ////////////////////////////////////////////////////////////////////////// | ////////////////////////////////////////////////////////////////////////// | ||||||
| void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) { | void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) { | ||||||
| 
 | 
 | ||||||
|     const int threadsPerBlock = MAX_NUM_THREADS / 4; |     const int threadsPerBlock = 256; | ||||||
|     const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; |     const int blocksPerGrid = 512; | ||||||
|     const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128; |     const int sharedMem = 512; | ||||||
| 
 | 
 | ||||||
|     const int numOfArrs = inArrs.size(); |     const int numOfArrs = inArrs.size(); | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -62,19 +62,31 @@ public: | |||||||
| 
 | 
 | ||||||
| /*
 | /*
 | ||||||
| TEST_F(PlaygroundTests, test_s_1) { | TEST_F(PlaygroundTests, test_s_1) { | ||||||
|     auto x = NDArrayFactory::create<float>('c', {32,112,112,16}); |     auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176}); | ||||||
|     auto y = NDArrayFactory::create<float>('c', {16}); |     auto x1 = x0.ulike(); | ||||||
|     auto z = x.ulike(); |     auto x2 = x0.ulike(); | ||||||
|  |     auto x3 = x0.ulike(); | ||||||
|  |     auto x4 = x0.ulike(); | ||||||
|  |     auto x5 = x0.ulike(); | ||||||
|  | 
 | ||||||
|  |     auto y = NDArrayFactory::create<int >(3); | ||||||
|  |     auto z = NDArrayFactory::create<float>('c', {32, 7, 7, 1056}); | ||||||
| 
 | 
 | ||||||
|     Context ctx(1); |     Context ctx(1); | ||||||
|     ctx.setInputArray(0, &x); |     ctx.setInputArray(0, &x0); | ||||||
|     ctx.setInputArray(1, &y); |     ctx.setInputArray(1, &x1); | ||||||
|  |     ctx.setInputArray(2, &x2); | ||||||
|  |     ctx.setInputArray(3, &x3); | ||||||
|  |     ctx.setInputArray(4, &x4); | ||||||
|  |     ctx.setInputArray(5, &x5); | ||||||
|  | 
 | ||||||
|  |     ctx.setInputArray(6, &y); | ||||||
|     ctx.setOutputArray(0, &z); |     ctx.setOutputArray(0, &z); | ||||||
|  |     ctx.setBArguments({true}); | ||||||
| 
 | 
 | ||||||
|     std::vector<Nd4jLong> values; |     std::vector<Nd4jLong> values; | ||||||
| 
 | 
 | ||||||
| 
 |     nd4j::ops::concat op; | ||||||
|     nd4j::ops::biasadd op; |  | ||||||
|     op.execute(&ctx); |     op.execute(&ctx); | ||||||
| 
 | 
 | ||||||
|     for (int e = 0; e < 1000; e++) { |     for (int e = 0; e < 1000; e++) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user