concat cuda helper
Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									1a0a8b1497
								
							
						
					
					
						commit
						6ed03217b4
					
				@ -73,6 +73,60 @@ namespace nd4j {
 | 
			
		||||
                concatCuda<T><<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
 | 
			
		||||
            }
 | 
			
		||||
            BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher,  (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
 | 
			
		||||
 | 
			
		||||
            //////////////////////////////////////////////////////////////////////////
 | 
			
		||||
            void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
 | 
			
		||||
 | 
			
		||||
                const int numOfArrs = inArrs.size();
 | 
			
		||||
                for(int i = 0; i < numOfArrs; ++i)
 | 
			
		||||
                    if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice();
 | 
			
		||||
 | 
			
		||||
                const int rank  = inArrs[0]->rankOf();
 | 
			
		||||
                const int rank2 = 2*rank;
 | 
			
		||||
                std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
 | 
			
		||||
 | 
			
		||||
                // take into account indices for first array
 | 
			
		||||
                indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
 | 
			
		||||
 | 
			
		||||
                // loop through the rest of input arrays
 | 
			
		||||
                for(int i = 1; i < numOfArrs; ++i) {
 | 
			
		||||
                    indices[i][2 * axis]     = indices[i-1][2 * axis + 1];                                // index start from
 | 
			
		||||
                    indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis);      // index end with (excluding)
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                std::vector<NDArray*> outSubArrs(numOfArrs);
 | 
			
		||||
                for(int i = 0; i < numOfArrs; ++i)
 | 
			
		||||
                    outSubArrs[i] = new NDArray(output(indices[i], true));
 | 
			
		||||
 | 
			
		||||
                // prepare arrays of pointers on buffers and shapes
 | 
			
		||||
                std::vector<void*>     hOutBuffers(numOfArrs), hInBuffers(numOfArrs);
 | 
			
		||||
                std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs);
 | 
			
		||||
                for(int i = 0; i < numOfArrs; ++i) {
 | 
			
		||||
                    hOutBuffers[i]   = outSubArrs[i]->getSpecialBuffer();
 | 
			
		||||
                    hInBuffers[i]    =     inArrs[i]->getSpecialBuffer();
 | 
			
		||||
                    hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo();
 | 
			
		||||
                    hInShapeInfo[i]  =     inArrs[i]->getSpecialShapeInfo();
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                // allocate and copy all buffers and shapes arrays to global memory
 | 
			
		||||
                PointersManager manager(context, "helpers::concat");
 | 
			
		||||
                void* dOutBuffers	= manager.replicatePointer(hOutBuffers.data(),   hOutBuffers.size() * sizeof(void*));
 | 
			
		||||
                void* dInBuffers	= manager.replicatePointer(hInBuffers.data(),    hInBuffers.size() * sizeof(void*));
 | 
			
		||||
                void* dInShapeInfo  = manager.replicatePointer(hInShapeInfo.data(),  hInShapeInfo.size() * sizeof(Nd4jLong*));
 | 
			
		||||
                void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
 | 
			
		||||
 | 
			
		||||
                BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
 | 
			
		||||
 | 
			
		||||
                manager.synchronize();
 | 
			
		||||
 | 
			
		||||
                for(int i = 0; i < numOfArrs; ++i)
 | 
			
		||||
                    delete outSubArrs[i];
 | 
			
		||||
 | 
			
		||||
                for(int i = 0; i < numOfArrs; ++i)
 | 
			
		||||
                    inArrs[i]->tickReadHost();
 | 
			
		||||
 | 
			
		||||
                output.tickWriteDevice();
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user