/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author sgazeos@gmail.com // #include <ops/declarable/helpers/nth_element.h> #include <helpers/TAD.h> #include <helpers/ShapeUtils.h> #include <helpers/ConstantTadHelper.h> #include <execution/Threads.h> namespace sd { namespace ops { namespace helpers { template <typename T> void nthElementFunctor_(NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { NDArray sortedVals(*input); if (input->isVector()) { //std::vector<float> data(input->lengthOf()); //memcpy(&data[0], input->buffer(), sizeof(T) * data.size()); //size_t l = 0; //for (size_t l = 0; l < data.size(); ++l) // data[l] = input->e<float>(l); //auto nthPos = data.begin(); //nthPos += n; //std::nth_element(data.begin(), nthPos, data.end()); SpecialMethods<T>::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), reverse); output->p(0, sortedVals.e<T>(n)); } else { // rank greater than 1 std::vector<int> lastDims({input->rankOf() - 1});// = ShapeUtils::evalDimsToExclude(input->rankOf(), {input->rankOf() - 1}); auto pack = sd::ConstantTadHelper::getInstance().tadForDimensions(sortedVals.shapeInfo(), lastDims); SpecialMethods<T>::sortTadGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), lastDims.data(), lastDims.size(), pack.primaryShapeInfo(), pack.primaryOffsets(), reverse); ResultSet rows = sortedVals.allTensorsAlongDimension(lastDims); Nd4jLong oL = output->lengthOf(); auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) { auto row = rows.at(e); output->p(e, row->e<T>(n)); } }; samediff::Threads::parallel_for(func, 0, oL); } } void nthElementFunctor(sd::LaunchContext *launchContext, NDArray* input, Nd4jLong n, NDArray* output, bool reverse) { BUILD_SINGLE_SELECTOR(input->dataType(), nthElementFunctor_, (input, n, output, reverse), LIBND4J_TYPES); } BUILD_SINGLE_TEMPLATE(template void nthElementFunctor_, (NDArray* input, Nd4jLong n, NDArray* output, bool reverse), LIBND4J_TYPES); } } }