/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include #include #include namespace sd { namespace ops { namespace helpers { template static int topKFunctor_(const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { Nd4jLong width = input->sizeAt(-1); int lastDim = input->rankOf() - 1; // ----------------------------------------------------------------------------------------------- // // this assumption is right: // if (values->lengthOf() != k * lastDimList->size()) { // nd4j_printf("top_k: something is wrong. %i expected, but %i given.\n", // values->lengthOf(), k * lastDimList->size()); // } // ----------------------------------------------------------------------------------------------- // std::vector dimsToExclude(input->rankOf() - 1); for (size_t d = 0; d < dimsToExclude.size(); ++d) dimsToExclude[d] = d; const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input->shapeInfo(), dimsToExclude); if (k == 1) { for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { auto trial = (*input)(e, dimsToExclude); //int maxPos = //lastDimList->at(e)->argMax(); Nd4jLong maxPos = 0; //trial.printIndexedBuffer("TRIAL:"); T maxVal = trial.e(0); for (Nd4jLong pos = 1; pos < trial.lengthOf(); pos++) if (maxVal < trial.e(pos)) { maxPos = pos; maxVal = trial.e(pos); } if (indices) indices->p(e, maxPos); //topIndex; if (values) values->p(e, maxVal); } } else { int nextPos = 0; for (Nd4jLong e = 0; e < numOfSubArrs; ++e) { auto trial = (*input)(e, dimsToExclude); // fill up the first k elements NDArray topValues = NDArrayFactory::create('c', {k}, input->getContext()); NDArray sortedVals = NDArrayFactory::create('c', {k}, input->getContext()); NDArray topIndices = NDArrayFactory::create('c', {k}, input->getContext()); for (uint pos = 0; pos < k; ++pos) { topIndices.r(pos) = pos; topValues.r(pos) = trial.t(pos); } //std::vector sortedVals(topValues); sortedVals.assign(topValues);// = NDArrayFactory::create('c', {k}); //std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending order SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), false); for (Nd4jLong i = static_cast(k); i < width; ++i) { T val = trial.e(i); T minTopVal = sortedVals.t(0); if (minTopVal < val) { // value should be inserted to top k // only if it is not contained in T* begin = reinterpret_cast(sortedVals.buffer()); T* end = begin + k; bool exists = std::binary_search(begin, end, val); if (!exists) { //exchangePos - a distance between begin and minimal existed to be suppressed by val T* topBegin = reinterpret_cast(topValues.buffer()); T* topEnd = topBegin + k; auto exchangePos = std::distance(topBegin, std::find(topBegin, topEnd, sortedVals.t(0))); topValues.r(exchangePos) = val; //*exchangeIt = val; topIndices.r(exchangePos) = i; sortedVals.r(0) = val; // suppress in sorted //std::sort(sortedVals.begin(), sortedVals.end()); // sorted in ascending order SpecialMethods::sortGeneric(sortedVals.buffer(), sortedVals.shapeInfo(), false); } } } if (needSort) { SpecialMethods::sortGeneric(topValues.buffer(), topValues.shapeInfo(), true); for (Nd4jLong j = 0; j < width; j++) for (uint pos = 0; pos < k; ++pos) if (topValues.t(pos) == trial.t(j)) topIndices.r(pos) = j; } else { // else sort by indices std::map sortValsMap; //std::vector> data(topValues.lengthOf()); for (Nd4jLong e = 0; e < topValues.lengthOf(); ++e) { sortValsMap[topIndices.t(e)] = topValues.t(e); } //std::sort(data.begin(), data.end(), [](std::pair const& a, std::pair const& b) { // return a.first < b.first; //}); Nd4jLong e = 0; for (auto it = sortValsMap.begin(); it != sortValsMap.end(); ++it, e++) { topIndices.r(e) = it->first; topValues.r(e) = it->second; } } if (values) (*values)(e, dimsToExclude).assign(topValues); if (indices) (*indices)(e, dimsToExclude).assign(topIndices); } //indices->printIndexedBuffer("Indices as is"); } return Status::OK(); } // ----------------------------------------------------------------------------------------------- // template static int inTopKFunctor_(sd::LaunchContext* context, const NDArray* input, const NDArray* target, NDArray* result, const uint k) { std::vector shapeI(input->rankOf()); for (int i = 0; i < input->rankOf() - 1; i++) shapeI[i] = input->sizeAt(i); shapeI[input->rankOf() - 1] = k; std::unique_ptr indices(NDArrayFactory::create_(input->ordering(), shapeI, context)); NDArray* values = nullptr; int status = topKFunctor(context, input, values, indices.get(), k, true); result->assign(0); if (status == ND4J_STATUS_OK) { auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) { bool found = false; for (uint j = 0; j < k; j++) { if (target->e(e) == indices->e(e * k + j)) { found = true; break; } } if (found) result->p(e, true); } }; samediff::Threads::parallel_tad(func, 0, target->lengthOf()); } return status; } int topKFunctor(sd::LaunchContext * context, const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort) { BUILD_SINGLE_SELECTOR(input->dataType(), return topKFunctor_, (input, values, indices, k, needSort), NUMERIC_TYPES); } int inTopKFunctor(sd::LaunchContext * context, const NDArray* input, const NDArray* target, NDArray* result, const uint k) { BUILD_SINGLE_SELECTOR(input->dataType(), return inTopKFunctor_, (context, input, target, result, k), NUMERIC_TYPES); } BUILD_SINGLE_TEMPLATE(template int topKFunctor_, (const NDArray* input, NDArray* values, NDArray* indices, const uint k, bool needSort), NUMERIC_TYPES); BUILD_SINGLE_TEMPLATE(template int inTopKFunctor_, (sd::LaunchContext * context, const NDArray* input, const NDArray* target, NDArray* result, const uint k), NUMERIC_TYPES); } } }