cavis/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu

90 lines
3.9 KiB
Plaintext

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author sgazeos@gmail.com
//
#include <ops/declarable/helpers/nth_element.h>
#include <TAD.h>
#include <ShapeUtils.h>
#include <PointersManager.h>
#include <NativeOps.h>
#include <helpers/ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong* outputShapeInfo, void* inputBuffer, Nd4jLong* inputShapeInfo, Nd4jLong* pTadShape, Nd4jLong* pTadOffsets, Nd4jLong n) {
__shared__ Nd4jLong bufferLength;
auto z = reinterpret_cast<T*>(outputBuffer);
auto x = reinterpret_cast<T*>(inputBuffer);
if (threadIdx.x == 0)
bufferLength = shape::length(outputShapeInfo);
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x;
for (int t = tid; t < bufferLength; t += step) {
auto tX = x + pTadOffsets[t];
z[shape::getIndexOffset(t, outputShapeInfo)] = tX[shape::getIndexOffset(n, pTadShape)]; //tX];
}
}
template <typename T>
void nthElementFunctor_(nd4j::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) {
NDArray::prepareSpecialUse({output}, {input});
NDArray sortedVals(*input);
Nd4jPointer params[2];
params[0] = context;
params[1] = context->getCudaStream();
// Nth element in sorted sequence : basic algorithm sort and retrieve nth element in sorted
if (input->isVector()) {
sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
}
else { // rank greater than 1
std::vector<int> lastDims({input->rankOf() - 1});// = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1});
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(sortedVals.getShapeInfo(), lastDims);
auto pTadShape = packX.specialShapeInfo();
auto pTadShapeH = packX.primaryShapeInfo();
auto pTadOffsets = packX.specialOffsets();
sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
sortedVals.tickWriteDevice();
sortedVals.syncToHost();
auto stream = context->getCudaStream();
fillUpElementKernel<T><<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n);
}
NDArray::registerSpecialUse({output}, {input});
}
void nthElementFunctor(nd4j::LaunchContext * context, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) {
BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (context, input, n, output, reverse), LIBND4J_TYPES);
}
}
}
}