/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//

// implementation of operation for LSTM cell with peep hole connections:
// http://www.bioinf.jku.at/publications/older/2604.pdf
// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
// and
// https://research.google.com/pubs/archive/43905.pdf
// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.


#include <ops/declarable/helpers/lstmLayer.h>
#include <execution/Threads.h>
#include <ops/declarable/helpers/activations.h>
#include <helpers/ShapeUtils.h>
#include <helpers/MmulHelper.h>
// #include <VariableSpace.h>
// #include <ops/declarable/CustomOperations.h>
// #include<ops/declarable/helpers/transforms.h>
// #include <ops/declarable/helpers/legacy_helpers.h>
// #include <array/NDArrayList.h>
// #include <iterator>


namespace sd 	  {
namespace ops 	  {
namespace helpers {

//////////////////////////////////////////////////////////////////////////
static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {

    switch (opId) {
        case 0:
            (const_cast<NDArray&>(x)).applyTransform(transform::Tanh, z);
            break;
        case 1:
            (const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, z);
            break;
        case 2:
            (const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, z);
            break;
        case 3: {
            ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
            (const_cast<NDArray&>(x)).applyTransform(transform::Affine, z, &args);
            break;
        }
        case 4:
            (const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, z);
            break;
        case 5:
            thresholdRelu(x.getContext(), x, alpha, z);
            break;
        case 6: {
            ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
            (const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, z, &args);
            break;
        }
        case 7:
            (const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, z);
            break;
        case 8:
            (const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, z);
            break;
        case 9:
            (const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, z);
            break;
        case 10:
            (const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, z);
            break;
        default:
            throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
    }
}

//////////////////////////////////////////////////////////////////////////
static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {

    switch (opId) {
        case 0:
            (const_cast<NDArray&>(x)).applyTransform(transform::TanhDerivative, z);
            break;
        case 1:
            (const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELUDerivative, 0, z);
            break;
        case 2:
            (const_cast<NDArray&>(x)).applyTransform(transform::SigmoidDerivative, z);
            break;
        case 3: {
            z = alpha;
            break;
        }
        case 4:
            (const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELUDerivative, alpha, z);
            break;
        case 5:
            (const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELUDerivative, alpha, z);
            break;
        case 6: {
            auto func = PRAGMA_THREADS_FOR {
                for(Nd4jLong i = start; i < stop; ++i) {
                    auto val = beta * x.e<float>(i);
                    z.p<float>(i, alpha * beta * (1.f - sd::math::nd4j_tanh<float,float>(val) * sd::math::nd4j_tanh<float,float>(val)));
                }
            };
            samediff::Threads::parallel_for(func, 0, x.lengthOf());
            break;
        }
        case 7:
            (const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoidDerivative, z);
            break;
        case 8:
            (const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELUDerivative, alpha, z);
            break;
        case 9:
            (const_cast<NDArray&>(x)).applyTransform(transform::SoftSignDerivative, z);
            break;
        case 10: {
            auto func = PRAGMA_THREADS_FOR {
                for(Nd4jLong i = start; i < stop; ++i) {
                    auto val = sd::math::nd4j_exp<float, float>(x.e<float>(i));
                    z.p<float>(i, val / (1.f + val));
                }
            };
            samediff::Threads::parallel_for(func, 0, x.lengthOf());
            break;
        }
        default:
            throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
    }
}

//////////////////////////////////////////////////////////////////////////
// FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal
static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) {

    if(clipVal == 0)
        return;

    auto func = PRAGMA_THREADS_FOR {
        for(Nd4jLong i = start; i < stop; ++i) {
            const auto val = c.e<float>(i);
            if(val == -clipVal || val == clipVal) {
                z0.p<float>(i, 0.f);
                z1.p<float>(i, 0.f);
                z2.p<float>(i, 0.f);
                z3.p<float>(i, 0.f);
            }
        }
    };
    samediff::Threads::parallel_for(func, 0, c.lengthOf());
}

//////////////////////////////////////////////////////////////////////////
static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {

    if(dataFormat == 0 || dataFormat == 3)
        return arr({t1,t2, b1,b2, 0,0});    // TNS: [sL, bS, nIn]

    if(dataFormat == 1)
        return arr({b1,b2, t1,t2, 0,0});    // NTS: [bS, sL ,nIn]

    return arr({b1,b2, 0,0, t1,t2});        // NST: [bS, nIn, sL]
}

//////////////////////////////////////////////////////////////////////////
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {

    if(dataFormat == 0 || dataFormat == 3)
        return t * bS + b;   // TNS: shape [sL, bS, nIn]

    return b * sL + t;       // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
}


//////////////////////////////////////////////////////////////////////////
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
                   const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
                   const std::vector<float>& params,
                   NDArray* h, NDArray* c) {

    // * -> means element-wise multiplication
    // × -> means matrix multiplication

    /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
    /** the objective is to provide math-readable code **/

    // equations (no peephole connections)
    // it  = σ(Wxi × xt  +  Wri × ht-1  +  bi)
    // ft  = σ(Wxf × xt  +  Wrf × ht-1  +  bf)
    // c't = tanh(Wxc × xt  +  Wrc × ht-1  +  bc)
    // ct  = ft * ct-1 + it * c't
    // ot  = σ(Wxo × xt  +  Wro × ht-1  +  bo)
    // ht  = ot * tanh(ct)

    // equations (peephole connections are present)
    // it  = σ(Wxi × xt  +  Wri × ht-1  +  Wpi * ct-1  +  bi)
    // ft  = σ(Wxf × xt  +  Wrf × ht-1  +  Wpf * ct-1  +  bf)
    // c't = tanh(Wxc × xt  +  Wrc × ht-1  +  bc)
    // ct  = ft * ct-1 + it * c't
    // ot  = σ(Wxo × xt  +  Wro × ht-1  +  Wpo * ct   +  bo)
    // ht  = ot * tanh(ct)


    // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus

    // params[0] - dataFormat, ignore
    // params[1] - directionMode, ignore
    // params[2] - cell clipping value, if it = 0 then do not apply clipping

    // params[3]  - activation ID for input (i), forget (f) and output (o) gates
    // params[4]  - alpha value for gates activation
    // params[5]  - beta value for gates activation

    // params[6]  - activation ID for cell state (c)
    // params[7]  - alpha value for cell state activation
    // params[8]  - beta value for cell state activation

    // params[9]  - activation ID for output (h)
    // params[10] - alpha value for output activation
    // params[11] - beta value for output activation

    // INPUTS:
    // x  - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
    // Wx - input weights [nIn, 4*nOut]
    // Wr - recurrent weights [nOut, 4*nOut]
    // b  - biases [4*nOut], optional, may be nullptr
    // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
    // cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
    // Wp - peephole weights [3*nOut], optional, may be nullptr

    // OUTPUTS:
    // h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
    // c - current cell state,  that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr

    // !!! dimension 4*nOut implies order it, ft, c't, ot
    // !!! dimension 3*nOut implies order it, ft, ot

    const Nd4jLong nOut = Wx->sizeAt(-1) / 4;

    auto z = mmul(*x, *Wx) + mmul(*hI, *Wr);    //   [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
                                                //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]

    // add biases if they are given
    if(b != nullptr)
        z += *b;                                    // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut]

    auto zi = x->rankOf() == 1 ? z({0,        nOut}) : z({0,0, 0,        nOut});         // input gate it, [bS, nOut](or[nOut])
    auto zf = x->rankOf() == 1 ? z({nOut,   2*nOut}) : z({0,0, nOut,   2*nOut});         // forget gate ft, [bS, nOut](or[nOut])
    auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut});         // cell gate c't, [bS, nOut](or[nOut])
    auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut});         // output gate ot, [bS, nOut](or[nOut])

    // peephole connections for input and forget gates
    if(Wp != nullptr) {
        zi += *cI * (*Wp)({0,      nOut});      // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
        zf += *cI * (*Wp)({nOut, 2*nOut});      // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
    }

    applyActivation(zi, params[3], params[4], params[5], zi);   // inplace
    applyActivation(zf, params[3], params[4], params[5], zf);   // inplace
    applyActivation(zg, params[6], params[7], params[8], zg);   // inplace

    c->assign(zf * *cI + zi * zg);          // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])

    // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
    if(params[2] != 0)
        c->applyScalar(scalar::LstmClip, params[2], *c);

    // peephole connections for output gate
    if(Wp != nullptr)
        zo += *c * (*Wp)({2*nOut, 3*nOut});    // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])

    applyActivation(zo, params[3], params[4], params[5], zo);

    applyActivation(*c, params[9], params[10], params[11], *h);
    *h *= zo;                               // [bS, nOut] * [bS, nOut](or[nOut])
}


//////////////////////////////////////////////////////////////////////////
// this auxiliary ff should be running before backprop
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
                   const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
                   const std::vector<float>& params,
                   NDArray* z, NDArray* a, NDArray* h, NDArray* c) {

    // z - zi, zf, zg, zo
    // a - i, f, g, o

    const Nd4jLong nOut = Wx->sizeAt(-1) / 4;

    z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr));    //   [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
                                                  //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
    // add biases if they are given
    if(b != nullptr)
        *z += *b;                                    // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut]

    auto zi = x->rankOf() == 1 ? (*z)({0,        nOut}) : (*z)({0,0, 0,        nOut});         // input gate it, [bS, nOut](or[nOut])
    auto zf = x->rankOf() == 1 ? (*z)({nOut,   2*nOut}) : (*z)({0,0, nOut,   2*nOut});         // forget gate ft, [bS, nOut](or[nOut])
    auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut});         // cell gate c't, [bS, nOut](or[nOut])
    auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut});         // output gate ot, [bS, nOut](or[nOut])

    auto i = x->rankOf() == 1 ? (*a)({0,        nOut}) : (*a)({0,0, 0,        nOut});         // input gate it, [bS, nOut](or[nOut])
    auto f = x->rankOf() == 1 ? (*a)({nOut,   2*nOut}) : (*a)({0,0, nOut,   2*nOut});         // forget gate ft, [bS, nOut](or[nOut])
    auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut});         // cell gate c't, [bS, nOut](or[nOut])
    auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut});         // output gate ot, [bS, nOut](or[nOut])

    // peephole connections for input and forget gates
    if(Wp != nullptr) {
        zi += *cI * (*Wp)({0,      nOut});      // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
        zf += *cI * (*Wp)({nOut, 2*nOut});      // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
    }

    applyActivation(zi, params[3], params[4], params[5], i);
    applyActivation(zf, params[3], params[4], params[5], f);
    applyActivation(zg, params[6], params[7], params[8], g);

    c->assign(f * *cI + i * g);          // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])

    // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
    if(params[2] != 0)
        c->applyScalar(scalar::LstmClip, params[2], *c);

    // peephole connections for output gate
    if(Wp != nullptr)
        zo += *c * (*Wp)({2*nOut, 3*nOut});    // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])

    applyActivation(zo, params[3], params[4], params[5], o);

    applyActivation(*c, params[9], params[10], params[11], *h);
    *h *= o;                               // [bS, nOut] * [bS, nOut](or[nOut])
}


//////////////////////////////////////////////////////////////////////////
void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
                     const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
                     const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
                     NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) {

    /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
    /** the objective is to provide math-readable code **/

    // equations (no peephole connections)
    // zi = x × Wxi + hI × Wri + bi
    // zf = x × Wxf + hI × Wrf + bf
    // zg = x × Wxg + hI × Wrg + bg
    // zo = x × Wxo + hI × Wro + bo
    // i = act(zi)
    // f = act(zf)
    // g = actC(zg)
    // o = act(zo)
    // c = clip(f * cI + i * g)
    // h = o * actH(c)

    // equations (peephole connections are present)
    // zi = x × Wxi + hI × Wri + cI * Wpi + bi
    // zf = x × Wxf + hI × Wrf + cI * Wpf + bf
    // zg = x × Wxg + hI × Wrg + bg
    // zo = x × Wxo + hI × Wro + c  * Wpo + bo
    // i = act(zi)
    // f = act(zf)
    // g = actC(zg)
    // o = act(zo)
    // c = clip(f * cI + i * g)
    // h = o * actH(c)

    // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus

    // params[0] - dataFormat, ignore
    // params[1] - directionMode, ignore
    // params[2] - cell clipping value, if it = 0 then do not apply clipping

    // params[3]  - activation ID for input (i), forget (f) and output (o) gates
    // params[4]  - alpha value for gates activation
    // params[5]  - beta value for gates activation

    // params[6]  - activation ID for cell state (c)
    // params[7]  - alpha value for cell state activation
    // params[8]  - beta value for cell state activation

    // params[9]  - activation ID for output (h)
    // params[10] - alpha value for output activation
    // params[11] - beta value for output activation

    // INPUTS:
    // x     - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
    // Wx    - input weights [nIn, 4*nOut]
    // Wr    - recurrent weights [nOut, 4*nOut]
    // b     - biases [4*nOut], optional, may be nullptr
    // hI    - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
    // cI    - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
    // Wp    - peephole weights [3*nOut], optional, may be nullptr
    // dLdh  - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr
    // dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr
    // dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr
    // z     - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
    // a     - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
    // c     - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut]

    // OUTPUTS:
    // dLdx  - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr
    // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut]
    // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut]
    // dLdb  - loss derivative with respect to b, optional, may be nullptr, [4*nOut]
    // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
    // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
    // dLdWp - loss derivative with respect to Wp, optional, may be nullptr,  [3*nOut]

    // !!! dimension 4*nOut implies order i, f, g, o
    // !!! dimension 3*nOut implies order i, f, o

    // dhdc  = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0       [bS, nOut]
    // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0         [bS, nOut]

    // dLdhI += dLdh;                       [bS, nOut]
    // dLdcI += dLdhI * dhdc;               [bS, nOut]

    // dLdzi = dLdcI*dcdi*didzi;            [bS, nOut](or[nOut])
    // dLdzf = dLdcI*dcdf*dfdzf;            [bS, nOut](or[nOut])
    // dLdzg = dLdcI*dcdg*dgdzg;            [bS, nOut](or[nOut])
    // dLdzo = dLdhI*dhdo*dodzo;            [bS, nOut](or[nOut])

    // dLdx  = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT,   [bS, nIn]
    // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT,   [bS, nOut]
    // dLdcI = dLdcI*dcdcI,                                         [bS, nOut]

    // dLdWxi = xT×dLdzi                            [nIn, bS] x [bS, nOut] = [nIn, nOut]
    // dLdWxf = xT×dLdzf                            [nIn, bS] x [bS, nOut] = [nIn, nOut]
    // dLdWxg = xT×dLdzg                            [nIn, bS] x [bS, nOut] = [nIn, nOut]
    // dLdWxo = xT×dLdzo                            [nIn, bS] x [bS, nOut] = [nIn, nOut]

    // dLdWri = hIT×dLdzi                           [nOut, bS] x [bS, nOut] = [nOut, nOut]
    // dLdWrf = hIT×dLdzf                           [nOut, bS] x [bS, nOut] = [nOut, nOut]
    // dLdWrg = hIT×dLdzg                           [nOut, bS] x [bS, nOut] = [nOut, nOut]
    // dLdWro = hIT×dLdzo                           [nOut, bS] x [bS, nOut] = [nOut, nOut]

    //  dLdbi = dLdzi.reduce_sum_along_0_axis       [bS, nOut] -> reduce -> [nOut]
    //  dLdbf = dLdzf.reduce_sum_along_0_axis       [bS, nOut] -> reduce -> [nOut]
    //  dLdbg = dLdzg.reduce_sum_along_0_axis       [bS, nOut] -> reduce -> [nOut]
    //  dLdbo = dLdzo.reduce_sum_along_0_axis       [bS, nOut] -> reduce -> [nOut]

    //  dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis       [bS, nOut] -> reduce -> [nOut]
    //  dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis       [bS, nOut] -> reduce -> [nOut]
    //  dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis       [bS, nOut] -> reduce -> [nOut]

    const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
    const Nd4jLong nIn  = x->sizeAt(-1);

    NDArray zi = x->rankOf() == 1 ? (*z)({0,        nOut}) : (*z)({0,0, 0,        nOut});         // input gate i, [bS, nOut](or[nOut])
    NDArray zf = x->rankOf() == 1 ? (*z)({nOut,   2*nOut}) : (*z)({0,0, nOut,   2*nOut});         // forget gate f, [bS, nOut](or[nOut])
    NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut});         // cell gate g, [bS, nOut](or[nOut])
    NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut});         // output gate o, [bS, nOut](or[nOut])

    NDArray i = x->rankOf() == 1 ? (*a)({0,        nOut}) : (*a)({0,0, 0,        nOut});         // input gate i, [bS, nOut](or[nOut])
    NDArray f = x->rankOf() == 1 ? (*a)({nOut,   2*nOut}) : (*a)({0,0, nOut,   2*nOut});         // forget gate f, [bS, nOut](or[nOut])
    NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut});         // cell gate g, [bS, nOut](or[nOut])
    NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut});         // output gate o, [bS, nOut](or[nOut])

    NDArray dLdz = z->ulike();                                                                  // [bS, 4*nOut](or[4*nOut])
    NDArray dLdzi = x->rankOf() == 1 ? dLdz({0,        nOut}) : dLdz({0,0, 0,        nOut});
    NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut,   2*nOut}) : dLdz({0,0, nOut,   2*nOut});
    NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut});
    NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut});

    // dcdzi = dcdi*didzi, [bS, nOut](or[nOut])
    activationDeriv(zi, params[3], params[4], params[5], dLdzi);     // didzi, inplace
    dLdzi *= g;                                                      // dcdi = g*clipDeriv

    // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut])
    activationDeriv(zf, params[3], params[4], params[5], dLdzf);     // dfdzf, inplace
    dLdzf *= *cI;                                                    // dcdf = cI*clipDeriv

    // dcdzg = dcde*dedzg, [bS, nOut](or[nOut])
    activationDeriv(zg, params[6], params[7], params[8], dLdzg);    // dgdzg, inplace
    dLdzg *= i;                                                     // dcdf = i*clipDeriv

    // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut])
    activationDeriv(zo, params[3], params[4], params[5], dLdzo);
    NDArray temp = dLdzo.ulike();
    applyActivation(*c, params[9], params[10], params[11], temp);     // actH(c), inplace
    dLdzo *= temp;

    // dcdcI
    NDArray dcdcI = f.dup();                                           // dcdcI = f*clipDeriv [bS, nOut](or[nOut])

    // take into account possible deposit from clipping derivative
    clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI);

    // dhdc
    NDArray dhdc = c->ulike();
    activationDeriv(*c, params[9], params[10], params[11], dhdc);    // [bS, nOut]
    dhdc *= o;

    if(Wp) {
        dhdc  += dLdzo*(*Wp)({2*nOut, 3*nOut});
        dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut});     // broadcast [bS, nOut] * nOut + ...
    }

    if(dLdh)
        *dLdhI += *dLdh;
    if(dLdhL)
        *dLdhI += *dLdhL;
    if(dLdcL)
        *dLdcI += *dLdcL;

    *dLdcI += *dLdhI * dhdc;

    dLdzi *= *dLdcI;     // [bS, nOut](or[nOut])
    dLdzf *= *dLdcI;     // [bS, nOut](or[nOut])
    dLdzg *= *dLdcI;     // [bS, nOut](or[nOut])
    dLdzo *= *dLdhI;     // [bS, nOut](or[nOut])

    // dLdx
    NDArray WxT = Wx->transpose();
    MmulHelper::mmul(&dLdz, &WxT, dLdx);        // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] )

    // dLdhI
    NDArray WrT = Wr->transpose();
    MmulHelper::mmul(&dLdz, &WrT, dLdhI);       // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] )

    // dLdcI
    dLdcI->assign(*dLdcI*dcdcI);                                            // [bS, nOut](or[nOut])

    if(x->rankOf() == 1) {

        NDArray xT  = x->reshape(x->ordering(),{nIn, 1});               // [nIn]  -> [nIn, 1]
        NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1});            // [nOut] -> [nOut, 1]
        NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut});     // [nOut] -> [1, 4*nOut]

        // dLdWx
        *dLdWx += mmul(xT, dLdzR);        // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut]

        // dLdWr
        *dLdWr += mmul(hIT, dLdzR);        // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut]
    }
    else {

        // dLdWx
        *dLdWx += mmul(x->transpose(), dLdz);         // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut]

        // dLdWr
        *dLdWr += mmul(hI->transpose(), dLdz);        // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut]
    }

    // dLdb
    if(b && x->rankOf() == 1)
        *dLdb += dLdz;                                          // [4*nOut]
    else if(b)
        *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0});   // [bS, 4*nOut] -> reduce -> [4*nOut];

    // dLdWp
    if(Wp && x->rankOf() == 1) {
        (*dLdWp)({       0,nOut}) += std::move(dLdzi)*(*cI);          // [nOut]
        (*dLdWp)({  nOut,2*nOut}) += std::move(dLdzf)*(*cI);          // [nOut]
        (*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c);           // [nOut]
    }
    else if(Wp) {
        NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext());
        (std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0});     // [bS, nOut] -> reduce -> [nOut]
        (*dLdWp)({0,nOut}) += temp;
        (std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0});     // [bS, nOut] -> reduce -> [nOut]
        (*dLdWp)({nOut,2*nOut}) += temp;
        (std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0});     // [bS, nOut] -> reduce -> [nOut]
        (*dLdWp)({2*nOut,3*nOut}) += temp;
    }
}

//////////////////////////////////////////////////////////////////////////
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
                       const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
                       const std::vector<float>& params,
                       const bool forward,
                       NDArray* h, NDArray* hL, NDArray* cL) {

    // INPUTS:
    // x  - current input  [sL, bS, nIn],  [bS, sL, nIn],  [bS, nIn, sL],
    // Wx - input weights [nIn, 4*nOut]
    // Wr - recurrent weights [nOut, 4*nOut]
    // b  - biases [4*nOut], optional, may be nullptr
    // seqLen - [bS], optional, may be nullptr
    // hI - initial output  [bS, nOut], optional, may be nullptr
    // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
    // Wp - peephole weights [3*nOut], optional, may be nullptr

    // OUTPUTS:
    // h - output [sL, bS, nOut],  [bS, sL, nOut],  [bS, nOut, sL], optional, may be nullptr
    // hL - output at last step [bS, nOut], optional, may be nullptr
    // cL - cell state at last step [bS, nOut], optional, may be nullptr

    // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
    // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]

    const int dataFormat    = params[0];
    const int directionMode = params[1];

    const Nd4jLong sL   = dataFormat == 3 ?  x->sizeAt(0) : x->sizeAt(dataFormat);
    const Nd4jLong bS   = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
    const Nd4jLong nOut = Wx->sizeAt(-1) / 4;

    const std::vector<Nd4jLong> shapeOut = {bS, nOut};

    const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType());

    auto h0 = const_cast<NDArray*>(hI);
    if(!hI) {
        h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext());
        h0->nullify();
    }

    auto c0 = const_cast<NDArray*>(cI);
    if(!cI) {
        c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext());
        c0->nullify();
    }

    auto ct = cL;
    if(!cL)
        ct = new NDArray(x->ordering(), shapeOut, type, x->getContext());

    auto ht = hL;
    if(!h && !hL)
        ht = new NDArray(x->ordering(), shapeOut, type, x->getContext());

    // create sets of required (depends on seqLen presence) sub-arrays
    std::vector<int> dims;
    ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr);

    if(!seqLen) {

        dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0});    // points on bS and nIn/nOut axes

        xSet = new ResultSet(x->allTensorsAlongDimension(dims));       // sub-arrays with shape [bS, nIn]
        if(h)
            hSet = new ResultSet(h->allTensorsAlongDimension(dims));   // sub-arrays with shape [bS, nOut]
    }
    else {

        dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2});    // points on nIn/nOut axis

        xSet  = new ResultSet(x->allTensorsAlongDimension(dims));               //  sub-arrays with shape [nIn]
        h0Set = new ResultSet(h0->allTensorsAlongDimension({1}));              //  sub-arrays with shape [nOut]
        c0Set = new ResultSet(c0->allTensorsAlongDimension({1}));              //  sub-arrays with shape [nOut]
        ctSet = new ResultSet(ct->allTensorsAlongDimension({1}));              //  sub-arrays with shape [nOut]
        if(h)
            hSet = new ResultSet(h->allTensorsAlongDimension(dims));            //  sub-arrays with shape [nOut]
        if(ht)
            htSet = new ResultSet(ht->allTensorsAlongDimension({1}));          //  sub-arrays with shape [nOut]
    }

    // loops
    if(forward) {

        if(!seqLen) {

            if(!h) {    // seqLen and h are absent

                lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
                for (Nd4jLong t = 1; t < sL; ++t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
            }
            else {      // seqLen is absent and h is present

                lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step
                for (Nd4jLong t = 1; t < sL; ++t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps

                if(hL)
                    hL->assign(hSet->at(sL - 1));     // assign last output to hL if it is not nullptr
            }
        }
        else {

            if(!h) {    // seqLen is present and h is absent

                for (Nd4jLong e = 0; e < bS; ++e) {

                    const int limit = seqLen->e<int>(e);

                    if(limit == 0) {
                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();
                        continue;
                    }

                    auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
                    lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step

                    for (int t = 1; t < limit; ++t) {
                        ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
                    }
                }
            }
            else { // seqLen and h are present

                for (Nd4jLong e = 0; e < bS; ++e) {

                    int limit = seqLen->e<int>(e);

                    if(limit == 0) {

                        tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range

                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();

                        continue;
                    }

                    auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
                    lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step

                    for (int t = 1; t < limit; ++t) {
                        auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
                        indPrev = indCurr;
                    }

                    if(hL)
                        htSet->at(e)->assign(hSet->at(indPrev));     // assign last output to hL if hL is not nullptr

                    if(limit != sL)
                        tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify();     // nullify for given e and time range [limit, sL)
                }
            }
        }
    }
    else {   // backward

         if(!seqLen) {

            if(!h) {    // seqLen and h are absent

                lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
                for (Nd4jLong t = sL - 2; t >= 0; --t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
            }
            else {  // seqLen is absent and h is present

                lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step
                for (Nd4jLong t = sL - 2; t >= 0; --t)
                    lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps

                if(hL)
                    hL->assign(hSet->at(0));     // assign last output to hL if it is not nullptr
            }
        }
        else if(directionMode == 1) {   // only backward, no bidirectional mode

            if(!h) {    // h is absent and seqLen is present

                for (Nd4jLong e = 0; e < bS; ++e) {

                    const int limit = seqLen->e<int>(e);

                    if(limit == 0) {
                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();
                        continue;
                    }

                    auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
                    lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step

                    for (Nd4jLong t = sL - 2; t >= sL - limit; --t) {
                        ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
                    }
                }
            }
            else {  // seqLen and h are present

                for (Nd4jLong e = 0; e < bS; ++e) {

                    int limit = seqLen->e<int>(e);

                    if(limit == 0) {

                        tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range

                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();

                        continue;
                    }

                    auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
                    lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step

                    for (Nd4jLong t = sL - 2; t >= sL - limit; --t) {
                        auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
                        indPrev = indCurr;
                    }

                    if(hL)
                        htSet->at(e)->assign(hSet->at(indPrev));     // assign last output to hL if it is not nullptr

                    if(limit != sL)
                        tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify();    // nullify for given e and time range [limit, sL)
                }
            }
        }
        else {   // backward in bidirectional mode

            if(!h) {    // h is absent and seqLen is present

                for (Nd4jLong e = 0; e < bS; ++e) {

                    const int limit = seqLen->e<int>(e);

                    if(limit == 0) {
                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();
                        continue;
                    }

                    auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
                    lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step

                    for (int t = limit - 2; t >= 0; --t) {
                        ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
                    }
                }
            }
            else {  // seqLen and h are present

                for (Nd4jLong e = 0; e < bS; ++e) {

                    int limit = seqLen->e<int>(e);

                    if(limit == 0) {

                        tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range

                        if(cL)
                            ctSet->at(e)->nullify();
                        if(hL)
                            htSet->at(e)->nullify();

                        continue;
                    }

                    auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
                    lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step

                    for (int t = limit - 2; t >= 0; --t) {
                        auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                        lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
                        indPrev = indCurr;
                    }

                    if(hL)
                        htSet->at(e)->assign(hSet->at(indPrev));     // assign last output to hL if it is not nullptr

                    if(limit != sL)
                        tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify();    // nullify for given e and time range [limit, sL)
                }
            }
        }
    }

    delete xSet;
    delete hSet;
    delete h0Set;
    delete c0Set;
    delete htSet;
    delete ctSet;

    if(!hI)
        delete h0;
    if(!cI)
        delete c0;
    if(!cL)
        delete ct;
    if(!h && !hL)
        delete ht;
}


//////////////////////////////////////////////////////////////////////////
void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
                         const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp,
                         const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
                         const std::vector<float>& params, const bool forward,
                         NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) {

    // INPUTS:
    // x  - current input  [sL, bS, nIn],  [bS, sL, nIn],  [bS, nIn, sL],
    // Wx - input weights [nIn, 4*nOut]
    // Wr - recurrent weights [nOut, 4*nOut]
    // b  - biases [4*nOut], optional, may be nullptr
    // seqLen - [bS], optional, may be nullptr
    // hI - initial output  [bS, nOut], optional, may be nullptr
    // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
    // Wp - peephole weights [3*nOut], optional, may be nullptr
    // dLdh - gradient vs. output [sL, bS, nOut],  [bS, sL, nOut],  [bS, nOut, sL], optional, may be nullptr
    // dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr
    // dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr

    // OUTPUTS:
    // dLdx  - gradient vs. input [sL, bS, nIn],  [bS, sL, nIn],  [bS, nIn, sL]
    // dLdWx - gradient vs. input weights [nIn, 4*nOut]
    // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut]
    // dLdb  - gradient vs. biases [4*nOut], optional, may be nullptr
    // dLdhI - gradient vs. initial output  [bS, nOut], optional, may be nullptr
    // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr
    // dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr

    // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
    // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]

    const int dataFormat    = params[0];
    const int directionMode = params[1];

    const int sL = dataFormat == 3 ?  x->sizeAt(0) : x->sizeAt(dataFormat);
    const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
    const int nOut = Wx->sizeAt(-1) / 4;

    const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType());

    auto dLdh0 = dLdhI;
    if(!hI)
        dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext());     // this constructor nullifies array automatically

    auto dLdc0 = dLdcI;
    if(!cI)
        dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext());     // this constructor nullifies array automatically

    NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext());
    NDArray a = z.ulike();
    NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext());
    NDArray c = h.ulike();

    // create sets of required (depends on seqLen presence) sub-arrays
    std::vector<int> dims;
    ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr),
              *dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr);

    if(!seqLen) {

        dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0});    // points on [bS, nIn/nOut]

        xSet    = new ResultSet(x->allTensorsAlongDimension(dims));         // sub-arrays with shape [bS, nIn]
        dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims));      // sub-arrays with shape [bS, nIn]
        hSet    = new ResultSet(h.allTensorsAlongDimension({1, 2}));        // sub-arrays with shape [bS, nOut]
        cSet    = new ResultSet(c.allTensorsAlongDimension({1, 2}));        // sub-arrays with shape [bS, nOut]
        zSet    = new ResultSet(z.allTensorsAlongDimension({1, 2}));        // sub-arrays with shape [bS, 4*nOut]
        aSet    = new ResultSet(a.allTensorsAlongDimension({1, 2}));        // sub-arrays with shape [bS, 4*nOut]
        if(dLdh)
            dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims));  // sub-arrays with shape [bS, nOut]
    }
    else {

        dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2});    // points on nIn/nOut axis

        xSet    = new ResultSet(x->allTensorsAlongDimension(dims));         // sub-arrays with shape [nIn]
        dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims));      // sub-arrays with shape [nIn]
        hSet    = new ResultSet(h.allTensorsAlongDimension({2}));           // sub-arrays with shape [nOut]
        cSet    = new ResultSet(c.allTensorsAlongDimension({2}));           // sub-arrays with shape [nOut]
        zSet    = new ResultSet(z.allTensorsAlongDimension({2}));           // sub-arrays with shape [4*nOut]
        aSet    = new ResultSet(a.allTensorsAlongDimension({2}));           // sub-arrays with shape [4*nOut]

        if(hI)
            hISet = new ResultSet(hI->allTensorsAlongDimension({1}));      // sub-arrays with shape [nOut]
        if(cI)
            cISet = new ResultSet(cI->allTensorsAlongDimension({1}));      // sub-arrays with shape [nOut]

        dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1}));      // sub-arrays with shape [nOut]
        dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1}));      // sub-arrays with shape [nOut]

        if(dLdh)
            dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims));  // sub-arrays with shape [nOut]
        if(dLdhL)
            dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1}));  // sub-arrays with shape [nOut]
        if(dLdcL)
            dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1}));  // sub-arrays with shape [nOut]
    }


    // loops
    if(forward) {

        if(!seqLen) {   // seqLen is absent

            if(hI)
                hSet->at(0)->assign(hI);
            else
                hSet->at(0)->nullify();
            if(cI)
                cSet->at(0)->assign(cI);
            else
                cSet->at(0)->nullify();

            // ff
            for (int t = 0; t < sL; ++t)
                lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1));

            // bp
            for (int t = sL-1; t >= 0; --t) {
                const NDArray* dLdhh  = dLdh ? dLdhSet->at(t) : nullptr;
                const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr;
                const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr;
                lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL,
                                zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
            }
        }
        else {  // seqLen is present

            for (int e = 0; e < bS; ++e) {

                const int limit = seqLen->e<int>(e);

                if(limit == 0) {
                    tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range
                    continue;
                }

                if(hI)
                    hSet->at(e)->assign(hISet->at(e));
                else
                    hSet->at(e)->nullify();
                if(cI)
                    cSet->at(e)->assign(cISet->at(e));
                else
                    cSet->at(e)->nullify();

                // ff
                for (int t = 0; t < limit; ++t)
                    lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params,
                                  zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e));

                // bp
                for (int t = limit-1; t >= 0; --t) {
                    const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                    const NDArray* dLdhh  = dLdh ? dLdhSet->at(ind) : nullptr;
                    const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr;
                    const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr;
                    lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
                                    zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
                                    dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
                }

               if(limit != sL)
                    tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify();     // nullify for given e and time range [limit, sL)
            }
        }
    }
    else {   // backward or bidirectional

        if(!seqLen) {     // backward or bidirectional, seqLen is absent

            if(hI)
                hSet->at(sL)->assign(hI);
            else
                hSet->at(sL)->nullify();
            if(cI)
                cSet->at(sL)->assign(cI);
            else
                cSet->at(sL)->nullify();

            // ff
            for (int t = sL-1; t >= 0; --t)
                lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t));

            // bp
            for (int t = 0; t < sL; ++t) {
                const NDArray* dLdhh  = dLdh ? dLdhSet->at(t) : nullptr;
                const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr;
                const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr;
                lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp,  dLdhh, dLdhhL, dLdccL,
                                zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
            }
        }
        else if(directionMode == 1) {   // backward, seqLen is present

            for (int e = 0; e < bS; ++e) {

                const int limit = seqLen->e<int>(e);

                if(limit == 0) {
                    tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range
                    continue;
                }

                if(hI)
                    hSet->at(sL*bS + e)->assign(hISet->at(e));
                else
                    hSet->at(sL*bS + e)->nullify();
                if(cI)
                    cSet->at(sL*bS + e)->assign(cISet->at(e));
                else
                    cSet->at(sL*bS + e)->nullify();

                // ff
                for (int t = sL - 1; t >= sL-limit; --t)
                    lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params,
                                  zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e));

                // bp
                for (int t = sL-limit; t < sL; ++t) {
                    const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                    const NDArray* dLdhh  = dLdh ? dLdhSet->at(ind) : nullptr;
                    const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr;
                    const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr;
                    lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
                                    zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
                                    dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
                }

                if(limit != sL)
                    tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify();    // nullify for given e and time range [limit, sL)
            }
        }
        else {   // bidirectional mode, seqLen is present

            for (int e = 0; e < bS; ++e) {

                const int limit = seqLen->e<int>(e);

                if(limit == 0) {
                    tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify();     // nullify for given e and whole time range
                    continue;
                }

                if(hI)
                    h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e));
                else
                    h({limit,limit+1, e,e+1, 0,0}).nullify();
                if(cI)
                    c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e));
                else
                    c({limit,limit+1, e,e+1, 0,0}).nullify();

                // ff
                for (int t = limit - 1; t >= 0; --t)
                    lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params,
                                  zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e));

                // bp
                for (int t = 0; t < limit; ++t) {
                    const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
                    const NDArray* dLdhh  = dLdh ? dLdhSet->at(ind) : nullptr;
                    const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr;
                    const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr;
                    lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
                                    zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
                                    dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
                }

                if(limit != sL)
                    tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify();    // nullify for given e and time range [limit, sL)
            }
        }
    }

    delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet;
    delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet;

    if(!hI)
        delete dLdh0;
    if(!cI)
        delete dLdc0;
}


}
}
}



//////////////////////////////////////////////////////////////////////////
// void lstmLayerCellBp(const NDArray* x,    const NDArray* Wx, const NDArray* Wr,
//                      const NDArray* b,          NDArray* hI,       NDArray* cI, const NDArray* Wp, const NDArray* dLdh,
//                      const std::vector<float>& params, const bool firstIter,

//                      NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr,
//                      NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp,
//                      NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) {

//     /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
//     /** the objective is to provide math-readable code **/

//     // equations (no peephole connections)
//     // zi = x × Wxi + hI × Wri + bi
//     // zf = x × Wxf + hI × Wrf + bf
//     // zg = x × Wxg + hI × Wrg + bg
//     // zo = x × Wxo + hI × Wro + bo
//     // i = act(zi)
//     // f = act(zf)
//     // g = actC(zg)
//     // o = act(zo)
//     // c = clip(f * cI + i * g)
//     // h = o * actH(c)

//     // equations (peephole connections are present)
//     // zi = x × Wxi + hI × Wri + cI * Wpi + bi
//     // zf = x × Wxf + hI × Wrf + cI * Wpf + bf
//     // zg = x × Wxg + hI × Wrg + bg
//     // zo = x × Wxo + hI × Wro + c  * Wpo + bo
//     // i = act(zi)
//     // f = act(zf)
//     // g = actC(zg)
//     // o = act(zo)
//     // c = clip(f * cI + i * g)
//     // h = o * actH(c)

//     // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus

//     // params[0] - dataFormat, ignore
//     // params[1] - directionMode, ignore
//     // params[2] - cell clipping value, if it = 0 then do not apply clipping

//     // params[3]  - activation ID for input (i), forget (f) and output (o) gates
//     // params[4]  - alpha value for gates activation
//     // params[5]  - beta value for gates activation

//     // params[6]  - activation ID for cell state (c)
//     // params[7]  - alpha value for cell state activation
//     // params[8]  - beta value for cell state activation

//     // params[9]  - activation ID for output (h)
//     // params[10] - alpha value for output activation
//     // params[11] - beta value for output activation

//     // INPUTS:
//     // x    - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
//     // Wx   - input weights [nIn, 4*nOut]
//     // Wr   - recurrent weights [nOut, 4*nOut]
//     // b    - biases [4*nOut], optional, may be nullptr
//     // hI   - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
//     // cI   - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
//     // Wp   - peephole weights [3*nOut], optional, may be nullptr
//     // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr
//     // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr
//     // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr
//     // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr
//     // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr
//     // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr
//     // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr
//     // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr
//     // dcIdb  - derivative from previous time step, [4*nOut], optional, may be nullptr
//     // dhIdb  - derivative from previous time step, [4*nOut], optional, may be nullptr

//     // OUTPUTS:
//     // dLdx  - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr
//     // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut]
//     // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut]
//     // dLdb  - loss derivative with respect to b, optional, may be nullptr, [4*nOut]
//     // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
//     // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
//     // dLdWp - loss derivative with respect to Wp, optional, may be nullptr,  [3*nOut]

//     // !!! dimension 4*nOut implies order i, f, g, o
//     // !!! dimension 3*nOut implies order i, f, o

//     // dcdzi = dcdi*didzi
//     // dcdzf = dcdf*dfdzf
//     // dcdzg = dcdg*dgdzg
//     // dhdzo = dhdo*dodzo

//     // dhdc = dhdc + Wp ? dhdzo*dzodc : 0           [bS, nOut]
//     // factor = dLdh*dhdc                           [bS, nOut]
//     // iFactor = factor*dcdzi                       [bS, nOut]
//     // fFactor = factor*dcdzf                       [bS, nOut]
//     // eFactor = factor*dcdzg                       [bS, nOut]
//     // oFactor = *dLdh*dhdzo                        [bS, nOut]

//     // tempC   = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0;
//     // tempIFE = dcdzi×WriT + dcdzf×WrfT + dcdzg×WrgT
//     // tempO   = dhdzo×WroT

//     // dhIdcI = dhdc_from_previous_time_step

//     // dLdx   = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT,      [bS, nIn]
//     // dLdhI  = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT,      [bS, nOut]
//     // dLdcI  = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter,           [bS, nOut]

//     // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi,           dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut]
//     // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf,           dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut]
//     // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg,           dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut]
//     // dcdWxo(dcIdWxo) =       0       + tempIFE*dhIdWxo + tempC*dcIdWxo;           dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut]

//     // dhdWxi(dhIdWxi) =       0       + dhdc*dcdWxi + tempO*dhIdWxi,           dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut]
//     // dhdWxf(dhIdWxf) =       0       + dhdc*dcdWxf + tempO*dhIdWxf,           dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut]
//     // dhdWxg(dhIdWxg) =       0       + dhdc*dcdWxg + tempO*dhIdWxg,           dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut]
//     // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo,           dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut]

//     // dhdWri(dhIdWri) =       0       + dhdc*dcdWri + tempO*dhIdWri,           dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
//     // dhdWrf(dhIdWrf) =       0       + dhdc*dcdWrf + tempO*dhIdWrf,           dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut]
//     // dhdWrg(dhIdWrg) =       0       + dhdc*dcdWrg + tempO*dhIdWrg,           dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut]
//     // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro,           dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut]

//     // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri,           dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
//     // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf,           dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
//     // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg,           dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
//     // dcdWro(dcIdWro) =       0       + tempIFE*dhIdWro + tempC*dcIdWro;           dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut]

//     // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim,  dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim,  dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dcIdWpo = (0        + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim,  dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS]

//     // dhdWpi(dhIdWpi) =(   0    + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim,           dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dhdWpf(dhIdWpf) =(   0    + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim,           dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim,           dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS]

//     // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim,           dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim,           dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim,           dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dcdbo(dcIdbo) = (  0   + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim;           dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS]

//     // dhdbi(dhIdbi) = (  0   + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim,           dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dhdbf(dhIdbf) = (  0   + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim,           dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dhdbg(dhIdbg) = (  0   + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim,           dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS]
//     // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim,           dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS]

//     const Nd4jLong nOut = Wx->sizeAt(-1) / 4;

//     NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr);
//     if(Wp) {
//         Wpi = new NDArray((*Wp)({0, nOut}));
//         Wpf = new NDArray((*Wp)({nOut, 2*nOut}));
//         Wpo = new NDArray((*Wp)({2*nOut, 3*nOut}));
//         dhIdWpi = new NDArray((*dhIdWp)({0, nOut}));
//         dhIdWpf = new NDArray((*dhIdWp)({nOut, 2*nOut}));
//         dhIdWpo = new NDArray((*dhIdWp)({2*nOut, 3*nOut}));
//         dcIdWpi = new NDArray((*dcIdWp)({0, nOut}));
//         dcIdWpf = new NDArray((*dcIdWp)({nOut, 2*nOut}));
//         dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut}));
//     }

//     NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr);
//     if(b) {
//         dhIdbi = new NDArray((*dhIdb)({0, nOut}));
//         dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut}));
//         dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut}));
//         dhIdbo = new NDArray((*dhIdb)({3*nOut, 4*nOut}));
//         dcIdbi = new NDArray((*dcIdb)({0, nOut}));
//         dcIdbf = new NDArray((*dcIdb)({nOut, 2*nOut}));
//         dcIdbg = new NDArray((*dcIdb)({2*nOut, 3*nOut}));
//         dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut}));
//     }

//     NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0,  0,nOut,        0,0}) : (*dhIdWx)({0,0,  0,nOut,        0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
//     NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0,  nOut,2*nOut,   0,0}) : (*dhIdWx)({0,0,  nOut,2*nOut,   0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
//     NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0,  2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0,  2*nOut,3*nOut, 0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
//     NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0,  3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0,  3*nOut,4*nOut, 0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr

//     NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0,  0,nOut,        0,0}) : (*dhIdWr)({0,0,  0,nOut,        0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
//     NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0,  nOut,2*nOut,   0,0}) : (*dhIdWr)({0,0,  nOut,2*nOut,   0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
//     NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0,  2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0,  2*nOut,3*nOut, 0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
//     NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0,  3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0,  3*nOut,4*nOut, 0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr

//     NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0,  0,nOut,        0,0}) : (*dcIdWx)({0,0,  0,nOut,        0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
//     NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0,  nOut,2*nOut,   0,0}) : (*dcIdWx)({0,0,  nOut,2*nOut,   0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
//     NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0,  2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0,  2*nOut,3*nOut, 0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
//     NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0,  3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0,  3*nOut,4*nOut, 0,0, 0,0});    // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr

//     NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0,  0,nOut,        0,0}) : (*dcIdWr)({0,0,  0,nOut,        0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
//     NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0,  nOut,2*nOut,   0,0}) : (*dcIdWr)({0,0,  nOut,2*nOut,   0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
//     NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0,  2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0,  2*nOut,3*nOut, 0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
//     NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0,  3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0,  3*nOut,4*nOut, 0,0, 0,0});    // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr

//     NDArray WxiT = (*Wx)({0,0,  0,     nOut}).transpose();              // [nOut, nIn]
//     NDArray WxfT = (*Wx)({0,0,  nOut,  2*nOut}).transpose();            // [nOut, nIn]
//     NDArray WxgT = (*Wx)({0,0,  2*nOut,3*nOut}).transpose();            // [nOut, nIn]
//     NDArray WxoT = (*Wx)({0,0,  3*nOut,4*nOut}).transpose();            // [nOut, nIn]

//     NDArray WriT = (*Wr)({0,0,  0,     nOut}).transpose();              // [nOut, nOut]
//     NDArray WrfT = (*Wr)({0,0,  nOut,  2*nOut}).transpose();            // [nOut, nOut]
//     NDArray WrgT = (*Wr)({0,0,  2*nOut,3*nOut}).transpose();            // [nOut, nOut]
//     NDArray WroT = (*Wr)({0,0,  3*nOut,4*nOut}).transpose();            // [nOut, nOut]

//     // ***** feed forward step ***** //

//     auto z = mmul(*x, *Wx) + mmul(*hI, *Wr);    //   [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
//                                                 //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
//     // add biases if they are given
//     if(b)
//         z += *b;                                    // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut])

//     auto zi = x->rankOf() == 1 ? z({0,        nOut}) : z({0,0, 0,        nOut});         // input gate i, [bS, nOut](or[nOut])
//     auto zf = x->rankOf() == 1 ? z({nOut,   2*nOut}) : z({0,0, nOut,   2*nOut});         // forget gate f, [bS, nOut](or[nOut])
//     auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut});         // cell gate g, [bS, nOut](or[nOut])
//     auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut});         // output gate o, [bS, nOut](or[nOut])

//     // peephole connections for input and forget gates
//     if(Wp) {
//         zi += *cI * *Wpi;      // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
//         zf += *cI * *Wpf;      // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
//     }

//     NDArray i = zi.ulike();        // [bS, nOut]
//     NDArray f = zf.ulike();        // [bS, nOut]
//     NDArray g = zg.ulike();        // [bS, nOut]
//     applyActivation(zi, params[3], params[4], params[5], i);
//     applyActivation(zf, params[3], params[4], params[5], f);
//     applyActivation(zg, params[6], params[7], params[8], g);

//     NDArray c = f * *cI + i * g;          // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])

//     // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
//     if(params[2] != 0)
//         c.applyScalar(scalar::LstmClip, params[2], c);

//     // peephole connections for output gate
//     if(Wp)
//         zo += c * *Wpo;    // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])

//     NDArray o = zo.ulike();     // [bS, nOut](or[nOut])
//     applyActivation(zo, params[3], params[4], params[5], o);

//     // ***** back prop step ***** //

//     NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x);       // [nIn,  nOut, bS, nOut] (or [nIn,  nOut, nOut])
//     NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI);      // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut])

//     // dodzo
//     NDArray dodzo = zo.ulike();                                              // [bS, nOut](or[nOut])
//     activationDeriv(zo, params[3], params[4], params[5], dodzo);

//     // dhdzo = dhdo*dodzo = actH(c)*dodzo
//     NDArray dhdzo = zo.ulike();                                              // [bS, nOut](or[nOut])
//     applyActivation(c, params[9], params[10], params[11], dhdzo);            // actH(c)
//     hI->assign(o*dhdzo);
//     dhdzo *= dodzo;

//     // dcdzi = dcdi*didzi
//     NDArray dcdzi = zi.ulike();                                             // [bS, nOut](or[nOut])
//     activationDeriv(zi, params[3], params[4], params[5], dcdzi);            // didzi
//     dcdzi *= g;                                                             // dcdi = g*clipDeriv

//     // dcdzf = dcdf*dfdzf
//     NDArray dcdzf = zf.ulike();                                             // [bS, nOut](or[nOut])
//     activationDeriv(zf, params[3], params[4], params[5], dcdzf);            // dfdzf
//     dcdzf *= *cI;                                                           // dcdf = cI*clipDeriv

//     // dcdzg = dcde*dedzg
//     NDArray dcdzg = zg.ulike();                                             // [bS, nOut](or[nOut])
//     activationDeriv(zg, params[6], params[7], params[8], dcdzg);            // dedzg
//     dcdzg *= i;                                                             // dcdf = i*clipDeriv

//     // dcdcI
//     NDArray dcdcI = f.dup();                                                // [bS, nOut](or[nOut])

//     // take into account possible deposit from clipping derivative
//     clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI);

//     // dzodc
//     NDArray* dzodc = Wpo;                           // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)

//     // dzidcI
//     NDArray* dzidcI = Wpi;                          // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)

//     // dzfdcI
//     NDArray* dzfdcI = Wpf;                          // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)

//     // dhdc
//     NDArray dhdc = c.ulike();
//     activationDeriv(c, params[9], params[10], params[11], dhdc);    // [bS, nOut]
//     dhdc *= o;
//     if(Wp)
//         dhdc += dhdzo* *dzodc;

//     NDArray factor = *dLdh * dhdc;

//     NDArray iFactor = factor*dcdzi;     // [bS, nOut](or[nOut])
//     NDArray fFactor = factor*dcdzf;     // [bS, nOut](or[nOut])
//     NDArray eFactor = factor*dcdzg;     // [bS, nOut](or[nOut])
//     NDArray oFactor = *dLdh *dhdzo;     // [bS, nOut](or[nOut])

//     NDArray tempC = dcdcI;
//     if(Wp)
//         tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI);

//     // dLdx
//     dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT));    // [bS, nIn](or[nOut])
//     // NDArray temp = c.ulike();
//     // applyActivation(c, params[9], params[10], params[11], temp);            // actH(c)
//     // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT));    // [bS, nIn](or[nOut])

//     // dLdhI
//     NDArray* dLdhII = dLdhI;
//     if(dLdcI && !dLdhI)
//         dLdhII = new NDArray(dLdcI->ulike());
//     dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT));   // [bS, nOut](or[nOut])

//     if(firstIter) {

//         // dLdcI
//         if(dLdcI)
//             dLdcI->assign(factor*tempC);                        // [bS, nOut](or[nOut])

//         // dcIdWxi(dcdWxi)
//         dcIdWxi.assign(dcdzi*dWxJacobian);      // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
//         // dcIdWxf(dcdWxf)
//         dcIdWxf.assign(dcdzf*dWxJacobian);
//         // dcIdWxg(dcdWxg)
//         dcIdWxg.assign(dcdzg*dWxJacobian);
//         // dcIdWxo(dcdWxo) = 0
//         dcIdWxo.nullify();

//         // dhIdWxi
//         dhIdWxi.assign(dhdc*dcIdWxi);           // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
//         // dhIdWxf
//         dhIdWxf.assign(dhdc*dcIdWxf);
//         // dhIdWxg
//         dhIdWxg.assign(dhdc*dcIdWxg);
//         // dhIdWxo
//         dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */);

//         // dcIdWri(dcdWri)
//         dcIdWri.assign(dcdzi*dWrJacobian);      // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);;
//         // dcIdWrf(dcdWrf)
//         dcIdWrf.assign(dcdzf*dWrJacobian);
//         // dcIdWrg(dcdWrg)
//         dcIdWrg.assign(dcdzg*dWrJacobian);
//         // dcIdWro(dcdWro) = 0
//         dcIdWro.nullify();

//         // dhIdWri
//         dhIdWri.assign(dhdc*dcIdWri);           // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
//         // dhIdWrf
//         dhIdWrf.assign(dhdc*dcIdWrf);
//         // dhIdWrg
//         dhIdWrg.assign(dhdc*dcIdWrg);
//         // dhIdWro
//         dhIdWro.assign(dhdzo*dWrJacobian /*+ 0 */);

//         if(Wp && x->rankOf() == 1) {
//             // dcIdWpi
//             dcIdWpi->assign(dcdzi*(*cI));       // [nOut] * [nOut]
//             // dcIdWpf
//             dcIdWpf->assign(dcdzf*(*cI));       // [nOut] * [nOut]
//             // dcIdWpo
//             dcIdWpo->nullify();                 // [nOut]

//             // dhdWpi
//             dhIdWpi->assign(dhdc*(*dcIdWpi));     // [nOut] * [nOut]
//             // dhdWpf
//             dhIdWpf->assign(dhdc*(*dcIdWpf));     // [nOut] * [nOut]
//             // dhdWpo
//             dhIdWpo->assign(dhdzo*c /* +0*/);     // [nOut] * [nOut]
//         }
//         else if(Wp) {
//             // dcIdWpi
//             (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdWpf
//             (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdWpo
//             dcIdWpo->nullify();                                                   // [nOut]

//             // dhIdWpi
//             (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0});     // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dhIdWpf
//             (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0});     // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dhIdWpo
//             (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0});     // ([bS, nOut] * [nOut])->reduce->[nOut]
//         }

//         if(b && x->rankOf() == 1) {
//             // dcIdbi
//             dcIdbi->assign(dcdzi);           // [nOut]
//             // dcIdbf
//             dcIdbf->assign(dcdzf);           // [nOut]
//             // dcIdbg
//             dcIdbg->assign(dcdzg);           // [nOut]
//             // dcIdbo
//             dcIdbo->nullify();               // [nOut]

//             //dhIdbi
//             dhIdbi->assign(dhdc*(*dcIdbi));     // [nOut]
//             //dhIdbf
//             dhIdbf->assign(dhdc*(*dcIdbf));     // [nOut]
//             //dhIdbg
//             dhIdbg->assign(dhdc*(*dcIdbg));     // [nOut]
//             //dhIdbo
//             dhIdbo->assign(dhdzo);              // [nOut]

//         }
//         else if(b) {
//             // dcIdbi
//             dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdbf
//             dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdbg
//             dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdbo
//             dcIdbo->nullify();                  // [nOut]

//             //dhIdbi
//             (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0});     // ([bS, nOut] * [nOut])->reduce->[nOut]
//             //dhIdbf
//             (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0});     // ([bS, nOut] * [nOut])->reduce->[nOut]
//             //dhIdbg
//             (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0});     // ([bS, nOut] * [nOut])->reduce->[nOut]
//             //dhIdbo
//             (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0});     // ([bS, nOut] * [nOut])->reduce->[nOut]

//         }
//     }
//     else {

//         NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT);
//         NDArray tempO   = mmul(dhdzo, WroT);

//         // dLdcI
//         if(dLdcI)
//             dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI));

//         // dcIdWxi(dcdWxi)
//         dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi);       // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
//         // dcIdWxf(dcdWxf)
//         dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf);
//         // dcIdWxg(dcdWxg)
//         dcIdWxg.assign(dcdzg*dWxJacobian + tempIFE*dhIdWxg + tempC*dcIdWxg);
//         // dcIdWxo(dcdWxo)
//         dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo);

//          // dhIdWxi
//         dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi);                           // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
//         // dhIdWxf
//         dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf);
//         // dhIdWxg
//         dhIdWxg.assign(dhdc*dcIdWxg + tempO*dhIdWxg);
//         // dhIdWxo
//         dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo);

//         // dcIdWri(dcdWri)
//         dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri);      // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
//         // dcIdWrf(dcdWrf)
//         dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf);
//         // dcIdWrg(dcdWrg)
//         dcIdWrg.assign(dcdzg*dWrJacobian + tempIFE*dhIdWrg + tempC*dcIdWrg);
//         // dcIdWro(dcdWro)
//         dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro);

//         // dhIdWri
//         dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri);                           // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
//         // dhIdWrf
//         dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf);
//         // dhIdWrg
//         dhIdWrg.assign(dhdc*dcIdWrg + tempO*dhIdWrg);
//         // dhIdWro
//         dhIdWro.assign(dhdzo*dWrJacobian + dhdc*dcIdWro + tempO*dhIdWro);

//         if(Wp && x->rankOf() == 1) {
//             // dcIdWpi
//             dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi));       // [nOut] * [nOut]
//             // dcIdWpf
//             dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf));       // [nOut] * [nOut]
//             // dcIdWpo
//             dcIdWpo->assign(/*  0 +  */   tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo));       // [nOut] * [nOut]

//             // dhdWpi
//             dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi));                // [nOut] * [nOut]
//             // dhdWpf
//             dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf));                // [nOut] * [nOut]
//             // dhdWpo
//             dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo));      // [nOut] * [nOut]
//         }
//         else if(Wp) {
//             // dcIdWpi
//             (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0});   // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dcIdWpf
//             (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0});   // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dcIdWpo
//             (/*  0 +  */   tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0});   // ([bS, nOut] * [nOut])->reduce->[nOut]

//             // dhIdWpi
//             (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0});              // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dhIdWpf
//             (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0});              // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dhIdWpo
//             (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0});    // ([bS, nOut] * [nOut])->reduce->[nOut]
//         }

//         if(b && x->rankOf() == 1) {
//             // dcIdbi
//             dcIdbi->assign(dcdzi  + tempIFE*(*dhIdbi) + tempC*(*dcIdbi));           // [nOut]
//             // dcIdbf
//             dcIdbf->assign(dcdzf  + tempIFE*(*dhIdbf) + tempC*(*dcIdbf));           // [nOut]
//             // dcIdbg
//             dcIdbg->assign(dcdzg  + tempIFE*(*dhIdbg) + tempC*(*dcIdbg));           // [nOut]
//             // dcIdbo
//             dcIdbo->assign(/*0+*/   tempIFE*(*dhIdbo) + tempC*(*dcIdbo));           // [nOut]

//             //dhIdbi
//             dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi));           // [nOut]
//             //dhIdbf
//             dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf));           // [nOut]
//             //dhIdbg
//             dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg));           // [nOut]
//             //dhIdbo
//             dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo));   // [nOut]

//         }
//         else if(b) {
//             // dcIdbi
//             (dcdzi  + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdbf
//             (dcdzf  + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdbg
//             (dcdzg  + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0});       // [bS, nOut]->reduce->[nOut]
//             // dcIdbo
//             (/*0+*/   tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0});       // [bS, nOut]->reduce->[nOut]

//             //dhIdbi
//             (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0});         // ([bS, nOut] * [nOut])->reduce->[nOut]
//             //dhIdbf
//             (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0});         // ([bS, nOut] * [nOut])->reduce->[nOut]
//             //dhIdbg
//             (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0});         // ([bS, nOut] * [nOut])->reduce->[nOut]
//             //dhIdbo
//             (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]

//         }
//     }

//     const std::vector<int> dimsToExclude = x->rankOf() == 1 ? std::vector<int>({2}) : std::vector<int>({2, 3});

//     // dLdWxi, dLdWxf, dLdWxg, dLdWxo
//     (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude);

//     // dLdWri, dLdWrf, dLdWrg, dLdWro
//     (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude);

//     // dLdWpi, dLdWpf, dLdWpo
//     if(Wp) {
//         if(x->rankOf() == 1) {
//             (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi));          // [nOut] * [nOut]
//             (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf));     // [nOut] * [nOut]
//             (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo));   // [nOut] * [nOut]
//         }
//         else {
//             // NDArray temp1 = (*dLdWp)({0, nOut});
//             // NDArray temp2 = (*dLdWp)({nOut, 2*nOut});
//             // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut});
//             // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0});       // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0});       // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0});       // ([bS, nOut] * [nOut])->reduce->[nOut]
//             (*dLdWp)({0, nOut}).assign(dhIdWpi);
//             (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf);
//             (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo);
//         }
//     }

//     // dLdbi, dLdbf, dLdbg, dLdbo
//     if(b) {
//         if(x->rankOf() == 1) {
//             (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi));          // [nOut] * [nOut]
//             (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf));     // [nOut] * [nOut]
//             (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg));   // [nOut] * [nOut]
//             (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo));   // [nOut] * [nOut]
//         }
//         else {
//             // NDArray temp1 = (*dLdb)({0, nOut});
//             // NDArray temp2 = (*dLdb)({nOut, 2*nOut});
//             // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut});
//             // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut});
//             // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0});       // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0});       // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0});       // ([bS, nOut] * [nOut])->reduce->[nOut]
//             // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0});       // ([bS, nOut] * [nOut])->reduce->[nOut]
//             (*dLdb)({0, nOut}).assign(dhIdbi);
//             (*dLdb)({nOut, 2*nOut}).assign(dhIdbf);
//             (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg);
//             (*dLdb)({3*nOut, 4*nOut}).assign(dhIdbo);
//         }
//     }

//     //dhIdcI
//     if(dLdcI)
//         dhIdcI->assign(dhdc);

//     cI->assign(c);

//     if(dLdcI && !dLdhI)
//         delete dLdhII;
//     if(Wp) {
//         delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo;
//     }
//     if(b) {
//         delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo;
//     }
// }