/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include #include namespace nd4j { namespace ops { namespace helpers { template static _CUDA_G void lrnKernel(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) { extern __shared__ char sharedChar[]; T* shared = reinterpret_cast(sharedChar); auto xEws = shape::elementWiseStride(xTadShapeInfo); auto zEws = shape::elementWiseStride(zTadShapeInfo); auto xOrder = shape::order(xTadShapeInfo); auto zOrder = shape::order(zTadShapeInfo); const T tbias = static_cast(bias); const T tbeta = static_cast(beta); const T talpha = static_cast(alpha); // one block of threads processes 1 example within batch for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { auto x = reinterpret_cast(vx) + xTadOffsets[i]; auto z = reinterpret_cast(vz) + zTadOffsets[i]; // load everything into shared memory, so we'll operate on shared memory from now on shared[threadIdx.x] = x[threadIdx.x * xEws]; __syncthreads(); const uint begin = nd4j::math::nd4j_max(0, threadIdx.x - depth); const uint last = depth + threadIdx.x + 1; const uint end = nd4j::math::nd4j_min(last, tadLength); T prev = 0.; for (int s = begin; s < end; s++) prev = prev + shared[s] * shared[s]; z[threadIdx.x * zEws] = shared[threadIdx.x] / nd4j::math::nd4j_pow(tbias + alpha * prev, tbeta); } } template static _CUDA_G void lrnBPKernel(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) { extern __shared__ char sharedChar[]; X* sharedX = reinterpret_cast(sharedChar); Z* sharedY = reinterpret_cast(sharedX + blockDim.x); auto xEws = shape::elementWiseStride(xTadShapeInfo); auto zEws = shape::elementWiseStride(zTadShapeInfo); auto xOrder = shape::order(xTadShapeInfo); auto zOrder = shape::order(zTadShapeInfo); const Z tbias = static_cast(bias); const Z tbeta = static_cast(beta); const Z talpha = static_cast(alpha); const Z coeff = talpha * tbeta; for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { auto x = reinterpret_cast(vx) + xTadOffsets[i]; auto z = reinterpret_cast(vz) + zTadOffsets[i]; const uint begin = nd4j::math::nd4j_max(0, threadIdx.x - depth); const uint last = depth + threadIdx.x + 1; const uint end = nd4j::math::nd4j_min(last, tadLength); // load everything into shared memory sharedX[threadIdx.x] = x[threadIdx.x * xEws]; sharedY[threadIdx.x] = 0.f; __syncthreads(); // we're operating in shared memory for (int s = begin; s < end; s++) sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; __syncthreads(); Z factor[1024]; Z init = tbias + talpha * sharedY[threadIdx.x]; Z prev = 0.f; for (uint s = begin; s < end; ++s) { factor[s] = nd4j::math::nd4j_pow(tbias + talpha * sharedY[s], -tbeta - 1); prev = prev + sharedX[s] * factor[s]; } z[threadIdx.x * zEws] = factor[threadIdx.x] * init - 2 * sharedX[threadIdx.x] * coeff * prev; } } template static void lrnBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { auto rank = input.rankOf(); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), {rank - 1}); auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.getShapeInfo(), {rank - 1}); const auto tadLength = shape::length(packX.primaryShapeInfo()); const int numBlocks = nd4j::math::nd4j_min(1024, packX.numberOfTads()); const int numThreads = tadLength; if (tadLength > 1024 || tadLength < 1) throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); lrnBPKernel<<getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta); gradI.tickWriteDevice(); gradI *= gradO; } void lrnBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) { input.syncToDevice(); gradO.syncToDevice(); BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES); gradI.tickWriteDevice(); } template static void lrnFunctor_(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { auto rank = input->rankOf(); auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {rank - 1}); auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {rank - 1}); const auto tadLength = shape::length(packX.primaryShapeInfo()); const int numBlocks = nd4j::math::nd4j_min(1024, packX.numberOfTads()); const int numThreads = tadLength; if (tadLength > 1024 || tadLength < 1) throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet"); lrnKernel<<getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta); } int lrnFunctor(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) { input->syncToDevice(); BUILD_SINGLE_SELECTOR(input->dataType(), lrnFunctor_, (block, input, output, depth, bias, alpha, beta), FLOAT_TYPES); output->tickWriteDevice(); return Status::OK(); } } } }