diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index 8f4a49905..3051de448 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -73,6 +73,60 @@ namespace nd4j { concatCuda<<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo); } BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES); + + ////////////////////////////////////////////////////////////////////////// + void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { + + const int numOfArrs = inArrs.size(); + for(int i = 0; i < numOfArrs; ++i) + if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice(); + + const int rank = inArrs[0]->rankOf(); + const int rank2 = 2*rank; + std::vector> indices(numOfArrs, std::vector(rank2,0)); + + // take into account indices for first array + indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); + + // loop through the rest of input arrays + for(int i = 1; i < numOfArrs; ++i) { + indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from + indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding) + } + + std::vector outSubArrs(numOfArrs); + for(int i = 0; i < numOfArrs; ++i) + outSubArrs[i] = new NDArray(output(indices[i], true)); + + // prepare arrays of pointers on buffers and shapes + std::vector hOutBuffers(numOfArrs), hInBuffers(numOfArrs); + std::vector hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs); + for(int i = 0; i < numOfArrs; ++i) { + hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer(); + hInBuffers[i] = inArrs[i]->getSpecialBuffer(); + hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo(); + hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo(); + } + + // allocate and copy all buffers and shapes arrays to global memory + PointersManager manager(context, "helpers::concat"); + void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); + void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); + void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); + void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*)); + + BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES); + + manager.synchronize(); + + for(int i = 0; i < numOfArrs; ++i) + delete outSubArrs[i]; + + for(int i = 0; i < numOfArrs; ++i) + inArrs[i]->tickReadHost(); + + output.tickWriteDevice(); + } } } } \ No newline at end of file