cavis/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu

184 lines
7.7 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/lrn.h>
#include <Status.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static _CUDA_G void lrnKernel(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) {
extern __shared__ char sharedChar[];
__shared__ T* shared;
if (threadIdx.x == 0)
shared = reinterpret_cast<T*>(sharedChar);
__syncthreads();
auto xEws = shape::elementWiseStride(xTadShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo);
auto xOrder = shape::order(xTadShapeInfo);
auto zOrder = shape::order(zTadShapeInfo);
const T tbias = static_cast<T>(bias);
const T tbeta = static_cast<T>(beta);
const T talpha = static_cast<T>(alpha);
for (uint i = blockIdx.x; i < numTads; i += gridDim.x) {
auto x = reinterpret_cast<T*>(vx) + xTadOffsets[i];
auto z = reinterpret_cast<T*>(vz) + zTadOffsets[i];
// load everything into shared memory
shared[threadIdx.x] = x[threadIdx.x * xEws];
__syncthreads();
const uint begin = nd4j::math::nd4j_max<int>(0, threadIdx.x - depth);
const uint last = depth + threadIdx.x + 1;
const uint end = nd4j::math::nd4j_min<int>(last, tadLength);
T prev = 0.;
for (int s = begin; s < end; s++)
prev = prev + shared[s] * shared[s];
z[threadIdx.x * zEws] = shared[threadIdx.x] / nd4j::math::nd4j_pow<T, T, T>(tbias + alpha * prev, tbeta);
}
}
template <typename X, typename Z>
static _CUDA_G void lrnBPKernel(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong numTads, Nd4jLong tadLength, int depth, double bias, double alpha, double beta) {
extern __shared__ char sharedChar[];
__shared__ X* sharedX;
__shared__ Z* sharedY;
if (threadIdx.x == 0) {
sharedX = reinterpret_cast<X*>(sharedChar);
sharedY = reinterpret_cast<Z*>(sharedX + blockDim.x);
}
__syncthreads();
auto xEws = shape::elementWiseStride(xTadShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo);
auto xOrder = shape::order(xTadShapeInfo);
auto zOrder = shape::order(zTadShapeInfo);
const Z tbias = static_cast<Z>(bias);
const Z tbeta = static_cast<Z>(beta);
const Z talpha = static_cast<Z>(alpha);
const Z coeff = talpha * tbeta;
for (uint i = blockIdx.x; i < numTads; i += gridDim.x) {
auto x = reinterpret_cast<X*>(vx) + xTadOffsets[i];
auto z = reinterpret_cast<Z*>(vz) + zTadOffsets[i];
const uint begin = nd4j::math::nd4j_max<int>(0, threadIdx.x - depth);
const uint last = depth + threadIdx.x + 1;
const uint end = nd4j::math::nd4j_min<int>(last, tadLength);
// load everything into shared memory
sharedX[threadIdx.x] = x[threadIdx.x * xEws];
sharedY[threadIdx.x] = 0.f;
__syncthreads();
for (int s = begin; s < end; s++)
sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s];
__syncthreads();
Z factor[1024];
Z init = tbias + talpha * sharedY[threadIdx.x];
Z prev = 0.f;
for (uint s = begin; s < end; ++s) {
factor[s] = nd4j::math::nd4j_pow<Z, Z, Z>(tbias + talpha * sharedY[s], -tbeta - 1);
prev = prev + sharedX[s] * factor[s];
}
z[threadIdx.x * zEws] = factor[threadIdx.x] * init - 2 * sharedX[threadIdx.x] * coeff * prev;
}
}
template <typename X, typename Z>
static void lrnBP_(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) {
auto rank = input.rankOf();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.getShapeInfo(), {rank - 1});
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.getShapeInfo(), {rank - 1});
const auto tadLength = shape::length(packX.primaryShapeInfo());
const int numBlocks = nd4j::math::nd4j_min<Nd4jLong>(1024, packX.numberOfTads());
const int numThreads = tadLength;
if (tadLength > 1024 || tadLength < 1)
throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet");
lrnBPKernel<X, Z><<<numBlocks, numThreads, numThreads * sizeof(X) + numThreads * sizeof(Z) + 1024, *block.launchContext()->getCudaStream()>>>(input.getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta);
gradI.tickWriteDevice();
gradI *= gradO;
}
void lrnBP(nd4j::graph::Context& block, const NDArray& input, const NDArray& gradO, NDArray& gradI, const int depth, const float bias, const float alpha, const float beta) {
input.syncToDevice();
gradO.syncToDevice();
BUILD_DOUBLE_SELECTOR(input.dataType(), gradO.dataType(), lrnBP_, (block, input, gradO, gradI, depth, bias, alpha, beta), FLOAT_TYPES, FLOAT_TYPES);
gradI.tickWriteDevice();
}
template <typename T>
static void lrnFunctor_(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) {
auto rank = input->rankOf();
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), {rank - 1});
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {rank - 1});
const auto tadLength = shape::length(packX.primaryShapeInfo());
const int numBlocks = nd4j::math::nd4j_min<Nd4jLong>(1024, packX.numberOfTads());
const int numThreads = tadLength;
if (tadLength > 1024 || tadLength < 1)
throw std::runtime_error("LRN: tadLength > 1024 isn't implemented yet");
lrnKernel<T><<<numBlocks, numThreads, numThreads * sizeof(T), *block.launchContext()->getCudaStream()>>>(input->specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), output->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), packX.numberOfTads(), tadLength, depth, bias, alpha, beta);
}
int lrnFunctor(nd4j::graph::Context& block, NDArray* input, NDArray* output, int depth, double bias, double alpha, double beta) {
input->syncToDevice();
BUILD_SINGLE_SELECTOR(input->dataType(), lrnFunctor_, (block, input, output, depth, bias, alpha, beta), FLOAT_TYPES);
output->tickWriteDevice();
return Status::OK();
}
}
}
}