/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // #include #include #include #include namespace sd { namespace ops { namespace helpers { //////////////////////////////////////////////////////////////////////// template static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) { const X* x = reinterpret_cast(input.buffer()); const Y* y = reinterpret_cast(indices.buffer()); X* z = reinterpret_cast(output.buffer()); const int xRank = input.rankOf(); const int yRank = indices.rankOf(); const int zRank = output.rankOf(); const int maxRank = sd::math::nd4j_max(yRank, sd::math::nd4j_max(xRank, zRank)); const Nd4jLong zLen = output.lengthOf(); const uint yLastDim = indices.sizeAt(-1); const int diff = zRank - xRank; const bool bEqual = yLastDim == xRank; auto func = PRAGMA_THREADS_FOR { int xCoords[MAX_RANK], zCoords[MAX_RANK], temp; for (auto i = start; i < stop; i++) { shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords); const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords); temp = zCoords[yRank - 1]; zCoords[yRank - 1] = 0; const auto yOffset = shape::getOffset(indices.shapeInfo(), zCoords); zCoords[yRank - 1] = temp; if(bEqual) memcpy(xCoords, zCoords, zRank * sizeof(int)); else if(diff >= 0) memcpy(xCoords, zCoords + diff, xRank * sizeof(int)); else memcpy(xCoords - diff, zCoords, zRank * sizeof(int)); for (uint j = 0; j < yLastDim; ++j) xCoords[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]]; // last stride const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords); z[zOffset] = x[xOffset]; } }; samediff::Threads::parallel_tad(func, 0, zLen); } //////////////////////////////////////////////////////////////////////// void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) { BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INDEXING_TYPES); } //////////////////////////////////////////////////////////////////////// template static void gather_(NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { int axis = intArgs.size() > 0 ? intArgs[0] : 0; const int inputRank = input->rankOf(); if(axis < 0) axis += inputRank; const int numOfIntArgs = intArgs.size(); if (indices != nullptr) { for(Nd4jLong i = 0; i < indices->lengthOf(); ++i) if(indices->e(i) >= input->sizeAt(axis)) throw std::runtime_error("helpers::gather function: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !"); // first case: indices consist of only one scalar if(indices->isScalar()) { if(input->rankOf() <= 1){ //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar auto idx = indices->e(0); auto scalarNDArray = input->e(idx); output->assign(scalarNDArray); } else { auto dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis}); auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions); auto tadArr = NDArray(reinterpret_cast(reinterpret_cast(input->buffer()) + tadPack.primaryOffsets()[indices->e(0)]), tadPack.primaryShapeInfo(), output->getContext()); output->assign(&tadArr); } } else if (input->rankOf() == 1 && indices->isVector()) { // special case auto func = PRAGMA_THREADS_FOR { for (auto e = start; e < stop; e++) output->p(e, input->e(indices->e(e))); }; samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } else { std::vector dimsOut(indices->rankOf()); std::iota(dimsOut.begin(), dimsOut.end(), axis); // fill with axis, axis+1, ... indices->rankOf()-1 const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), dimsOut); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) { NDArray subArrOut = (*output)(i, dimsOut); NDArray subArrIn = (*input)(indices->e(i), {axis}); subArrOut.assign(subArrIn); } }; samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } } else { for(int i = 1; i < numOfIntArgs; ++i) if(intArgs[i] >= input->sizeAt(axis)) throw std::runtime_error("helpers::gather function: some of input indexes is larger than corresponding shape of input array !"); // we only allow scalar/vector case here if (numOfIntArgs == 2) { // scalar case output->assign((*input)(intArgs[1], {axis})); } else { // vector case const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), {axis}); auto func = PRAGMA_THREADS_FOR { for (auto i = start; i < stop; i++) { NDArray subArrOut = (*output)(i, {axis}); NDArray subArrIn = (*input)(intArgs[i + 1], {axis}); subArrOut.assign(subArrIn); } }; samediff::Threads::parallel_tad(func, 0, numOfSubArrs); } } } void gather(NDArray* input, const NDArray* indices, NDArray* output, const std::vector& intArgs) { BUILD_SINGLE_SELECTOR(input->dataType(), gather_, (input, indices, output, intArgs), LIBND4J_TYPES); } } } }