/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com) // #include #include #include namespace nd4j { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// template __global__ static void inTopKCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const uint k) { const auto y = reinterpret_cast(vy); auto z = reinterpret_cast(vz); __shared__ uint* sharedMem; __shared__ X elemToCompare; __shared__ const X* xTad; __shared__ Nd4jLong idx, xTadLen; if (threadIdx.x == 0) { extern __shared__ unsigned char shmem[]; sharedMem = reinterpret_cast(shmem); xTadLen = shape::length(xTadShapeInfo); xTad = reinterpret_cast(vx) + xTadOffsets[blockIdx.x]; idx = y[shape::getIndexOffset(blockIdx.x, yShapeInfo, shape::length(yShapeInfo))]; // shape::length(yShapeInfo) == numTads elemToCompare = xTad[shape::getIndexOffset(idx, xTadShapeInfo, xTadLen)]; } __syncthreads(); sharedMem[threadIdx.x] = 0; for (Nd4jLong i = threadIdx.x; i < xTadLen; i += blockDim.x) if(elemToCompare < xTad[shape::getIndexOffset(i, xTadShapeInfo, xTadLen)]) ++sharedMem[threadIdx.x]; __syncthreads(); // aggregate sum for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { if (threadIdx.x < activeThreads) sharedMem[threadIdx.x] += sharedMem[threadIdx.x + activeThreads]; __syncthreads(); } if (threadIdx.x == 0) z[shape::getIndexOffset(blockIdx.x, zShapeInfo, shape::length(zShapeInfo))] = *sharedMem < k; } /////////////////////////////////////////////////////////////////// template static void inTopKCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const void *vy, const Nd4jLong *yShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const uint k) { inTopKCuda<<>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, xTadShapeInfo, xTadOffsets, k); } /////////////////////////////////////////////////////////////////// int inTopKFunctor(nd4j::LaunchContext * context, const NDArray* predictions, const NDArray* targets, NDArray* output, const uint k) { PointersManager manager(context, "in_top_k"); const auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(predictions->getShapeInfo(), {1}); const int threadsPerBlock = MAX_NUM_THREADS; const int blocksPerGrid = static_cast(packX.numberOfTads()); const int sharedMem = sizeof(uint) * threadsPerBlock + 128; const auto xType = predictions->dataType(); const auto yType = targets->dataType(); NDArray::prepareSpecialUse({output}, {predictions, targets}); BUILD_DOUBLE_SELECTOR(xType, yType, inTopKCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), predictions->getSpecialBuffer(), predictions->getSpecialShapeInfo(), targets->getSpecialBuffer(), targets->getSpecialShapeInfo(), output->getSpecialBuffer(), output->getSpecialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets(), k), FLOAT_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({output}, {predictions, targets}); manager.synchronize(); return Status::OK(); } template static _CUDA_G void topValuesMover(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vi, Nd4jLong *iTadShapeInfo, Nd4jLong *iTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong tadLength, int numTads, int k) { for (int t = blockIdx.x; t < numTads; t += gridDim.x) { auto x = reinterpret_cast(vx) + xTadOffsets[t]; auto i = reinterpret_cast(vi) + iTadOffsets[t]; auto z = reinterpret_cast(vz) + zTadOffsets[t]; for (int e = threadIdx.x; e < k; e += blockDim.x) { auto idx = i[shape::getIndexOffset(e, iTadShapeInfo, k)]; z[shape::getIndexOffset(e, zTadShapeInfo, k)] = x[shape::getIndexOffset(idx, xTadShapeInfo, tadLength)]; } } } template static _CUDA_G void indicesAlongDimension(void *vx, Nd4jLong *xTadShapeInfo, Nd4jLong *xTadOffsets, void *vi, Nd4jLong *iTadShapeInfo, Nd4jLong *iTadOffsets, void *vz, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffsets, Nd4jLong tadLength, int numTads, int k, int scanWidth, bool needSort) { extern __shared__ char _shmem[]; X* tempValues = reinterpret_cast(_shmem) + threadIdx.x * scanWidth; Y* tempIndices = reinterpret_cast(reinterpret_cast(_shmem) + blockDim.x * scanWidth) + threadIdx.x * scanWidth; __shared__ X localMaximum; if (threadIdx.x == 0) localMaximum = -DataTypeUtils::max(); __syncthreads(); for (int t = blockIdx.x; t < numTads; t += gridDim.x) { auto x = reinterpret_cast(vx) + xTadOffsets[t]; auto i = reinterpret_cast(vi) + iTadOffsets[t]; auto z = reinterpret_cast(vz) + zTadOffsets[t]; // we'll do multiple reads here for (int p = 0; p < k; p += scanWidth) { // resetting temporary storage for (int p = 0; p < scanWidth; p++) { tempValues[p] = -DataTypeUtils::max(); tempIndices[p] = DataTypeUtils::max(); } // local max values/indices for (int e = threadIdx.x; e < tadLength; e++) { auto value = x[shape::getIndexOffset(e, xTadShapeInfo, tadLength)]; // we'll compare this value to current stored ones for (int f = 0; f < scanWidth; f++) { if (value > tempValues[f] && (p == 0 || value < localMaximum)) { tempValues[f] = value; tempIndices[f] = e; } } } __syncthreads(); // at this point we have local part ready for merge and define global maximum for this iteration, and local maximum for next iteration for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { if (threadIdx.x < activeThreads) { if (tempValues[0] < tempValues[0 + activeThreads * scanWidth]) { tempValues[0] = tempValues[0 + activeThreads * scanWidth]; tempIndices[0] = tempIndices[0 + activeThreads * scanWidth]; } } __syncthreads(); } __syncthreads(); // at this point we know local minimum for next iteration if (threadIdx.x == 0) { localMaximum = tempValues[scanWidth - 1]; z[shape::getIndexOffset(p, zTadShapeInfo, k)] = tempValues[scanWidth - 1]; i[shape::getIndexOffset(p, iTadShapeInfo, k)] = tempIndices[scanWidth - 1]; } __syncthreads(); } __syncthreads(); if (!needSort) { // if we don't need sort, we need to return values based on their indices (ascending) for (int m = 0; m < k; m++) { if (m % 2 == 0) { for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { auto top = 2 * tid + 1; if (top < k) { auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo, k); auto t1 = shape::getIndexOffset(top, iTadShapeInfo, k); if (i[t0] > i[t1]) { // swap indices first Y di0 = i[t0]; i[t0] = i[t1]; i[t1] = di0; //swap values next X dz0 = z[t0]; z[t0] = z[t1]; z[t1] = dz0; } } } } else { for (int tid = threadIdx.x; tid < k; tid += blockDim.x) { auto top = 2 * tid + 2; if (top < k) { auto t0 = shape::getIndexOffset(top - 1, iTadShapeInfo, k); auto t1 = shape::getIndexOffset(top, iTadShapeInfo, k); if (i[t0] > i[t1]) { // swap indices first Y di0 = i[t0]; i[t0] = i[t1]; i[t1] = di0; //swap values next X dz0 = z[t0]; z[t0] = z[t1]; z[t1] = dz0; } } } } __syncthreads(); } } } } template static int topKFunctor_(nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {input->rankOf() - 1}); auto packI = ConstantTadHelper::getInstance()->tadForDimensions(indices->shapeInfo(), {input->rankOf() - 1}); auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(values->shapeInfo(), {input->rankOf() - 1}); auto tadLength = shape::length(packX.primaryShapeInfo()); // we get top K values first if (k == 1) { input->applyIndexReduce(indexreduce::IndexMax, indices, {input->rankOf() - 1}); // copy values on specified indices topValuesMover<<<256, 256, 1024, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k); } else { int scanWidth = 1; int numTreads = 256; int shMemSize = (numTreads * sizeof(X) * scanWidth) + (numTreads * sizeof(Y) * scanWidth) + 512; indicesAlongDimension<<<256, numTreads, shMemSize, *context->getCudaStream()>>>(input->getSpecialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), indices->specialBuffer(), packI.platformShapeInfo(), packI.platformOffsets(), values->specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), tadLength, packX.numberOfTads(), k, scanWidth, needSort); } return Status::OK(); } int topKFunctor(nd4j::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { input->syncToDevice(); BUILD_DOUBLE_SELECTOR(input->dataType(), indices->dataType(), topKFunctor_, (context, input, values, indices, k, needSort), LIBND4J_TYPES, INDEXING_TYPES); values->tickWriteDevice(); indices->tickWriteDevice(); return Status::OK(); } } } }