|
|
|
rollKernelFullAnyDimensionStage1<T><<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLength, dim, sizeAt, theShift);
|
|
|
|
rollKernelFullAnyDimensionStage2<T><<<1, 256, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), numTads, tadLength, dim, sizeAt, theShift);
|
|
|
|
rollKernelLinearStage1<T><<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift);
|
|
|
|
rollKernelLinearStage2<T><<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift, shiftCount);
|
|
|
|
rollKernelLinearStage3<T><<<1, 1, 1024, *(output->getContext()->getCudaStream())>>>(output->specialBuffer(), output->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), fullLen, actualShift, remainShift);
|