concat cuda helper
Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									1a0a8b1497
								
							
						
					
					
						commit
						6ed03217b4
					
				| @ -73,6 +73,60 @@ namespace nd4j { | ||||
|                 concatCuda<T><<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo); | ||||
|             } | ||||
|             BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher,  (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES); | ||||
| 
 | ||||
|             ////////////////////////////////////////////////////////////////////////// | ||||
|             void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) { | ||||
| 
 | ||||
|                 const int numOfArrs = inArrs.size(); | ||||
|                 for(int i = 0; i < numOfArrs; ++i) | ||||
|                     if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice(); | ||||
| 
 | ||||
|                 const int rank  = inArrs[0]->rankOf(); | ||||
|                 const int rank2 = 2*rank; | ||||
|                 std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0)); | ||||
| 
 | ||||
|                 // take into account indices for first array | ||||
|                 indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); | ||||
| 
 | ||||
|                 // loop through the rest of input arrays | ||||
|                 for(int i = 1; i < numOfArrs; ++i) { | ||||
|                     indices[i][2 * axis]     = indices[i-1][2 * axis + 1];                                // index start from | ||||
|                     indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis);      // index end with (excluding) | ||||
|                 } | ||||
| 
 | ||||
|                 std::vector<NDArray*> outSubArrs(numOfArrs); | ||||
|                 for(int i = 0; i < numOfArrs; ++i) | ||||
|                     outSubArrs[i] = new NDArray(output(indices[i], true)); | ||||
| 
 | ||||
|                 // prepare arrays of pointers on buffers and shapes | ||||
|                 std::vector<void*>     hOutBuffers(numOfArrs), hInBuffers(numOfArrs); | ||||
|                 std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs); | ||||
|                 for(int i = 0; i < numOfArrs; ++i) { | ||||
|                     hOutBuffers[i]   = outSubArrs[i]->getSpecialBuffer(); | ||||
|                     hInBuffers[i]    =     inArrs[i]->getSpecialBuffer(); | ||||
|                     hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo(); | ||||
|                     hInShapeInfo[i]  =     inArrs[i]->getSpecialShapeInfo(); | ||||
|                 } | ||||
| 
 | ||||
|                 // allocate and copy all buffers and shapes arrays to global memory | ||||
|                 PointersManager manager(context, "helpers::concat"); | ||||
|                 void* dOutBuffers	= manager.replicatePointer(hOutBuffers.data(),   hOutBuffers.size() * sizeof(void*)); | ||||
|                 void* dInBuffers	= manager.replicatePointer(hInBuffers.data(),    hInBuffers.size() * sizeof(void*)); | ||||
|                 void* dInShapeInfo  = manager.replicatePointer(hInShapeInfo.data(),  hInShapeInfo.size() * sizeof(Nd4jLong*)); | ||||
|                 void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*)); | ||||
| 
 | ||||
|                 BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES); | ||||
| 
 | ||||
|                 manager.synchronize(); | ||||
| 
 | ||||
|                 for(int i = 0; i < numOfArrs; ++i) | ||||
|                     delete outSubArrs[i]; | ||||
| 
 | ||||
|                 for(int i = 0; i < numOfArrs; ++i) | ||||
|                     inArrs[i]->tickReadHost(); | ||||
| 
 | ||||
|                 output.tickWriteDevice(); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user