parent
1a0a8b1497
commit
6ed03217b4
|
@ -73,6 +73,60 @@ namespace nd4j {
|
|||
concatCuda<T><<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
const int numOfArrs = inArrs.size();
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice();
|
||||
|
||||
const int rank = inArrs[0]->rankOf();
|
||||
const int rank2 = 2*rank;
|
||||
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
|
||||
|
||||
// take into account indices for first array
|
||||
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
|
||||
|
||||
// loop through the rest of input arrays
|
||||
for(int i = 1; i < numOfArrs; ++i) {
|
||||
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
||||
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
|
||||
}
|
||||
|
||||
std::vector<NDArray*> outSubArrs(numOfArrs);
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
outSubArrs[i] = new NDArray(output(indices[i], true));
|
||||
|
||||
// prepare arrays of pointers on buffers and shapes
|
||||
std::vector<void*> hOutBuffers(numOfArrs), hInBuffers(numOfArrs);
|
||||
std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs);
|
||||
for(int i = 0; i < numOfArrs; ++i) {
|
||||
hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer();
|
||||
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
|
||||
hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo();
|
||||
hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo();
|
||||
}
|
||||
|
||||
// allocate and copy all buffers and shapes arrays to global memory
|
||||
PointersManager manager(context, "helpers::concat");
|
||||
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
|
||||
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
|
||||
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
|
||||
void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
|
||||
|
||||
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
|
||||
|
||||
manager.synchronize();
|
||||
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
delete outSubArrs[i];
|
||||
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
inArrs[i]->tickReadHost();
|
||||
|
||||
output.tickWriteDevice();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue