/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ #ifndef CUDA_LAMBDA_HELPER #define CUDA_LAMBDA_HELPER #include #include #include #include #include static Nd4jLong __device__ __noinline__ getIndexOffset(Nd4jLong index, const Nd4jLong *shapeInfo) { return shape::getIndexOffset(index, shapeInfo); } static Nd4jLong __device__ __noinline__ length(const Nd4jLong *shapeInfo) { return shape::length(shapeInfo); } template static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); template static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda); template class LambdaHelper { public: template FORCEINLINE static void lambdaLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { lambdaKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); auto err = cudaStreamSynchronize(*stream); if (err != 0) throw std::runtime_error("NDArray::applyLambda execution failed"); } template FORCEINLINE static void lambdaIndexedLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { lambdaIndexedKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, lambda); auto err = cudaStreamSynchronize(*stream); if (err != 0) throw std::runtime_error("NDArray::applyIndexedLambda execution failed"); } template FORCEINLINE static void lambdaPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { lambdaPairwiseKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); auto err = cudaStreamSynchronize(*stream); if (err != 0) throw std::runtime_error("NDArray::applyPairwiseLambda execution failed"); } template FORCEINLINE static void lambdaIndexedPairwiseLauncher(cudaStream_t *stream, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { lambdaIndexedPairwiseKernel<<<256, 512, 1024, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); auto err = cudaStreamSynchronize(*stream); if (err != 0) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda execution failed"); } template FORCEINLINE static void lambdaTriplewiseLauncher(cudaStream_t *stream,const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { lambdaTriplewiseKernel<<<256, 512, 1024, *stream>>>(vw, wShapeInfo, vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, lambda); auto err = cudaStreamSynchronize(*stream); if (err != 0) throw std::runtime_error("NDArray::applyTriplewiseLambda execution failed"); } }; //////////////////////////////////////////////////////////////////////// template static _CUDA_G void lambdaKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { auto x = reinterpret_cast(vx); auto z = reinterpret_cast(vz); auto xEws = shape::elementWiseStride(xShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo); auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) z[e * zEws] = lambda(x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { auto xOffset = getIndexOffset(e, xShapeInfo); auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(x[xOffset]); } } } //////////////////////////////////////////////////////////////////////// template static _CUDA_G void lambdaIndexedKernel(const void* vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { auto x = reinterpret_cast(vx); auto z = reinterpret_cast(vz); auto xEws = shape::elementWiseStride(xShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo); auto xOrder = shape::order(xShapeInfo); auto zOrder = shape::order(zShapeInfo); auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (xEws >= 1 && zEws >= 1 && xOrder == zOrder) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) z[e * zEws] = lambda(e, x[e * xEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { auto xOffset = getIndexOffset(e, xShapeInfo); auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(e, x[xOffset]); } } } //////////////////////////////////////////////////////////////////////// template static _CUDA_G void lambdaIndexedPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); auto xEws = shape::elementWiseStride(xShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo); auto xOrder = shape::order(xShapeInfo); auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) z[e * zEws] = lambda(e, x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { auto xOffset = getIndexOffset(e, xShapeInfo); auto yOffset = getIndexOffset(e, yShapeInfo); auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(e, x[xOffset], y[yOffset]); } } } //////////////////////////////////////////////////////////////////////// template static _CUDA_G void lambdaPairwiseKernel(const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); auto xEws = shape::elementWiseStride(xShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo); auto xOrder = shape::order(xShapeInfo); auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) z[e * zEws] = lambda(x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { auto xOffset = getIndexOffset(e, xShapeInfo); auto yOffset = getIndexOffset(e, yShapeInfo); auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(x[xOffset], y[yOffset]); } } } //////////////////////////////////////////////////////////////////////// template static _CUDA_G void lambdaTriplewiseKernel(const void* vw, const Nd4jLong *wShapeInfo, const void* vx, const Nd4jLong *xShapeInfo, const void* vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, Lambda lambda) { auto w = reinterpret_cast(vw); auto x = reinterpret_cast(vx); auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); auto wEws = shape::elementWiseStride(wShapeInfo); auto xEws = shape::elementWiseStride(xShapeInfo); auto yEws = shape::elementWiseStride(yShapeInfo); auto zEws = shape::elementWiseStride(zShapeInfo); auto wOrder = shape::order(wShapeInfo); auto xOrder = shape::order(xShapeInfo); auto yOrder = shape::order(yShapeInfo); auto zOrder = shape::order(zShapeInfo); auto zLength = length(zShapeInfo); auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (wEws > 1 && xEws >= 1 && yEws >= 1 && zEws >= 1 && xOrder == zOrder && yOrder == xOrder && wOrder == xOrder) { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) z[e * zEws] = lambda(w[e * wEws], x[e * xEws], y[e * yEws]); } else { for (uint e = tid; e < zLength; e += blockDim.x * gridDim.x) { auto wOffset = getIndexOffset(e, wShapeInfo); auto xOffset = getIndexOffset(e, xShapeInfo); auto yOffset = getIndexOffset(e, yShapeInfo); auto zOffset = getIndexOffset(e, zShapeInfo); z[zOffset] = lambda(w[wOffset], x[xOffset], y[yOffset]); } } } #endif ////////////////////////////////////////////////////////////////////////// template void NDArray::applyLambda(Lambda func, NDArray& target) { auto dtype = this->dataType(); if (dtype != target.dataType()) throw std::runtime_error("NDArray::applyLambda X/Z data types must be the same"); //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); prepareSpecialUse({&target}, {this}); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template void NDArray::applyPairwiseLambda(const NDArray& other, Lambda func, NDArray& target) { auto dtype = this->dataType(); if (dtype != target.dataType() || dtype != other.dataType()) throw std::runtime_error("NDArray::applyPairwiseLambda X/Y/Z data types must be the same"); //throw datatype_exception::build("NDArray::applyLambda X/Z data types must be the same", dtype, target.dataType()); prepareSpecialUse({&target}, {this, &other}); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template void NDArray::applyIndexedLambda(Lambda func, NDArray& target) { auto dtype = this->dataType(); if (dtype != target.dataType()) throw std::runtime_error("NDArray::applyIndexedLambda X/Z data types must be the same"); prepareSpecialUse({&target}, {this}); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); registerSpecialUse({&target}, {this}); } ////////////////////////////////////////////////////////////////////////// template void NDArray::applyIndexedPairwiseLambda(NDArray& other, Lambda func, NDArray& target) { auto dtype = this->dataType(); if (dtype != target.dataType() || dtype != other.dataType()) throw std::runtime_error("NDArray::applyIndexedPairwiseLambda X/Y/Z data types must be the same"); prepareSpecialUse({&target}, {this, &other}); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaIndexedPairwiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); registerSpecialUse({&target}, {this, &other}); } ////////////////////////////////////////////////////////////////////////// template void NDArray::applyTriplewiseLambda(NDArray& second, NDArray& third, Lambda func, NDArray& target) { auto dtype = this->dataType(); if (dtype != target.dataType() || dtype != second.dataType() || dtype != third.dataType()) throw std::runtime_error("NDArray::applyTriplewiseLambda X/Y/Z data types must be the same"); prepareSpecialUse({&target}, {this, &second, &third}); BUILD_SINGLE_SELECTOR(dtype, LambdaHelper ,::lambdaTriplewiseLauncher(this->_context->getCudaStream(), this->specialBuffer(), this->specialShapeInfo(), second.specialBuffer(), second.specialShapeInfo(), third.specialBuffer(), third.specialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), func), LIBND4J_TYPES); registerSpecialUse({&target}, {this, &second, &third}); }