// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017
//  @author Yurii Shyrma, created on 05.12.2017

#include <array/NDArrayFactory.h>
#include <helpers/MmulHelper.h>
#include <execution/Threads.h>

namespace sd    {
namespace ops     {
namespace helpers {

static FORCEINLINE NDArray activation(const NDArray& arr) {

    // return (const_cast<NDArray<T>&>(arr)).template transform<simdOps::Tanh<T>>();
    auto result = NDArray(&arr, false, arr.getContext());
    (const_cast<NDArray&>(arr)).applyTransform(transform::Tanh, result);
    return result;

static FORCEINLINE NDArray sigmoid(const NDArray& arr) {
    return (const_cast<NDArray&>(arr)).transform(transform::Sigmoid);

void sruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) {

    // x   input [bS x inSize], bS - batch size, inSize - number of features
    // c0  previous cell state c  [bS x inSize], that is at previous time step t-1
    // w   weights [inSize x 3*inSize]
    // b   biases [2*inSize]

    // h   current cell output [bS x inSize], that is at current time step t
    // c   current cell state  [bS x inSize], that is at current time step t

    const int inSize = x->sizeAt(1);           // inSize - number of features

    auto z = mmul(*x, *w);               //  [bS x 3*inSize]

    // forget gate = sigmoid(x*Wf + bf)
    auto f = sigmoid(z({0,0, inSize,   2*inSize}) + (*b)({0, inSize}));

    // reset gate = sigmoid(x*Wr + br)
    auto r = sigmoid(z({0,0, 2*inSize, 3*inSize}) + (*b)({inSize, 2*inSize}));

    // ◦ means element-wise product or so called Hadamard product
    // current sell state = f◦c0 + (1 - f)◦(x*Wc)
    c->assign(f * (*c0) + (1.f - f) * z({0, 0 ,0, inSize}) );
    // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}});

    // current cell output = r◦activation(c) + (1 - r)◦x
    h->assign( r * activation(*c) + (1.f - r) * (*x) );
    // *h = r * (activation<T>(c) - *x) + *x;

void sruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) {

    // x   input [bS x inSize x time]
    // c0  initial cell state  (at time step = 0) [bS x inSize],
    // w   weights, [3*inSize x inSize]
    // b   biases,  [2*inSize]

    // h   cell outputs [bS x inSize x time]
    // c   cell states  [bS x inSize x time]

    auto wT = w->transpose();                             // [3*inSize x inSize] -> [inSize x 3*inSize]

    const int time  = x->sizeAt(2);

    NDArray ct_1(*c0);

    // loop through time steps
    for (int t = 0; t < time; ++t) {

        auto xt = (*x)({0,0, 0,0, t,t+1});
        auto ht = (*h)({0,0, 0,0, t,t+1});
        auto ct = (*c)({0,0, 0,0, t,t+1});

        helpers::sruCell(context, &xt, &ct_1, &wT, b,  &ht, &ct);

template <typename T>
static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {

    // x     input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features
    // w     2d tensor of weights [2*K x 6*K]
    // b     row of biases with twice length [4*K]
    // c0    2d tensor of initial state [bS x 2*K] at time t=0
    // mask  optional, 2d tensor of dropout mask [bS x 2*K]

    // ht  [time x bS x 2*K]
    // ct  [time x bS x 2*K]

    const Nd4jLong time = x->sizeAt(0);                     // time - number of time steps
    const Nd4jLong bS   = x->sizeAt(1);                     // bS - batch size
    const Nd4jLong K    = x->sizeAt(2) / 2;                 // K - number of features

    //  x = x * mask
        x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x);             // apply mask

    // U = x * w
    NDArray wi = mmul(*x, *w);                    //  U [time x bS x 6*K]

    const Nd4jLong d2      = 2*K;
    const Nd4jLong ncols   = bS*d2;
    const Nd4jLong ncolsWi = 3*ncols;

    T* pI    = x->bufferAsT<T>();
    T* pWi   = wi.bufferAsT<T>();
    T* pBias = const_cast<NDArray*>(b)->bufferAsT<T>();
    T* pInit = const_cast<NDArray*>(c0)->bufferAsT<T>();
    T* pMask = mask ? const_cast<NDArray*>(mask)->bufferAsT<T>() : nullptr;
    T* pHt   = ht->bufferAsT<T>();
    T* pCt   = ct->bufferAsT<T>();

    auto func = PRAGMA_THREADS_FOR {
        for (auto col = start; col < stop; col++) {
            const auto colNum = col % d2;
            bool flip = colNum >= K;
            T maskVal = mask ? *(pMask + col) : T(1);
            T cur = *(pInit + col);
            T bF = *(pBias + colNum);
            T bR = *(pBias + colNum + d2);
            T *pWiVal = pWi + 3 * col;
            T *pIVal = pI + col;
            T *pHtVal = pHt + col;
            T *pCtVal = pCt + col;

            if (flip) {
                const auto step = (time - 1) * ncols;
                pIVal += step;
                pHtVal += step;
                pCtVal += step;
                pWiVal += (time - 1) * ncolsWi;

            auto ncolsRev = flip ? -ncols : ncols;
            auto ncolsWiRev = flip ? -ncolsWi : ncolsWi;

            for (Nd4jLong t = 0; t < time; ++t) {
                // evaluate sigmoids
                T ft = (1.) / (1. + sd::math::nd4j_exp<T, T>(-(pWiVal[1] + bF)));
                T rt = (1.) / (1. + sd::math::nd4j_exp<T, T>(-(pWiVal[2] + bR)));

                cur = (cur - *pWiVal) * ft + *pWiVal;
                *pCtVal = cur;
                T val = sd::math::nd4j_tanh<T, T>(cur);
                *pHtVal = (val * maskVal - *pIVal) * rt + *pIVal;

                pIVal += ncolsRev;
                pWiVal += ncolsWiRev;
                pCtVal += ncolsRev;
                pHtVal += ncolsRev;

    samediff::Threads::parallel_tad(func, 0, ncols);

template <typename T>
static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask,
                     NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {

    // x  input 3d tensor [time x bS x 2*K], time - number of time steps, bS - batch size, K - number of features
    // w  2d tensor of weights [2*K x 6*K]
    // b  row of biases with twice length 4*K]
    // c0 2d tensor of initial state [bS x 2*K] at time t=0
    // ct [time x bS x 2*K]
    // inGradC0 [bS x 2*K]
    // inGradHt  [time x bS x 2*K]
    // mask optional,  2d tensor of dropout mask [bS x 2*K]

    // gradI  [time x bS x 2*K]
    // gradW  [time x 2*K x 6*K]
    // gradB  [4*K]
    // gradC0 [bS x 2*K]

    const Nd4jLong time   = x->sizeAt(0);                     // time - number of time steps
    const Nd4jLong bS     = x->sizeAt(1);
    const Nd4jLong K = x->sizeAt(2) / 2;

    //  x = x * mask
        x->applyBroadcast(broadcast::Multiply, {1, 2}, *mask, *x);             // apply mask

    // U = x * w
    NDArray wi = mmul(*x, *w);                    //  [time x bS x 2*K] * [2*K x 6*K] = [time x bS x 6*K]
    NDArray gradBias(x->ordering(), {bS, 4*K}, x->dataType(), x->getContext());
    NDArray gradWi  (x->ordering(), {time, bS, 6*K}, x->dataType(), x->getContext());

    const Nd4jLong d2      = 2*K;
    const Nd4jLong ncols   = bS*d2;
    const Nd4jLong ncolsWi = 3*ncols;
    T* pInput     = x->bufferAsT<T>();
    T* pWi        = wi.bufferAsT<T>();
    T* pBias      = const_cast<NDArray*>(b)->bufferAsT<T>();
    T* pInit      = const_cast<NDArray*>(c0)->bufferAsT<T>();
    T* pMask      = mask ? const_cast<NDArray*>(mask)->bufferAsT<T>() : nullptr;
    T* pState     = const_cast<NDArray*>(ct)->bufferAsT<T>();
    T* pInGradCt  = const_cast<NDArray*>(inGradC0)->bufferAsT<T>();
    T* pInGradHt  = const_cast<NDArray*>(inGradHt)->bufferAsT<T>();
    T* pGradWi    = gradWi.bufferAsT<T>();
    T* pGradInput = gradI->bufferAsT<T>();
    T* pGradBias  = gradBias.bufferAsT<T>();
    T* pGradInit  = gradC0->bufferAsT<T>();

    auto func = PRAGMA_THREADS_FOR {
        for (auto col = start; col < stop; col++) {
            T gbF = 0.f;
            T gbR = 0.f;
            const auto colNum = col % d2;
            const bool flip = colNum >= K;
            T maskVal = mask ? *(pMask + col) : T(1.);
            T cur = *(pInGradCt + col);
            T bF = *(pBias + colNum);
            T bR = *(pBias + colNum + d2);
            T *pWiVal = pWi + 3 * col;
            T *pInputVal = pInput + col;
            T *pStateVal = pState + col;
            T *pInGradHtVal = pInGradHt + col;
            T *pGradWiVal = pGradWi + 3 * col;
            T *pGradInputVal = pGradInput + col;

            if (!flip) {
                const auto stepI = (time - 1) * ncols;
                const auto stepW = (time - 1) * ncolsWi;
                pInputVal += stepI;
                pStateVal += stepI;
                pInGradHtVal += stepI;
                pGradInputVal += stepI;
                pWiVal += stepW;
                pGradWiVal += stepW;

            Nd4jLong ncolsRev = flip ? -ncols : ncols;
            Nd4jLong ncolsWiRev = flip ? -ncolsWi : ncolsWi;

            for (Nd4jLong t = 0; t < time; ++t) {
                // evaluate sigmoids
                T ft = ((T) 1.) / ((T) 1. + sd::math::nd4j_exp<T, T>(-(*(pWiVal + 1) + bF)));
                T rt = ((T) 1.) / ((T) 1. + sd::math::nd4j_exp<T, T>(-(*(pWiVal + 2) + bR)));

                T val = sd::math::nd4j_tanh<T, T>(*pStateVal);
                T prevVal = (t < time - 1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col));
                // grad wrt input
                *pGradInputVal = *pInGradHtVal - (*pInGradHtVal) * rt;
                // grad wrt rt, wiR and bR
                T grt = (*pInGradHtVal) * (val * maskVal - *pInputVal) * (rt - rt * rt);
                *(pGradWiVal + 2) = grt;
                gbR += grt;
                // grad wrt state
                T gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt * val * val) + cur;
                // grad wrt wi0
                *pGradWiVal = gradSateVal - gradSateVal * ft;
                // grad wrt ft, wi1, and bF
                T gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft * ft);
                *(pGradWiVal + 1) = gft;
                gbF += gft;
                // grad wrt c_previous
                cur = gradSateVal * ft;

                pInputVal -= ncolsRev;
                pWiVal -= ncolsWiRev;
                pStateVal -= ncolsRev;
                pGradWiVal -= ncolsWiRev;
                pGradInputVal -= ncolsRev;
                pInGradHtVal -= ncolsRev;
            *(pGradBias + col) = gbF;
            *(pGradBias + col + ncols) = gbR;
            *(pGradInit + col) = cur;

    samediff::Threads::parallel_tad(func, 0, ncols);

    // gradB
    gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0});    // [4*K]

    // gradW
    x->permutei({0, 2, 1});                                            // [time x bS x 2*K] -> [time x 2*K x bS]
    MmulHelper::mmul(x, &gradWi, gradW, 1., 0.);                       // [time x 2*K x bS ] * [time x bS x 6*K] = [time x 2*K x 6*K]

void sruBI(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
    BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), FLOAT_TYPES);
void sruBIBP(sd::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
    BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBP_, (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), FLOAT_TYPES);

BUILD_SINGLE_TEMPLATE(template void sruBI_,   (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0), FLOAT_TYPES);


