/*******************************************************************************
 * Copyright (c) 2020 Konduit, K.K.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author GS <sgazeos@gmail.com>
//

#include <op_boilerplate.h>
#include <NDArray.h>
#include <execution/Threads.h>
#include <ConstantTadHelper.h>
#include "../triangular_solve.h"

namespace nd4j {
    namespace ops {
        namespace helpers {
            /*
             * lower triangular process for system of linear equations
             * x_1 = b_1/a_1,1
             * x_2 = (b_2 - a_2,1 * x_1) / a_2,2
             * x_3 = (b_3 - a_3,1 * x_1 - a_3,2 * x_2) / a_3,3
             * ...
             * x_M = (b_M - a_M,1 * x_1 - ... a_M,M-1 * x_M-1)/ a_M,M
             *
             * output == x
             * a == leftInput
             * b == rightInput
             *
             * */
            template <typename T>
            static __device__ void lowerTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
                                                        T const* rightInput, Nd4jLong const* rightInputShape,
                                                        bool const adjoint, T* output, Nd4jLong* outputShape,
                                                        Nd4jLong rows) {

                for (auto r = 0; r < rows; r++) {
                    Nd4jLong posY[] = {r, 0};
                    Nd4jLong posX[] = {r, r};
                    auto xIndex = shape::getOffset(leftInputShape, posX, 0);
                    auto yIndex = shape::getOffset(rightInputShape, posY, 0);
                    auto zIndex = shape::getOffset(outputShape, posY, 0);

                    auto sum = rightInput[yIndex];
                    for (auto c = 0; c < r; c++) {
                        Nd4jLong posZ[] = {c, 0};
                        Nd4jLong pos[] = {r, c};
                        auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
                        auto zcIndex = shape::getOffset(outputShape, posZ, 0);
                        sum -= leftInput[xcIndex] * output[zcIndex];
                    }
                    output[zIndex] = sum / leftInput[xIndex];
                }
            }

            /*
             * upper triangular process for system of linear equations
             * x_M = b_M/a_M,M
             * x_M-1 = (b_M-1 - a_M-1,M-2 * x_M) / a_M-1,M-1
             * x_M-2 = (b_M-2 - a_M-2,M-3 * x_M-2 - a_M-2,M-1 * x_M) / a_3,3
             * ...
             * x_1 = (b_1 - a_1,2 * x_2 - ... a_1,M * x_M)/ a_1,1
             *
             * output == x
             * a == leftInput
             * b == rightInput
             *
             * */

            template <typename T>
            static __device__ void upperTriangularSolve(T const* leftInput, Nd4jLong const* leftInputShape,
                    T const* rightInput, Nd4jLong const* rightInputShape, bool const adjoint, T* output,
                    Nd4jLong* outputShape, Nd4jLong rows) {

                for (auto r = rows; r > 0; r--) {
                    Nd4jLong posY[] = {r - 1, 0};
                    Nd4jLong posX[] = {r - 1, r - 1};
                    auto xIndex = shape::getOffset(leftInputShape, posX, 0);
                    auto yIndex = shape::getOffset(rightInputShape, posY, 0);
                    auto zIndex = shape::getOffset(outputShape, posY, 0);
                    auto sum = rightInput[yIndex];
                    for (auto c = r; c < rows; c++) {
                        Nd4jLong posZ[] = {c, 0};
                        Nd4jLong pos[] = {r - 1, c};
                        auto zcIndex = shape::getOffset(outputShape, posZ, 0);
                        auto xcIndex = shape::getOffset(leftInputShape, pos, 0);
                        sum -= leftInput[xcIndex] * output[zcIndex];
                    }
                    output[zIndex] = sum / leftInput[xIndex];
                }
            }

            template <typename T>
            static __global__ void triangularSolveKernel(T const* leftInput, Nd4jLong const* leftPartShape,
                    T const* rightInput, Nd4jLong const* rightPartShape, bool const lower, bool const adjoint, T* output,
                    Nd4jLong* outputShape, Nd4jLong* tadLeftShape, Nd4jLong* tadLeftOffset, Nd4jLong* tadRightShape,
                    Nd4jLong* tadRightOffset, Nd4jLong* tadOutputShape, Nd4jLong* tadOutputOffset, Nd4jLong batchNum) {

                __shared__ Nd4jLong rows;
                if (threadIdx.x == 0) {
                    rows = shape::sizeAt(leftPartShape, -2);
                }
                __syncthreads();

                auto start = blockIdx.x * blockDim.x + threadIdx.x;
                auto stop = batchNum;
                auto increment = blockDim.x * gridDim.x;

                for (auto i = start; i < stop; i += increment) {
                    auto pLeftPart = leftInput + tadLeftOffset[i];
                    auto pRightPart = rightInput + tadRightOffset[i];
                    auto pOutputPart = output + tadOutputOffset[i];
                    if (lower) {
                        lowerTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows);
                    } else {
                        upperTriangularSolve<T>(pLeftPart, tadLeftShape, pRightPart, tadRightShape, adjoint, pOutputPart, tadOutputShape, rows);
                    }
                }
            }

            template <typename T>
            static int triangularSolveFunctor_(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput,
                    bool lower, bool adjoint, NDArray* output) {
                NDArray::prepareSpecialUse({output}, {leftInput, rightInput});
                auto leftTads = ConstantTadHelper::getInstance()->tadForDimensions(leftInput->getShapeInfo(), {-2, -1});
                auto rightTads = ConstantTadHelper::getInstance()->tadForDimensions(rightInput->getShapeInfo(), {-2, -1});
                auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), {-2, -1});

                auto stream = context->getCudaStream();
                T const* leftBuf = reinterpret_cast<T const*>(leftInput->getSpecialBuffer());
                T const* rightBuf = reinterpret_cast<T const*>(rightInput->getSpecialBuffer());
                T* outputBuf = reinterpret_cast<T*>(output->specialBuffer());
                triangularSolveKernel<T><<<128, 128, 256, *stream>>>(leftBuf, leftInput->getSpecialShapeInfo(),
                        rightBuf, rightInput->getSpecialShapeInfo(), lower, adjoint, outputBuf, output->specialShapeInfo(),
                        leftTads.specialShapeInfo(), leftTads.specialOffsets(), rightTads.specialShapeInfo(),
                        rightTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets(),
                        leftTads.numberOfTads());

                NDArray::registerSpecialUse({output}, {leftInput, rightInput});

                return Status::OK();

            }

            int triangularSolveFunctor(nd4j::LaunchContext * context, NDArray* leftInput, NDArray* rightInput, bool lower, bool adjoint, NDArray* output) {
                BUILD_SINGLE_SELECTOR(leftInput->dataType(), return triangularSolveFunctor_, (context, leftInput, rightInput, lower, adjoint, output), FLOAT_NATIVE);
            }

            template <typename T>
            static __global__ void upperAdjointKernel(T const* input, T* output,
                    Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns,
                    Nd4jLong* inputTads, Nd4jLong* inputOffsets, Nd4jLong* outputTads, Nd4jLong* outputOffsets) {

                for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) {
                    auto inputPart = input + inputOffsets[b];
                    auto outputPart = output + outputOffsets[b];
                    for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
                        for (auto c = threadIdx.y; c <= r; c += blockDim.y) {
                            Nd4jLong zPos[] = {r, c};
                            Nd4jLong xPos[] = {c, r};
                            auto zIndex = shape::getOffset(outputTads, zPos);
                            auto xIndex = shape::getOffset(inputTads, xPos);
                            outputPart[zIndex] = inputPart[xIndex];
                        }
                    }
                }

            }

            template <typename T>
            static __global__ void lowerAdjointKernel(T const* input, T* output,
                         Nd4jLong batchSize, Nd4jLong rows, Nd4jLong columns,
                         Nd4jLong* inputTads, Nd4jLong* inputOffsets, Nd4jLong* outputTads, Nd4jLong* outputOffsets) {

                for (auto b = blockIdx.x; b < batchSize; b += gridDim.x) {
                    auto inputPart = input + inputOffsets[b];
                    auto outputPart = output + outputOffsets[b];
                    for (auto r = threadIdx.x; r < rows; r += blockDim.x) {
                        for (auto c = r + threadIdx.y; c < columns; c += blockDim.y) {
                            Nd4jLong zPos[] = {r, c};
                            Nd4jLong xPos[] = {c, r};
                            auto zIndex = shape::getOffset(outputTads, zPos);
                            auto xIndex = shape::getOffset(inputTads, xPos);
                            outputPart[zIndex] = inputPart[xIndex];
                        }
                    }
                }
            }

            template <typename T>
            static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower,
                    NDArray* output) {

                auto inputTads = ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), {-2, -1});
                auto outputTads = ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), {-2, -1});
                auto stream = context->getCudaStream();
                auto inputBuf = reinterpret_cast<T const*>(input->getSpecialBuffer());
                auto outputBuf = reinterpret_cast<T*>(output->specialBuffer());
                auto rows = input->sizeAt(-2);
                auto columns = input->sizeAt(-1);

                if (lower) {
                    lowerAdjointKernel<T><<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets());
                } else {
                    upperAdjointKernel<T><<<128, 256, 256, *stream>>>(inputBuf, outputBuf, outputTads.numberOfTads(), rows, columns, inputTads.specialShapeInfo(), inputTads.specialOffsets(), outputTads.specialShapeInfo(), outputTads.specialOffsets());
                }
            }

            void adjointMatrix(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) {
                BUILD_SINGLE_SELECTOR(input->dataType(), adjointTriangularMatrix_, (context, input, lower, output), FLOAT_NATIVE);
            }

        }
    }
}