/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
//



#include <ops/declarable/helpers/transforms.h>
#include <helpers/ShapeUtils.h>
#include <numeric>
#include <helpers/Loops.h>

namespace sd 	  {
namespace ops 	  {
namespace helpers {


////////////////////////////////////////////////////////////////////////
template<typename X, typename Y>
static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) {

    const X* x = reinterpret_cast<X*>(input.buffer());
    const Y* y = reinterpret_cast<Y*>(indices.buffer());
          X* z = reinterpret_cast<X*>(output.buffer());

    const int xRank    = input.rankOf();
    const int yRank    = indices.rankOf();
    const int zRank    = output.rankOf();
    const int maxRank  = sd::math::nd4j_max<int>(yRank, sd::math::nd4j_max<int>(xRank, zRank));

    const Nd4jLong zLen = output.lengthOf();

    const uint yLastDim = indices.sizeAt(-1);

    const int diff = zRank - xRank;
    const bool bEqual = yLastDim == xRank;

    auto func = PRAGMA_THREADS_FOR {

        int xCoords[MAX_RANK], zCoords[MAX_RANK], temp;

        for (auto i = start; i < stop; i++) {

            shape::index2coordsCPU(start, i, output.shapeInfo(), zCoords);

            const auto zOffset = shape::getOffset(output.shapeInfo(), zCoords);

            temp = zCoords[yRank - 1];
            zCoords[yRank - 1] = 0;
            const auto yOffset = shape::getOffset(indices.shapeInfo(), zCoords);
            zCoords[yRank - 1] = temp;

            if(bEqual)
                memcpy(xCoords, zCoords, zRank * sizeof(int));
            else if(diff >= 0)
                memcpy(xCoords, zCoords + diff, xRank * sizeof(int));
            else
                memcpy(xCoords - diff, zCoords, zRank * sizeof(int));

            for (uint j = 0; j < yLastDim; ++j)
                xCoords[j] = y[yOffset + j * indices.stridesOf()[yRank - 1]];   // last stride

            const auto xOffset = shape::getOffset(input.shapeInfo(), xCoords);

            z[zOffset] = x[xOffset];
        }
    };

    samediff::Threads::parallel_tad(func, 0, zLen);
}

////////////////////////////////////////////////////////////////////////
void gatherND(sd::LaunchContext * context, NDArray& input, NDArray& indices, NDArray& output) {
    BUILD_DOUBLE_SELECTOR(input.dataType(), indices.dataType(), gatherND_, (input, indices, output), LIBND4J_TYPES, INDEXING_TYPES);
}


////////////////////////////////////////////////////////////////////////
template<typename T>
static void gather_(NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) {

    int axis = intArgs.size() > 0 ? intArgs[0] : 0;
    const int inputRank = input->rankOf();
    if(axis < 0)
        axis += inputRank;

    const int numOfIntArgs = intArgs.size();

    if (indices != nullptr) {

        for(Nd4jLong i = 0; i < indices->lengthOf(); ++i)
            if(indices->e<Nd4jLong>(i) >= input->sizeAt(axis))
                throw std::runtime_error("helpers::gather function: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array !");

        // first case: indices consist of only one scalar
        if(indices->isScalar()) {
            if(input->rankOf() <= 1){
                //For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar
				auto idx = indices->e<Nd4jLong>(0);
				auto scalarNDArray = input->e(idx);
                output->assign(scalarNDArray);
            } else {
                auto dimensions = ShapeUtils::evalDimsToExclude(input->rankOf(), {axis});
                auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(input->shapeInfo(), dimensions);

                auto tadArr = NDArray(reinterpret_cast<void *>(reinterpret_cast<T*>(input->buffer()) + tadPack.primaryOffsets()[indices->e<Nd4jLong>(0)]), tadPack.primaryShapeInfo(), output->getContext());
                output->assign(&tadArr);
			}
        }
        else if (input->rankOf() == 1 && indices->isVector()) {
            // special case
            auto func = PRAGMA_THREADS_FOR {
                for (auto e = start; e < stop; e++)
                    output->p(e, input->e<T>(indices->e<Nd4jLong>(e)));
            };

            samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }
        else {

            std::vector<int> dimsOut(indices->rankOf());
            std::iota(dimsOut.begin(), dimsOut.end(), axis);   // fill with axis, axis+1, ... indices->rankOf()-1
            const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), dimsOut);

            auto func = PRAGMA_THREADS_FOR {
                for (auto i = start; i < stop; i++) {
                    NDArray subArrOut = (*output)(i, dimsOut);
                    NDArray subArrIn = (*input)(indices->e<Nd4jLong>(i), {axis});
                    subArrOut.assign(subArrIn);
                }
            };

            samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
        }
    }
    else {

        for(int i = 1; i < numOfIntArgs; ++i)
            if(intArgs[i] >= input->sizeAt(axis))
                throw std::runtime_error("helpers::gather function: some of input indexes is larger than corresponding shape of input array !");

        // we only allow scalar/vector case here
        if (numOfIntArgs == 2) { // scalar case
            output->assign((*input)(intArgs[1], {axis}));
        }
        else { // vector case
            const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(output->shapeInfo(), {axis});

            auto func = PRAGMA_THREADS_FOR {
                for (auto i = start; i < stop; i++) {
                    NDArray subArrOut = (*output)(i, {axis});
                    NDArray subArrIn = (*input)(intArgs[i + 1], {axis});
                    subArrOut.assign(subArrIn);
                }
            };

            samediff::Threads::parallel_tad(func, 0, numOfSubArrs);
        }
    }
}

    void gather(NDArray* input, const NDArray* indices, NDArray* output, const std::vector<int>& intArgs) {
        BUILD_SINGLE_SELECTOR(input->dataType(), gather_, (input, indices, output, intArgs), LIBND4J_TYPES);
    }

}
}
}