/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author raver119@gmail.com
//  @author Yurii Shyrma (iuriish@yahoo.com)
//

#include<ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/transforms.h>
#include<array>

namespace sd  {
namespace ops {


//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {

    REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");

    const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

    const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();

    // first of all take into account possible presence of empty arrays
    // also if scalar is present -> copy its value to vector with length=1
    std::vector<const NDArray*> nonEmptyArrs;
    std::vector<int> arrsToDelete;
    int index = 0;
    bool allOfSameType = true;
    auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
    auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();

    for(int i = 0; i < numOfInArrs; ++i) {
        auto input = INPUT_VARIABLE(i);
        auto currentRank = input->rankOf();

// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy
//        REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
//        REQUIRE_TRUE(rankOfFirstArr == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, rankOfFirstArr);
        if(!input->isEmpty()) {

            allOfSameType &= (typeOfFirstArr == input->dataType());

            if(input->rankOf() == 0) {
                auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
                vec->assign(input);
                nonEmptyArrs.push_back(vec);
                arrsToDelete.push_back(index);
            }
            else{
                nonEmptyArrs.push_back(input);
            }
            ++index;
        }
    }

    const int numOfNonEmptyArrs = nonEmptyArrs.size();

    if(numOfNonEmptyArrs == 0){
        //All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
        REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty");
        return Status::OK();
    }

    const int rank = nonEmptyArrs[0]->rankOf();                     //  look up to first non-empty array
    int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
    if(axis < 0){
        axis += rank;
    }

    // ******** input validation ******** //
    REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
    REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT op: output array should have the same type as inputs arrays !");
    REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);

    for(int i = 1; i < numOfNonEmptyArrs; ++i)
        REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !");

    for(int i = 1; i < numOfNonEmptyArrs; ++i) {
        for(int dim = 0; dim < rank; ++dim)
            if(dim != axis)
                REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
    }
    // ******** end of input validation ******** //

    auto output = OUTPUT_VARIABLE(0);

    if(numOfNonEmptyArrs == 1)
        output->assign(nonEmptyArrs[0]);
    else
        helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis);

    // delete dynamically allocated vectors with length=1
    for(int index : arrsToDelete)
        delete nonEmptyArrs[index];

    return Status::OK();
}

    DECLARE_SYN(ParallelConcat, concat);
    DECLARE_SYN(concat_v2, concat);
    DECLARE_SYN(concatv2, concat);

        DECLARE_TYPES(concat) {
            getOpDescriptor()
                    ->setAllowedInputTypes(sd::DataType::ANY);
                    // ->setSameMode(true);
        }

//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(concat) {

    REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");

    const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

    const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();

    // first of all take into account possible presence of empty arrays
    // also if scalar is present -> use the shape of vector with length=1 instead
    ShapeList arrShapes;
    std::vector<int> shapesToDelete;
    int index = 0;
    for(int i = 0; i < numOfInArrs; ++i) {

        if(inputShape->at(i)[0] == 0) {
            if (shape::isEmpty(inputShape->at(i)))
                arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType()));
            else
                arrShapes.push_back(ConstantShapeHelper::getInstance().vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
        }
        else{
            arrShapes.push_back(inputShape->at(i));
        }
        ++index;
    }

    const int numOfNonEmptyArrs = arrShapes.size();

    const int rank = shape::rank(arrShapes.at(0));

    int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
    if(axis < 0){
        axis += rank;
    }

    // ******** input validation ******** //
    REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);

    for(int i = 1; i < numOfNonEmptyArrs; ++i)
        REQUIRE_TRUE(shape::rank(arrShapes.at(i)) == rank, 0, "CONCAT op: all input arrays must have the same rank !");

    for(int i = 1; i < numOfNonEmptyArrs; ++i) {
        for(int dim = 0; dim < rank; ++dim)
            if(dim != axis)
                REQUIRE_TRUE(arrShapes.at(i)[dim+1] == arrShapes.at(0)[dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
    }
    // ******** end of input validation ******** //


    Nd4jLong* outShapeInfo(nullptr);
    COPY_SHAPE(arrShapes.at(0), outShapeInfo);

    // case when we have only one input array
    if(numOfNonEmptyArrs == 1) {
        ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0)));
        return SHAPELIST(CONSTANT(outShapeInfo));
    }

    for(int i = 1; i < numOfNonEmptyArrs; ++i)
        outShapeInfo[axis + 1] += arrShapes.at(i)[axis + 1];

    ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0)));

    // delete dynamically allocated vectors shapes with length=1
//    for(int index : shapesToDelete)
//        RELEASE(arrShapes[index], block.getWorkspace());

    auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(outShapeInfo));
    RELEASE(outShapeInfo, block.getWorkspace());
    return SHAPELIST(result);
}


        // //////////////////////////////////////////////////////////////////////////
        // CUSTOM_OP_IMPL(concat, -1, 1, false, 0, -2){
        //     // do something here{
        //     NDArray<T> *last = INPUT_VARIABLE((int) block.width() - 1);

        //     int _dimension = 0;
        //     if (block.numI() > 0)
        //         _dimension = INT_ARG(0);
        //     else {
        //         _dimension = (int) last->e(0);
        //     }

        //     // we want to ensure that all
        //     NDArray<T> *first = nullptr;
        //     auto output = OUTPUT_VARIABLE(0);

        //     int elements = 0;

        //     for (int e = 0; e < block.width(); e++) {
        //         auto arr = INPUT_VARIABLE(e);
        //         if (!arr->isEmpty())
        //             elements++;

        //         // we must find first non-empty element here
        //         if (!arr->isEmpty() && first == nullptr)
        //             first = arr;
        //     }

        //     REQUIRE_TRUE(first != nullptr, 0, "Concat: at least 1 non-empty input required!");

        //     // it's possible to get into situation when your input has only 1 input. That's just assign
        //     if (elements == 1) {
        //         output->assign(first);
        //         return Status::OK();
        //     }

        //     bool oldScalars = first->rankOf() == 2 && first->isScalar();

        //     auto buffers = new Nd4jPointer[elements];
        //     auto shapes = new Nd4jPointer[elements];

        //     buffers[0] = (Nd4jPointer) first->buffer();
        //     shapes[0] = (Nd4jPointer) first->shapeInfo();

        //     if (_dimension < 0)
        //         _dimension += first->rankOf();

        //     if (sd::Environment::getInstance().isDebugAndVerbose()) {
        //         printf("Shape %i: ", 0);
        //         shape::printShapeInfoLinear((Nd4jLong *) shapes[0]);
        //     }

        //     int er = 0;
        //     for (int e = 0; e < block.width(); e++) {
        //         Variable<T> *var = block.variable(e);
        //         auto array = var->getNDArray();

        //         if (array->isEmpty())
        //             continue;

        //         buffers[er] = reinterpret_cast<Nd4jPointer>(array->buffer());
        //         shapes[er++] = reinterpret_cast<Nd4jPointer>(array->shapeInfo());

        //         oldScalars &= array->rankOf() == 2 && array->isScalar();

        //         if (sd::Environment::getInstance().isDebugAndVerbose()) {
        //             printf("Shape %i: ", e);
        //             shape::printShapeInfoLinear(array->shapeInfo());
        //         }
        //     }
        //     if (sd::Environment::getInstance().isDebugAndVerbose())
        //         fflush(stdout);

        //     if (oldScalars) {
        //         nd4j_debug("OLD_SCALARS!\n","");
        //         _dimension = 1;
        //     }

        //     sd::SpecialMethods<T>::concatCpuGeneric(_dimension, elements, buffers, shapes, output->buffer(), output->shapeInfo());

        //     STORE_RESULT(*output);

        //     if (sd::Environment::getInstance().isDebugAndVerbose())
        //         output->printShapeInfo("Concat result shape");

        //     delete[] buffers;
        //     delete[] shapes;

        //     return ND4J_STATUS_OK;
        // }

        // DECLARE_SYN(ParallelConcat, concat);
        // DECLARE_SYN(concat_v2, concat);
        // DECLARE_SYN(concatv2, concat);

        // DECLARE_SHAPE_FN(concat) {
        //     auto inp = inputShape->at(0);
        //     int _dimension = INT_ARG(0);

        //     NDArray<T>* first = nullptr;
        //     auto last = inputShape->at(inputShape->size() - 1);

        //     Nd4jLong elements = 0;
        //     Nd4jLong *newShape;

        //     for (int  e = 0; e < inputShape->size(); e++) {
        //         auto s = INPUT_VARIABLE(e);

        //         if (!s->isEmpty()) {
        //             elements++;

        //             if (first == nullptr)
        //                 first = s;
        //         }
        //     }


        //     { // special cases for 0D concat
        //         bool allScalars = true;
        //         bool hasScalars = false;
        //         for (int e = 0; e < block.width(); e++) {
        //             auto c = INPUT_VARIABLE(e);

        //             if (c->isEmpty())
        //                 continue;

        //             allScalars &= c->rankOf() == 0;
        //             hasScalars |= c->rankOf() == 0;
        //         }

        //         // all scalars
        //         if (allScalars) {
        //             ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);

        //             shape::shapeBuffer(1, &elements, newShape);
        //             return SHAPELIST(newShape);
        //         }

        //         // any scalar
        //         if (hasScalars) {
        //             ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
        //             Nd4jLong length = shape::length(inp);
        //             for (int i = 1; i < block.width(); i++) {
        //                 auto c = INPUT_VARIABLE(i);
        //                 if (c->isEmpty())
        //                     continue;

        //                 length += c->lengthOf();
        //             }

        //             shape::shapeBuffer(1, &length, newShape);
        //             return SHAPELIST(newShape);
        //         }
        //     }


        //     ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong);

        //     if (_dimension < 0)
        //         _dimension += first->rankOf();

        //     std::memcpy(newShape, first->shapeInfo(), shape::shapeInfoByteLength(first->shapeInfo()));
        //     for (int i = 0; i < inputShape->size(); i++) {
        //         auto s = INPUT_VARIABLE(i);

        //         // FIXME: s == first is bad, but fast. alternatively we can subtract first size out of result
        //         if (s->isEmpty() || s == first)
        //             continue;

        //         newShape[_dimension + 1] += shape::shapeOf(inputShape->at(i))[_dimension];
        //     }

        //     shape::updateStrides(newShape, first->ordering());

        //     return SHAPELIST(newShape);
        // }

//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) {

    const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

    const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();

    auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1);

    auto first = INPUT_VARIABLE(0);

    const int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf());

    int startPos = 0;

    for (int e = 0; e < numOfInArrs - 1; e++) {
        auto originalChunk = INPUT_VARIABLE(e);
        auto epsilonChunk = OUTPUT_VARIABLE(e);
        std::vector<Nd4jLong> indices(2 * epsilonNext->rankOf());

        int width = originalChunk->sizeAt(axis);

        for (int e = 0; e < epsilonNext->rankOf(); e++) {
            if (e == axis)
                indices[2*e + 1] = (indices[2*e] = startPos) + width;
            else
                indices[2*e + 1] = indices[2*e] = 0;
        }

        auto subarray = (*epsilonNext)(indices, true);
        epsilonChunk->assign(subarray);

        startPos += width;
    }

    return ND4J_STATUS_OK;
}

DECLARE_TYPES(concat_bp) {
    getOpDescriptor()
            ->setAllowedInputTypes(sd::DataType::ANY)
            ->setAllowedOutputTypes({ALL_FLOATS});
}

DECLARE_SHAPE_FN(concat_bp) {

    const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);

    const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();

    auto shapeList = SHAPELIST();

    for (int e = 0; e < numOfInArrs - 1; e++) {
        auto inShape = inputShape->at(e);
        shapeList->push_back(ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
    }

    return shapeList;
}


}
}