/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author Yurii Shyrma (iuriish@yahoo.com)
//

// implementation of operation for LSTM cell with peep hole connections:
// http://www.bioinf.jku.at/publications/older/2604.pdf
// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
// and
// https://research.google.com/pubs/archive/43905.pdf
// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.


#include <ops/declarable/helpers/lstmLayer.h>
#include <helpers/ShapeUtils.h>
// #include <VariableSpace.h>
// #include <ops/declarable/CustomOperations.h>
// #include<ops/declarable/helpers/transforms.h>
// #include <ops/declarable/helpers/legacy_helpers.h>
// #include <array/NDArrayList.h>
// #include <iterator>
// #include <MmulHelper.h>

namespace nd4j 	  {
namespace ops 	  {
namespace helpers {


//////////////////////////////////////////////////////////////////////////
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
                   const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
                   const std::vector<float>& params,
                   NDArray* h, NDArray* c) {


    /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
    /** the objective is to provide math-readable code **/

    // equations (no peephole connections)
    // it  = σ(Wxi * xt  +  Wri * ht-1  +  bi)
    // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  bf)
    // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
    // ct  = ft ◦ ct-1 + it ◦ c't
    // ot  = σ(Wxo * xt  +  Wro * ht-1  +  bo)
    // ht  = ot ◦ tanh(ct)

    // equations (peephole connections are present)
    // it  = σ(Wxi * xt  +  Wri * ht-1  +  Wpi ◦ ct-1  +  bi)
    // ft  = σ(Wxf * xt  +  Wrf * ht-1  +  Wpf ◦ ct-1  +  bf)
    // c't = tanh(Wxc * xt  +  Wrc * ht-1  +  bc)
    // ct  = ft ◦ ct-1 + it ◦ c't
    // ot  = σ(Wxo * xt  +  Wro * ht-1  +  Wpo ◦ ct   +  bo)
    // ht  = ot ◦ tanh(ct)


    // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus

    // params[0] - dataFormat, ignore
    // params[1] - directionMode, ignore
    // params[2] - cell clipping value, if it = 0 then do not apply clipping

    // params[3]  - activation ID for input (i), forget (f) and output (o) gates
    // params[4]  - alpha value for gates activation
    // params[5]  - beta value for gates activation

    // params[6]  - activation ID for cell state (c)
    // params[7]  - alpha value for cell state activation
    // params[8]  - beta value for cell state activation

    // params[9]  - activation ID for output (h)
    // params[10] - alpha value for output activation
    // params[11] - beta value for output activation

    // INPUTS:
    // x  - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
    // Wx - input weights [nIn, 4*nOut]
    // Wr - recurrent weights [nOut, 4*nOut]
    // b  - biases [4*nOut], optional, may be nullptr
    // hI - previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
    // cI - previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
    // Wp - peephole weights [3*nOut], optional, may be nullptr

    // OUTPUTS:
    // h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
    // c - current cell state,  that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr

    // !!! dimension 4*nOut implies order it, ft, c't, ot
    // !!! dimension 3*nOut implies order it, ft, ot

    const Nd4jLong nOut = Wx->sizeAt(-1) / 4;

    auto z = mmul(*x, *Wx) + mmul(*hI, *Wr);    //   [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
                                                //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]

    // add biases if they are given
    if(b != nullptr)
        z += *b;                                    // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut]

    auto zi = x->rankOf() == 1 ? z({0,        nOut}) : z({0,0, 0,        nOut});         // input gate it, [bS, nOut]
    auto zf = x->rankOf() == 1 ? z({nOut,   2*nOut}) : z({0,0, nOut,   2*nOut});         // forget gate ft, [bS, nOut]
    auto zc = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut});         // cell gate c't, [bS, nOut]
    auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut});         // output gate ot, [bS, nOut]

    // peephole connections for input and forget gates
    if(Wp != nullptr) {
        zi += *cI * (*Wp)({0,      nOut});      // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut]
        zf += *cI * (*Wp)({nOut, 2*nOut});      // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut]
    }

    applyActivation(zi, params[3], params[4], params[5], zi);   // inplace
    applyActivation(zf, params[3], params[4], params[5], zf);   // inplace
    applyActivation(zc, params[6], params[7], params[8], zc);   // inplace

    c->assign(zf * *cI + zi * zc);          // [bS, nOut] ◦ [bS, nOut] + [bS, nOut] ◦ [bS, nOut] = [bS, nOut]

    // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
    if(params[2] != 0)
        c->applyScalar(scalar::LstmClip, params[2], *c);

    // peephole connections for output gate
    if(Wp != nullptr)
        zo += *c * (*Wp)({2*nOut, 3*nOut});    // broadcast: [bS, nOut] + [nOut] ◦ [bS, nOut] = [bS, nOut]

    applyActivation(zo, params[3], params[4], params[5], zo);

    applyActivation(*c, params[9], params[10], params[11], *h);
    *h *= zo;                               // [bS, nOut] ◦ [bS, nOut]
}



//////////////////////////////////////////////////////////////////////////
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
                       const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
                       const std::vector<float>& params,
                       const bool forward,
                       NDArray* h, NDArray* hL, NDArray* cL) {

    // INPUTS:
    // x  - current input  [sL, bS, nIn],  [bS, sL, nIn],  [bS, nIn, sL],
    // Wx - input weights [nIn, 4*nOut]
    // Wr - recurrent weights [nOut, 4*nOut]
    // b  - biases [4*nOut], optional, may be nullptr
    // seqLen - [bS], optional, may be nullptr
    // hI - initial output  [bS, nOut], optional, may be nullptr
    // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
    // Wp - peephole weights [3*nOut], optional, may be nullptr

    // OUTPUTS:
    // h - output [sL, bS, nOut],  [bS, sL, nOut],  [bS, nOut, sL], optional, may be nullptr
    // hL - output at last step [bS, nOut], optional, may be nullptr
    // cL - cell state at last step [bS, nOut], optional, may be nullptr

    // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
    // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]

    const int dataFormat    = params[0];
    const int directionMode = params[1];

    const Nd4jLong sL   = x->sizeAt(dataFormat);
    const Nd4jLong bS   = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
    const Nd4jLong nOut = Wx->sizeAt(-1) / 4;

    const std::vector<Nd4jLong> shapeOut = {bS, nOut};

    auto h0 = const_cast<NDArray*>(hI);
    if(!hI) {
        h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
        h0->nullify();
    }

    auto c0 = const_cast<NDArray*>(cI);
    if(!cI) {
        c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
        c0->nullify();
    }

    auto ct = cL;
    if(!cL)
        cL = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());

    auto ht = hL;
    if(!h && !hL)
        ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());

    // create sets of required (depends on seqLen presence) sub-arrays
    std::vector<int> dims;
    ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr);

    if(!seqLen) {

        dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0});    // points on bS and nIn/nOut axes

        xSet = new ResultSet(x->allTensorsAlongDimension(dims));       // sub-arrays with shape [bS, nIn]
        if(h)
            hSet = new ResultSet(h->allTensorsAlongDimension(dims));   // sub-arrays with shape [bS, nOut]
    }
    else {

        dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2});    // points on nIn/nOut axis

        xSet  = new ResultSet(x->allTensorsAlongDimension(dims));               //  sub-arrays with shape [nIn]
        h0Set = new ResultSet(h0->allTensorsAlongDimension({1}));              //  sub-arrays with shape [nOut]
        c0Set = new ResultSet(c0->allTensorsAlongDimension({1}));              //  sub-arrays with shape [nOut]
        ctSet = new ResultSet(ct->allTensorsAlongDimension({1}));              //  sub-arrays with shape [nOut]
        if(h)
            hSet = new ResultSet(h->allTensorsAlongDimension(dims));            //  sub-arrays with shape [nOut]
        if(ht)
            htSet = new ResultSet(ht->allTensorsAlongDimension({1}));          //  sub-arrays with shape [nOut]
    }

    // loops
    if(forward) {

        if(!seqLen) {

            if(!h) {    // seqLen and h are absent

                lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
                for (int t = 1; t < sL; ++t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
            }
            else {      // seqLen is absent and h is present

                lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step
                for (int t = 1; t < sL; ++t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps

                if(hL)
                    hL->assign(hSet->at(sL - 1));     // assign last output to hL if it is not nullptr
            }
        }
        else {

            if(!h) {    // seqLen is present and h is absent

                for (int e = 0; e < bS; ++e) {

                    const int limit = seqLen->e<int>(e);

                    if(limit == 0) {
                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();
                        continue;
                    }

                    auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
                    lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step

                    for (int t = 1; t < limit; ++t) {
                        ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
                    }
                }
            }
            else { // seqLen and h are present

                for (int e = 0; e < bS; ++e) {

                    int limit = seqLen->e<int>(e);

                    if(limit == 0) {

                        tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range

                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();

                        continue;
                    }

                    auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
                    lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step

                    for (int t = 1; t < limit; ++t) {
                        auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
                        indPrev = indCurr;
                    }

                    if(hL)
                        htSet->at(e)->assign(hSet->at(indPrev));     // assign last output to hL if hL is not nullptr

                    tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify();     // nullify for given e and time range [limit, sL)
                }
            }
        }
    }
    else {   // backward

         if(!seqLen) {

            if(!h) {    // seqLen and h are absent

                lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
                for (int t = sL - 2; t >= 0; --t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
            }
            else {  // seqLen is absent and h is present

                lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step
                for (int t = sL - 2; t >= 0; --t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps

                if(hL)
                    hL->assign(hSet->at(0));     // assign last output to hL if it is not nullptr
            }
        }
        else if(directionMode == 1) {   // only backward, no bidirectional mode

            if(!h) {    // h is absent and seqLen is present

                for (int e = 0; e < bS; ++e) {

                    const int limit = seqLen->e<int>(e);

                    if(limit == 0) {
                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();
                        continue;
                    }

                    auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
                    lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step

                    for (int t = sL - 2; t >= sL - limit; --t) {
                        ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
                    }
                }
            }
            else {  // seqLen and h are present

                for (int e = 0; e < bS; ++e) {

                    int limit = seqLen->e<int>(e);

                    if(limit == 0) {

                        tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range

                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();

                        continue;
                    }

                    auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
                    lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step

                    for (int t = sL - 2; t >= sL - limit; --t) {
                        auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
                        indPrev = indCurr;
                    }

                    if(hL)
                        htSet->at(e)->assign(hSet->at(indPrev));     // assign last output to hL if it is not nullptr

                    tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify();    // nullify for given e and time range [limit, sL)
                }
            }
        }
        else {   // backward in bidirectional mode

            if(!h) {    // h is absent and seqLen is present

                for (int e = 0; e < bS; ++e) {

                    const int limit = seqLen->e<int>(e);

                    if(limit == 0) {
                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();
                        continue;
                    }

                    auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
                    lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step

                    for (int t = limit - 2; t >= 0; --t) {
                        ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
                    }
                }
            }
            else {  // seqLen and h are present

                for (int e = 0; e < bS; ++e) {

                    int limit = seqLen->e<int>(e);

                    if(limit == 0) {

                        tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range

                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();

                        continue;
                    }

                    auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
                    lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step

                    for (int t = limit - 2; t >= 0; --t) {
                        auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
                        indPrev = indCurr;
                    }

                    if(hL)
                        htSet->at(e)->assign(hSet->at(indPrev));     // assign last output to hL if it is not nullptr

                    tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify();    // nullify for given e and time range [limit, sL)
                }
            }
        }
    }

    delete xSet;
    delete hSet;
    delete h0Set;
    delete c0Set;
    delete htSet;
    delete ctSet;
}



}
}
}