/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 * Copyright (c) 2019 Konduit K.K.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author GS <sgazeos@gmail.com>
//

#include <ops/declarable/helpers/segment.h>
#include <helpers/ShapeUtils.h>
#include <execution/Threads.h>
#include <unordered_map>

namespace sd {
namespace ops {
namespace helpers {

    // segment max
    template <typename T>
    static void segmentMaxFunctor_(NDArray* input, NDArray* indices, NDArray* output) {
        //int numClasses = output->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        Nd4jLong idx = indices->e<Nd4jLong>(0);
        if (input->isVector()) {
            T val = input->e<T>(0);

            for (Nd4jLong e = 1; e < indices->lengthOf(); e++) {
                if (idx == indices->e<Nd4jLong>(e)) {
                   // max
                   val = sd::math::nd4j_max<T>(val, input->t<T>(e));
                }
                else {
                    idx = indices->e<Nd4jLong>(e);
                    val = input->t<T>(e);
                }
                output->t<T>(idx) = val;
            }
        }
        else {
            std::vector<int> restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});
            auto listOfTensors = input->allTensorsAlongDimension(restDims);
            auto listOfOutTensors = output->allTensorsAlongDimension(restDims);

            auto numOfClasses = output->sizeAt(0); // number of classes
            std::vector<std::pair<NDArray*, int>> outputs(numOfClasses);
            auto maxT = listOfOutTensors.at(idx);

            //int pos = 0;
            maxT->assign(listOfTensors.at(0));

            for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
                if (indices->e<int>(i) == idx) {

                    for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) {
                       maxT->t<T>(e) = sd::math::nd4j_max(maxT->t<T>(e), listOfTensors.at(i)->t<T>(e));
                    }
                }
                else {
                    idx = indices->e<Nd4jLong>(i);
                    maxT = listOfOutTensors.at(idx);
                    maxT->assign(listOfTensors.at(i));
                }

            }
        }
    }

    // segmen min
    template <typename T>
    static void segmentMinFunctor_(NDArray* input, NDArray* indices, NDArray* output) {
        //int numClasses = output->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        Nd4jLong idx = indices->e<Nd4jLong>(0);
        if (input->isVector()) {
            T val = input->e<T>(0);

            for (Nd4jLong e = 1; e < indices->lengthOf(); e++) {
                if (idx == indices->e<Nd4jLong>(e)) {
                   // min
                   val = sd::math::nd4j_min<T>(val, input->t<T>(e));
                }
                else {
                    idx = indices->e<Nd4jLong>(e);
                    val = input->t<T>(e);
                }
                output->t<T>(idx) = val;
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            int numOfClasses = output->sizeAt(0); // number of classes
            std::vector<std::pair<NDArray*, int>> outputs(numOfClasses);
            auto minT = listOfOutTensors.at(idx);

            int pos = 0;
            minT->assign(listOfTensors.at(0));

            for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
                if (indices->e<Nd4jLong>(i) == idx) {

                    for (Nd4jLong e = 0; e < minT->lengthOf(); e++) {
                       minT->p(e, sd::math::nd4j_min(minT->e<T>(e), listOfTensors.at(i)->e<T>(e)));
                    }
                }
                else {
                    idx = indices->e<Nd4jLong>(i);
                    minT = listOfOutTensors.at(idx);
                    minT->assign(listOfTensors.at(i));
                }
            }
        }
    }

    // segmen mean
    template <typename T>
    static void segmentMeanFunctor_(NDArray* input, NDArray* indices, NDArray* output) {
        int numClasses = output->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        int idx = indices->e<int>(0);
        if (input->isVector()) {
            T val = T(0.f);
            int count = 0;

            for (Nd4jLong e = 0; e < indices->lengthOf(); e++) {
                if (idx == indices->e<int>(e)) {
                   // mean
                   val += input->e<T>(e);
                   count++;
                }
                else {
                   output->p<T>(idx, val / count);
                    idx = indices->e<int>(e);
                    val = input->e<T>(e);
                    count = 1;
                }
                output->p<T>(idx, val / count);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            auto listOfTensors = input->allTensorsAlongDimension(restDims);
            auto listOfOutTensors = output->allTensorsAlongDimension(restDims);

            int numOfClasses = output->sizeAt(0); // number of classes
            std::vector<std::pair<NDArray*, int>> outputs(numOfClasses);
            auto meanT = listOfOutTensors.at(idx);
            int count = 1;
            auto meanV = meanT->dup();
            meanV.assign(listOfTensors.at(0));

            for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
                if (indices->e<int>(i) == idx) {
                    auto func = PRAGMA_THREADS_FOR {
                        for (auto e = start; e < stop; e++) {
                            meanV.p<T>(e, meanV.e<T>(e) + listOfTensors.at(i)->e<T>(e));
                        }
                    };
                    samediff::Threads::parallel_for(func, 0, meanT->lengthOf());

                    count++;
                }
                else {
                    //meanT->assign(meanV);
                    meanV.applyScalar(scalar::Divide, count, *meanT);
                    idx = indices->e<int>(i);
                    meanT = listOfOutTensors.at(idx);
                    meanV.assign(listOfTensors.at(i));
                    count = 1;
                }
                meanV.applyScalar(scalar::Divide, count, *meanT);
            }
        }
    }

    template <typename T>
    static void segmentSumFunctor_(NDArray* input, NDArray* indices, NDArray* output) {
        int numClasses = output->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        int idx = indices->e<int>(0);
        if (input->isVector()) {
            T val = T(0.f);
            int count = 0;
            for (Nd4jLong e = 0; e < indices->lengthOf(); e++) {
                if (idx == indices->e<int>(e)) {
                   // sum
                   val += input->t<T>(e);
                }
                else {
                    idx = indices->e<int>(e);
                    val = input->t<T>(e);
                }
                output->p(idx, val);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            auto listOfTensors = input->allTensorsAlongDimension(restDims);
            auto listOfOutTensors = output->allTensorsAlongDimension(restDims);

            int numOfClasses = output->sizeAt(0); // number of classes
            std::vector<std::pair<NDArray*, int>> outputs(numOfClasses);
            auto sumT = listOfOutTensors.at(idx);

            for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                if (indices->e<int>(i) == idx) {
                    auto func = PRAGMA_THREADS_FOR {
                        for (auto e = start; e < stop; e++) {
                            sumT->p(e, sumT->e<T>(e) + listOfTensors.at(i)->e<T>(e));
                        }
                    };
                    samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
                }
                else {
                    idx = indices->e<int>(i);
                    sumT = listOfOutTensors.at(idx);
                    sumT->assign(listOfTensors.at(i));
                }
            }
        }
    }

    template <typename T>
    static void segmentProdFunctor_(NDArray* input, NDArray* indices, NDArray* output) {
        //int numClasses = output->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        int idx = indices->e<int>(0);
        output->assign(1.f);
        if (input->isVector()) {
            T val = input->e<T>(0);
            int count = 0;

            for (Nd4jLong e = 1; e < indices->lengthOf(); e++) {
                if (idx == indices->e<int>(e)) {
                   // sum
                   val *= input->e<T>(e);
                }
                else {
                    idx = indices->e<int>(e);
                    val = input->e<T>(e);
                }
                output->p(idx, val);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            auto listOfTensors = input->allTensorsAlongDimension(restDims);
            auto listOfOutTensors = output->allTensorsAlongDimension(restDims);

            int numOfClasses = output->sizeAt(0); // number of classes
            auto sumT = listOfOutTensors.at(idx);
            sumT->assign(listOfTensors.at(0));
            for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
                if (indices->e<int>(i)  == idx) {
                    auto func = PRAGMA_THREADS_FOR {
                        for (auto e = start; e < stop; e++) {
                            sumT->p(e, sumT->e<T>(e) * listOfTensors.at(i)->e<T>(e));
                        }
                    };
                    samediff::Threads::parallel_for(func, 0, sumT->lengthOf());
                }
                else {
                    idx = indices->e<int>(i);
                    sumT = listOfOutTensors.at(idx);
                    sumT->assign(listOfTensors.at(i));
                }
            }
        }
    }

//    template <typename T>
//    static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, NDArray& anOutput) {
//      }

    void segmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), segmentMaxFunctor_, (input, indices, output), LIBND4J_TYPES);
    }

    void segmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), segmentMinFunctor_, (input, indices, output), LIBND4J_TYPES);
    }

    void segmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), segmentMeanFunctor_, (input, indices, output), LIBND4J_TYPES);
    }

    void segmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), segmentSumFunctor_, (input, indices, output), LIBND4J_TYPES);
    }

    void segmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), segmentProdFunctor_, (input, indices, output), LIBND4J_TYPES);
    }

    bool segmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, NDArray& expected, NDArray& output) {
        auto val = indices->e(0);
        for (Nd4jLong e = 1; e < indices->lengthOf(); e++) {
            output = indices->e(e);
            if (val.e<Nd4jLong>(0) > output.e<Nd4jLong>(0))
                return false;
            val = indices->e(e);
        }

        return true;
    }

    //BUILD_SINGLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, NDArray&, NDArray&), LIBND4J_TYPES);
    BUILD_SINGLE_TEMPLATE(template void segmentProdFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES);
    BUILD_SINGLE_TEMPLATE(template void segmentSumFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES);
    BUILD_SINGLE_TEMPLATE(template void segmentMeanFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES);
    BUILD_SINGLE_TEMPLATE(template void segmentMinFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES);
    BUILD_SINGLE_TEMPLATE(template void segmentMaxFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES);
    // -------------------------------------------------------------------------------------------------------------- //
    // Unsorted segment ops
    // -------------------------------------------------------------------------------------------------------------- //

    bool unsortedSegmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, Nd4jLong expected, Nd4jLong& output) {
        Nd4jLong val = indices->e<Nd4jLong>(0);

        Nd4jLong maxInd = indices->argMax();
        if (indices->e<Nd4jLong>(maxInd) >= expected) {
                output = val;
                return false;
        }
        output = expected;
        return true;
    }

    template <typename T>
    static void unsortedSegmentMaxFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {

        // if input is a vector: (as if in doc sample)
        //int idx = static_cast<int>((*indices)(0.));
        MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
            idxs[indices->e<Nd4jLong>(e)].push_back(e);

        //std::sort(idxs.begin(), idxs.end());

        if (input->isVector()) { // 1D case
            T maxVal = DataTypeUtils::max<T>();
            output->assign(-maxVal);

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                T val = input->e<T>(fi->second.at(0));
                for (Nd4jLong idx = 1; idx < static_cast<Nd4jLong>(fi->second.size()); ++idx) {
                    val = sd::math::nd4j_max(val, input->e<T>(fi->second.at(idx)));
                }
                output->p(fi->first, val);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            T maxVal = DataTypeUtils::max<T>();
            output->assign(-maxVal);

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                auto outputT = listOfOutTensors.at(fi->first);
                outputT->assign(listOfTensors.at(fi->second.at(0)));
                for (Nd4jLong idx = 1; idx < static_cast<Nd4jLong>(fi->second.size()); ++idx) {
                    auto maxT = listOfTensors.at(fi->second.at(idx));
                    for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) {
                        T val = sd::math::nd4j_max(maxT->e<T>(e), outputT->e<T>(e));

                        outputT->p(e, val);
                    }
                }
            }
        }
    }
    void unsortedSegmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMaxFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES);
    }
    BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES);

    template <typename T>
    static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        // if input is a vector: (as if in doc sample)
        //int idx = static_cast<int>((*indices)(0.));
        MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());

        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
            idxs[indices->e<Nd4jLong>(e)].push_back(e);

        //std::sort(idxs.begin(), idxs.end());

        if (input->isVector()) { // 1D case
            T maxVal = DataTypeUtils::max<T>();
            output->assign(maxVal);

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                T val = input->t<T>(fi->second.at(0));

                for (size_t idx = 1; idx < fi->second.size(); ++idx) {
                    val = sd::math::nd4j_min(val, input->t<T>(fi->second.at(idx)));
                }
                output->t<T>(fi->first) = val;
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            T maxVal = DataTypeUtils::max<T>();
            output->assign(maxVal);

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                auto outputT = listOfOutTensors.at(fi->first);
                outputT->assign(listOfTensors.at(fi->second.at(0)));
                for (size_t idx = 1; idx < fi->second.size(); ++idx) {
                    auto minT = listOfTensors.at(fi->second.at(idx));

                    for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) {
                        outputT->t<T>(e) = sd::math::nd4j_min(minT->t<T>(e), outputT->t<T>(e));
                    }
                }
                //outputT->assign(maxT);
            }
        }

    }
    void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, (input, indices, numOfClasses, output),
                              NUMERIC_TYPES);
    }

    BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES);

    void unsortedSegmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
            idxs[indices->e<Nd4jLong>(e)].push_back(e);

        //std::sort(idxs.begin(), idxs.end());

        if (input->isVector()) { // 1D case

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                double sumValue = input->e<double>(fi->second.at(0));
                int loop_size = fi->second.size();

                // FIXME: parallelism here?
                for (size_t idx = 1; idx < loop_size; ++idx) {
                    sumValue += input->e<double>(fi->second.at(idx));
                }

                output->p(fi->first, sumValue / fi->second.size());
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            // FIXME: parallelism here?
            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                auto outputT = listOfOutTensors.at(fi->first);
                outputT->assign(listOfTensors.at(fi->second.at(0)));
                Nd4jLong loopSize = fi->second.size();

                for (Nd4jLong idx = 1; idx < loopSize; ++idx) {
                    auto current = listOfTensors.at(fi->second.at(idx));
                    *outputT += *current;
                }
                (*outputT) /= double(fi->second.size());
            }
        }
    }

    void unsortedSegmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
            idxs[indices->e<Nd4jLong>(e)].push_back(e);

        if (input->isVector()) { // 1D case

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                double sumValue = input->e<double>(fi->second.at(0));
                Nd4jLong loop_size = fi->second.size();

                // FIXME: parallelism here?
                for (Nd4jLong idx = 1; idx < loop_size; ++idx) {
                    sumValue += input->e<double>(fi->second.at(idx));
                }
                output->p(fi->first, sumValue);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                auto outputT = listOfOutTensors.at(fi->first);
                outputT->assign(listOfTensors.at(fi->second.at(0)));
                Nd4jLong loop_size = fi->second.size();

                // FIXME: parallelism here?
                for (Nd4jLong idx = 1; idx < loop_size; ++idx) {
                    auto current = listOfTensors.at(fi->second.at(idx));
                    *(outputT) += *current;
                }
                //outputT->assign(maxT);
            }
        }
    }

    template <typename T>
    void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
            idxs[indices->e<Nd4jLong>(e)].push_back(e);

        //std::sort(idxs.begin(), idxs.end());

        output->assign(1.f);

        if (input->isVector()) { // 1D case
            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                T prodValue = input->e<T>(fi->second.at(0));
                for (size_t idx = 1; idx < fi->second.size(); ++idx) {
                    prodValue *= input->e<T>(fi->second.at(idx));
                }
                output->p(fi->first, prodValue);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                auto outputT = listOfOutTensors.at(fi->first);
                outputT->assign(listOfTensors.at(fi->second.at(0)));
                for (size_t idx = 1; idx < fi->second.size(); ++idx) {
                    auto current = listOfTensors.at(fi->second.at(idx));

                    *outputT *= *current;
                }
            }
        }
    }

    void unsortedSegmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentProdFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES);
    }
    BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES);

    void unsortedSegmentSqrtNFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
        MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
            idxs[indices->e<Nd4jLong>(e)].push_back(e);

        //std::sort(idxs.begin(), idxs.end());

        if (input->isVector()) { // 1D case
            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                double sumValue = input->e<double>(fi->second.at(0));
                for (size_t idx = 1; idx < fi->second.size(); ++idx) {
                    sumValue += input->e<double>(fi->second.at(idx));
                }
                output->p(fi->first, sumValue / sd::math::nd4j_sqrt<Nd4jLong, double>(fi->second.size()));
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) {
                auto outputT = listOfOutTensors.at(fi->first);
                outputT->assign(listOfTensors.at(fi->second.at(0)));
                for (size_t idx = 1; idx < fi->second.size(); ++idx) {
                    auto current = listOfTensors.at(fi->second.at(idx));
                    *outputT += *current;
                }
                //outputT->assign(maxT);
                (*outputT) /= sd::math::nd4j_sqrt<size_t, double>(fi->second.size());
            }
        }
    }

    // -------------------------------------------------------------------------------------------------------------- //
    // Backpropagate ops helpers
    // -------------------------------------------------------------------------------------------------------------- //
    // Sorted backpropagate ops
    //
    // segment max
    template <typename T>
    int segmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
        //int numOfClasses = gradOut->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        auto tempRes = gradOut->dup();
        segmentMaxFunctor_<T>(input, indices, &tempRes);
        if (input->isVector()) {
            Nd4jLong loop_size = input->lengthOf();

            auto func = PRAGMA_THREADS_FOR {
                for (auto e = start; e < stop; e++) {
                    auto classNum = indices->e<Nd4jLong>(e);
                    if (sd::math::nd4j_abs(tempRes.e<T>(classNum) - input->e<T>(e)) <= T(1.e-6))
                        output->p(e, gradOut->e<T>(classNum));
                }
            };
            samediff::Threads::parallel_for(func, 0, loop_size);
        }
        else {
            std::vector<int> restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims);
            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            //int numOfClasses = tempRes.sizeAt(0); // number of classes
            //std::vector<std::pair<NDArray*, int>> outputs(numOfClasses);

            auto func = PRAGMA_THREADS_FOR {
                for (auto i = start; i < stop; i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);

                    for (Nd4jLong e = 0; e < current->lengthOf(); e++) {
                        if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e<T>(e) - current->e<T>(e)) <= T(1.e-6))
                            currentOut->p(e, currentGradOut->e<T>(e));
                    }
                }
            };

            samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
        }

        return ND4J_STATUS_OK;
    }

    int segmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
        BUILD_SINGLE_SELECTOR(output->dataType(), return segmentMaxFunctorBP_, (context, input, indices, gradOut, output), NUMERIC_TYPES);
    }
    BUILD_SINGLE_TEMPLATE(template int segmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES);

    // segmen min
    int segmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
        NDArray tempRes = gradOut->dup();
        segmentMinFunctor(context, input, indices, &tempRes);
        if (input->isVector()) {
            auto func = PRAGMA_THREADS_FOR {
                for (auto e = start; e < stop; e++) {
                    auto classNum = indices->e<Nd4jLong>(e);
                    if (sd::math::nd4j_abs(tempRes.e<double>(classNum) - input->e<double>(e)) < 1.e-5)
                        output->p(e, gradOut->e<double>(classNum));
                }
            };
            samediff::Threads::parallel_for(func, 0, input->lengthOf());
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims);
            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            //int numOfClasses = tempRes.sizeAt(0); // number of classes
            //std::vector<std::pair<NDArray*, int>> outputs(numOfClasses);
            output->assign(0.);
            int pos = 0;

            auto func = PRAGMA_THREADS_FOR {
                for (auto i = start; i < stop; i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);

                    for (Nd4jLong e = 0; e < current->lengthOf(); e++) {
                        if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e<double>(e) - current->e<double>(e)) <
                            1.e-5)
                            currentOut->p(e, currentGradOut->e<double>(e));
                    }
                }
            };

            samediff::Threads::parallel_tad(func, 0, indices->lengthOf());
        }
        return ND4J_STATUS_OK;
    }

    // segmen mean
    int segmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
        int numClasses = output->sizeAt(0);
        MAP_IMPL<Nd4jLong, Nd4jLong> classCount;//(numClasses);

        for (Nd4jLong count = 0; count < numClasses; ++count) {
            classCount[count] = 0;
        }

        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
            classCount[indices->e<Nd4jLong>(e)] ++;
        }

        // if input is a vector: (as if in doc sample)
        if (input->isVector()) {
            for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
                Nd4jLong classNum = indices->e<Nd4jLong>(e);
                output->p(e, gradOut->e<double>(classNum) / classCount[classNum]);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);
;

            int pos = 0;
            //auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, uint64_t increment) -> void {
                for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);

                    for (Nd4jLong e = 0; e < current->lengthOf(); e++) {
                        currentOut->p(e, currentGradOut->e<double>(e) / classCount.at(classNum));
                    }
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }
        return ND4J_STATUS_OK;
    }

    int segmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
//        int numClasses = output->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        Nd4jLong idx = indices->e<Nd4jLong>(0);
        if (input->isVector()) {
            for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
                Nd4jLong classNum = indices->e<Nd4jLong>(e);
                output->p(e, gradOut->e<double>(classNum));
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            //auto func = PRAGMA_THREADS_FOR {
                for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);

                    currentOut->assign(currentGradOut);
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }
        return Status::OK();
    }

    int segmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
        auto tempRes = gradOut->dup();
        segmentProdFunctor(context, input, indices, &tempRes);
        if (input->isVector()) {
            for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
                Nd4jLong classNum = indices->e<Nd4jLong>(e);
                output->p(e, gradOut->e<double>(classNum) * tempRes.e<double>(classNum)/ input->e<double>(e));
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims);
            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            //int numOfClasses = tempRes.sizeAt(0); // number of classes
            //std::vector<std::pair<NDArray*, int>> outputs(numOfClasses);

            //auto func = PRAGMA_THREADS_FOR {
                for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);
                    auto currentFFOut = listOfBPTensors.at(classNum);

                    currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current));
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }

        return ND4J_STATUS_OK;
    }

    // -------------------------------------------------------------------------------------------------------------- //
    // Unsorted backpropagate segment ops
    // -------------------------------------------------------------------------------------------------------------- //

    template <typename T>
    static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
//        int numOfClasses = gradOut->sizeAt(0);
        // if input is a vector: (as if in doc sample)
        auto tempRes = gradOut->dup();
        unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes);
        if (input->isVector()) {

            for (Nd4jLong e = 0; e < input->lengthOf(); ++e) {
                Nd4jLong classNum = indices->e<Nd4jLong>(e);
                if (sd::math::nd4j_abs(tempRes.e<double>(classNum) - input->e<double>(e)) < 1.e-5)
                    output->p(e, gradOut->e<T>(classNum));
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims);
            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                Nd4jLong classNum = indices->e<Nd4jLong>(i);
                NDArray* current = listOfTensors.at(i);
                NDArray* currentOut = listOfOutTensors.at(i);
                NDArray* currentGradOut = listOfGradOuts.at(classNum);
                for (int e = 0; e < current->lengthOf(); e++) {
                    if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e<double>(e) - current->e<double>(e)) < 1.e-5)
                        currentOut->p(e, currentGradOut->e<T>(e));
                }
            }
        }

        return ND4J_STATUS_OK;
    }

    int unsortedSegmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
        BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES);
    }
    BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES);

    template <typename T>
    static int unsortedSegmentMinFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
        auto tempRes = gradOut->dup();
        unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes);
        if (input->isVector()) {

            auto func = PRAGMA_THREADS_FOR {
                for (auto e = start; e < stop; e++) {
                    auto classNum = indices->e<Nd4jLong>(e);
                    if (sd::math::nd4j_abs(tempRes.t<T>(classNum) - input->t<T>(e)) < 1.e-6)
                        output->t<T>(e) = gradOut->t<T>(classNum);
                }
            };

            samediff::Threads::parallel_for(func, 0, input->lengthOf());
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims);
            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            //auto func = PRAGMA_THREADS_FOR {
                for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);

                    for (Nd4jLong e = 0; e < current->lengthOf(); e++) {
                        if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->t<T>(e) - current->t<T>(e)) < 1.e-6)
                            currentOut->t<T>(e) = currentGradOut->t<T>(e);
                    }
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }

        return ND4J_STATUS_OK;
    }

    int unsortedSegmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
        BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES);
    }
    BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES);

    int unsortedSegmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {

        MAP_IMPL<Nd4jLong, Nd4jLong> classCount;//(numClasses);

        for (Nd4jLong count = 0; count < numOfClasses; ++count) {
            classCount[count] = 0;
        }

        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
            classCount[indices->e<Nd4jLong>(e)]++;
        }

        // if input is a vector: (as if in doc sample)
        if (input->isVector()) {
            for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
                Nd4jLong classNum = indices->e<Nd4jLong>(e);
                output->p(e, gradOut->e<double>(classNum) / classCount[classNum]);
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                Nd4jLong classNum = indices->e<Nd4jLong>(i);
                NDArray* current = listOfTensors.at(i);
                NDArray* currentOut = listOfOutTensors.at(i);
                NDArray* currentGradOut = listOfGradOuts.at(classNum);
                currentOut->assign(*currentGradOut / double(classCount[classNum]));
            }
        }
        return ND4J_STATUS_OK;
    }

    int unsortedSegmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {

        // if input is a vector: (as if in doc sample)
        Nd4jLong idx = indices->e<Nd4jLong>(0);
        if (input->isVector()) {
            for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
                Nd4jLong classNum = indices->e<Nd4jLong>(e);
                output->p(e, gradOut->e<double>(classNum));
            }
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            //auto func = PRAGMA_THREADS_FOR {
                for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);

                    currentOut->assign(currentGradOut);
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }
        return Status::OK();
    }

    int unsortedSegmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {

        auto tempRes = gradOut->dup();
        unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes);
        if (input->isVector()) {
            auto func = PRAGMA_THREADS_FOR {
                for (auto e = start; e < stop; e++) {
                    auto classNum = indices->e<Nd4jLong>(e);
                    output->p<double>(e, gradOut->e<double>(classNum) * tempRes.e<double>(classNum) / input->e<double>(e));
                }
            };

            samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims);
            ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors = input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims);

            //auto func = PRAGMA_THREADS_FOR {
                for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);
                    auto currentFFOut = listOfBPTensors.at(classNum);

                    currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current));
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }

        return Status::OK();
    }

//    template <typename T>
    int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
            MAP_IMPL<Nd4jLong, Nd4jLong> classCount;//(numClasses);

        for (Nd4jLong count = 0; count < numOfClasses; ++count) {
            classCount[count] = 0;
        }

        for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) {
            classCount[indices->e<Nd4jLong>(e)]++;
        }

        // if input is a vector: (as if in doc sample)
        if (input->isVector()) {
            //auto func = PRAGMA_THREADS_FOR {
                for (Nd4jLong e = 0; e < indices->lengthOf(); e++) {
                    auto classNum = indices->e<Nd4jLong>(e);
                    output->p(e, gradOut->e<double>(classNum) / sd::math::nd4j_sqrt<double, double>(classCount[classNum]));
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }
        else {
            auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0});

            ResultSet listOfGradOuts  =gradOut->allTensorsAlongDimension(restDims);
            ResultSet listOfTensors  =input->allTensorsAlongDimension(restDims);
            ResultSet listOfOutTensors  =output->allTensorsAlongDimension(restDims);

            //auto func = PRAGMA_THREADS_FOR {
                for (Nd4jLong i = 0; i < indices->lengthOf(); i++) {
                    auto classNum = indices->e<Nd4jLong>(i);
                    auto current = listOfTensors.at(i);
                    auto currentOut = listOfOutTensors.at(i);
                    auto currentGradOut = listOfGradOuts.at(classNum);

                    for (int e = 0; e < current->lengthOf(); e++) {
                        currentOut->p<double>(e, currentGradOut->e<double>(e) / sd::math::nd4j_sqrt<double, double>(classCount[classNum]));
                    }
                }
            //};

            //samediff::Threads::parallel_for(func, 0, indices->lengthOf());
        }
        return Status::OK();
    }

}
}
}