cavis/libnd4j/include/ops/declarable/helpers/cuda/top_k.cu

282 lines
13 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/helpers/top_k.h>
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
__global__ static void inTopKCuda(const void* vx, const Nd4jLong* xShapeInfo,
const void* vy, const Nd4jLong* yShapeInfo,
void* vz, const Nd4jLong* zShapeInfo,
const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets,
const uint k) {
const auto y = reinterpret_cast<const Y*>(vy);
auto z = reinterpret_cast<bool*>(vz);
__shared__ uint* sharedMem;
__shared__ X elemToCompare;
__shared__ const X* xTad;
__shared__ Nd4jLong idx, xTadLen;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<uint*>(shmem);
xTadLen = shape::length(xTadShapeInfo);
xTad = reinterpret_cast<const X*>(vx) + xTadOffsets[blockIdx.x];
idx = y[shape::getIndexOffset(blockIdx.x, yShapeInfo, shape::length(yShapeInfo))]; // shape::length(yShapeInfo) == numTads
elemToCompare = xTad[shape::getIndexOffset(idx, xTadShapeInfo, xTadLen)];
}
__syncthreads();
sharedMem[threadIdx.x] = 0;
for (Nd4jLong i = threadIdx.x; i < xTadLen; i += blockDim.x)
if(elemToCompare < xTad[shape::getIndexOffset(i, xTadShapeInfo, xTadLen)])
++sharedMem[threadIdx.x];
__syncthreads();
// aggregate sum
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads)
sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads];
__syncthreads();
}
if (threadIdx.x == 0)
z[shape::getIndexOffset(blockIdx.x, zShapeInfo, shape::length(zShapeInfo))] = *sharedMem < k;
}
///////////////////////////////////////////////////////////////////
template<typename X, typename Y>
static void inTopKCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void *vx, const Nd4jLong *xShapeInfo,
const void *vy, const Nd4jLong *yShapeInfo,
void *vz, const Nd4jLong *zShapeInfo,
const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets,
const uint k) {
inTopKCuda<X,Y><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, k);
}
///////////////////////////////////////////////////////////////////
int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, const NDArray* targets, NDArray* output, const uint k) {
PointersManager manager(context, "in_top_k");
const auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(predictions->getShapeInfo(), {1});
const int threadsPerBlock = MAX_NUM_THREADS;
const int blocksPerGrid = static_cast<int>(packX.numberOfTads());
const int sharedMem = sizeof(uint) * threadsPerBlock + 128;
const auto xType = predictions->dataType();
const auto yType = targets->dataType();
NDArray::prepareSpecialUse({output}, {predictions, targets});
BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->getSpecialBuffer(), predictions->getSpecialShapeInfo(), targets->getSpecialBuffer(), targets->getSpecialShapeInfo(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({output}, {predictions, targets});
manager.synchronize();
return Status::OK();
}
template <typename X, typename Y>
static _CUDA_G void topValuesMover(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vi, Nd4jLong *iTadShapeInfo, Nd4jLong *iTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong tadLength, int numTads, int k) {
for (int t = blockIdx.x; t < numTads; t += gridDim.x) {
auto x = reinterpret_cast<X*>(vx) + xTadOffsets[t];
auto i = reinterpret_cast<Y*>(vi) + iTadOffsets[t];
auto z = reinterpret_cast<X*>(vz) + zTadOffsets[t];
for (int e = threadIdx.x; e < k; e += blockDim.x) {
auto idx = i[shape::getIndexOffset(e, iTadShapeInfo, k)];
z[shape::getIndexOffset(e, zTadShapeInfo, k)] = x[shape::getIndexOffset(idx, xTadShapeInfo, tadLength)];
}
}
}
template <typename X, typename Y>
static _CUDA_G void indicesAlongDimension(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vi, Nd4jLong *iTadShapeInfo, Nd4jLong *iTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong tadLength, int numTads, int k, int scanWidth, bool needSort) {
extern __shared__ char _shmem[];
X* tempValues = reinterpret_cast<X*>(_shmem) + threadIdx.x * scanWidth;
Y* tempIndices = reinterpret_cast<Y*>(reinterpret_cast<X*>(_shmem) + blockDim.x * scanWidth) + threadIdx.x * scanWidth;
__shared__ X localMaximum;
if (threadIdx.x == 0)
localMaximum = -DataTypeUtils::max<X>();
__syncthreads();
for (int t = blockIdx.x; t < numTads; t += gridDim.x) {
auto x = reinterpret_cast<X *>(vx) + xTadOffsets[t];
auto i = reinterpret_cast<Y *>(vi) + iTadOffsets[t];
auto z = reinterpret_cast<X *>(vz) + zTadOffsets[t];
// we'll do multiple reads here
for (int p = 0; p < k; p += scanWidth) {
// resetting temporary storage
for (int p = 0; p < scanWidth; p++) {
tempValues[p] = -DataTypeUtils::max<X>();
tempIndices[p] = DataTypeUtils::max<Y>();
}
// local max values/indices
for (int e = threadIdx.x; e < tadLength; e++) {
auto value = x[shape::getIndexOffset(e, xTadShapeInfo, tadLength)];
// we'll compare this value to current stored ones
for (int f = 0; f < scanWidth; f++) {
if (value > tempValues[f] && (p == 0 || value < localMaximum)) {
tempValues[f] = value;
tempIndices[f] = e;
}
}
}
__syncthreads();
// at this point we have local part ready for merge and define global maximum for this iteration, and local maximum for next iteration
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads) {
if (tempValues[0] < tempValues[0 + activeThreads * scanWidth]) {
tempValues[0] = tempValues[0 + activeThreads * scanWidth];
tempIndices[0] = tempIndices[0 + activeThreads * scanWidth];
}
}
__syncthreads();
}
__syncthreads();
// at this point we know local minimum for next iteration
if (threadIdx.x == 0) {
localMaximum = tempValues[scanWidth - 1];
z[shape::getIndexOffset(p, zTadShapeInfo, k)] = tempValues[scanWidth - 1];
i[shape::getIndexOffset(p, iTadShapeInfo, k)] = tempIndices[scanWidth - 1];
}
__syncthreads();
}
__syncthreads();
if (!needSort) {
// if we don't need sort, we need to return values based on their indices (ascending)
for (int m = 0; m < k; m++) {
if (m % 2 == 0) {
for (int tid = threadIdx.x; tid < k; tid += blockDim.x) {
auto top = 2 * tid + 1;
if (top < k) {
auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo, k);
auto t1 = shape::getIndexOffset(top, iTadShapeInfo, k);
if (i[t0] > i[t1]) {
// swap indices first
Y di0 = i[t0];
i[t0] = i[t1];
i[t1] = di0;
//swap values next
X dz0 = z[t0];
z[t0] = z[t1];
z[t1] = dz0;
}
}
}
} else {
for (int tid = threadIdx.x; tid < k; tid += blockDim.x) {
auto top = 2 * tid + 2;
if (top < k) {
auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo, k);
auto t1 = shape::getIndexOffset(top, iTadShapeInfo, k);
if (i[t0] > i[t1]) {
// swap indices first
Y di0 = i[t0];
i[t0] = i[t1];
i[t1] = di0;
//swap values next
X dz0 = z[t0];
z[t0] = z[t1];
z[t1] = dz0;
}
}
}
}
__syncthreads();
}
}
}
}
template <typename X, typename Y>
static int topKFunctor_(nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) {
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 1});
auto packI = ConstantTadHelper::getInstance()->tadForDimensions(indices->shapeInfo(), {input->rankOf() - 1});
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(values->shapeInfo(), {input->rankOf() - 1});
auto tadLength = shape::length(packX.primaryShapeInfo());
// we get top K values first
if (k == 1) {
input->applyIndexReduce(indexreduce::IndexMax, indices, {input->rankOf() - 1});
// copy values on specified indices
topValuesMover<X,Y><<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k);
} else {
int scanWidth = 1;
int numTreads = 256;
int shMemSize = (numTreads * sizeof(X) * scanWidth) + (numTreads * sizeof(Y) * scanWidth) + 512;
indicesAlongDimension<X,Y><<<256, numTreads, shMemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k, scanWidth, needSort);
}
return Status::OK();
}
int topKFunctor(nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) {
input->syncToDevice();
BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INDEXING_TYPES);
values->tickWriteDevice();
indices->tickWriteDevice();
return Status::OK();
}
}
}
}