/******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author Yurii Shyrma (iuriish@yahoo.com) // // implementation of operation for LSTM cell with peep hole connections: // http://www.bioinf.jku.at/publications/older/2604.pdf // S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. // and // https://research.google.com/pubs/archive/43905.pdf // Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. #include #include #include #include #include // #include // #include // #include // #include // #include // #include namespace sd { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { switch (opId) { case 0: (const_cast(x)).applyTransform(transform::Tanh, z); break; case 1: (const_cast(x)).applyScalar(scalar::RELU, 0, z); break; case 2: (const_cast(x)).applyTransform(transform::Sigmoid, z); break; case 3: { ExtraArguments args({ static_cast(alpha), static_cast(beta)}); (const_cast(x)).applyTransform(transform::Affine, z, &args); break; } case 4: (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, z); break; case 5: thresholdRelu(x.getContext(), x, alpha, z); break; case 6: { ExtraArguments args({ static_cast(alpha), static_cast(beta)}); (const_cast(x)).applyTransform(transform::ScaledTanh, z, &args); break; } case 7: (const_cast(x)).applyTransform(transform::HardSigmoid, z); break; case 8: (const_cast(x)).applyScalar(scalar::ELU, alpha, z); break; case 9: (const_cast(x)).applyTransform(transform::SoftSign, z); break; case 10: (const_cast(x)).applyTransform(transform::SoftPlus, z); break; default: throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); } } ////////////////////////////////////////////////////////////////////////// static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { switch (opId) { case 0: (const_cast(x)).applyTransform(transform::TanhDerivative, z); break; case 1: (const_cast(x)).applyScalar(scalar::RELUDerivative, 0, z); break; case 2: (const_cast(x)).applyTransform(transform::SigmoidDerivative, z); break; case 3: { z = alpha; break; } case 4: (const_cast(x)).applyScalar(scalar::LeakyRELUDerivative, alpha, z); break; case 5: (const_cast(x)).applyScalar(scalar::RELUDerivative, alpha, z); break; case 6: { auto func = PRAGMA_THREADS_FOR { for(Nd4jLong i = start; i < stop; ++i) { auto val = beta * x.e(i); z.p(i, alpha * beta * (1.f - sd::math::nd4j_tanh(val) * sd::math::nd4j_tanh(val))); } }; samediff::Threads::parallel_for(func, 0, x.lengthOf()); break; } case 7: (const_cast(x)).applyTransform(transform::HardSigmoidDerivative, z); break; case 8: (const_cast(x)).applyScalar(scalar::ELUDerivative, alpha, z); break; case 9: (const_cast(x)).applyTransform(transform::SoftSignDerivative, z); break; case 10: { auto func = PRAGMA_THREADS_FOR { for(Nd4jLong i = start; i < stop; ++i) { auto val = sd::math::nd4j_exp(x.e(i)); z.p(i, val / (1.f + val)); } }; samediff::Threads::parallel_for(func, 0, x.lengthOf()); break; } default: throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); } } ////////////////////////////////////////////////////////////////////////// // FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) { if(clipVal == 0) return; auto func = PRAGMA_THREADS_FOR { for(Nd4jLong i = start; i < stop; ++i) { const auto val = c.e(i); if(val == -clipVal || val == clipVal) { z0.p(i, 0.f); z1.p(i, 0.f); z2.p(i, 0.f); z3.p(i, 0.f); } } }; samediff::Threads::parallel_for(func, 0, c.lengthOf()); } ////////////////////////////////////////////////////////////////////////// static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { if(dataFormat == 0 || dataFormat == 3) return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] if(dataFormat == 1) return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] } ////////////////////////////////////////////////////////////////////////// static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { if(dataFormat == 0 || dataFormat == 3) return t * bS + b; // TNS: shape [sL, bS, nIn] return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] } ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, NDArray* h, NDArray* c) { // * -> means element-wise multiplication // × -> means matrix multiplication /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /** the objective is to provide math-readable code **/ // equations (no peephole connections) // it = σ(Wxi × xt + Wri × ht-1 + bi) // ft = σ(Wxf × xt + Wrf × ht-1 + bf) // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't // ot = σ(Wxo × xt + Wro × ht-1 + bo) // ht = ot * tanh(ct) // equations (peephole connections are present) // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) // ht = ot * tanh(ct) // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus // params[0] - dataFormat, ignore // params[1] - directionMode, ignore // params[2] - cell clipping value, if it = 0 then do not apply clipping // params[3] - activation ID for input (i), forget (f) and output (o) gates // params[4] - alpha value for gates activation // params[5] - beta value for gates activation // params[6] - activation ID for cell state (c) // params[7] - alpha value for cell state activation // params[8] - beta value for cell state activation // params[9] - activation ID for output (h) // params[10] - alpha value for output activation // params[11] - beta value for output activation // INPUTS: // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr // Wx - input weights [nIn, 4*nOut] // Wr - recurrent weights [nOut, 4*nOut] // b - biases [4*nOut], optional, may be nullptr // hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // Wp - peephole weights [3*nOut], optional, may be nullptr // OUTPUTS: // h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr // c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr // !!! dimension 4*nOut implies order it, ft, c't, ot // !!! dimension 3*nOut implies order it, ft, ot const Nd4jLong nOut = Wx->sizeAt(-1) / 4; auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] // add biases if they are given if(b != nullptr) z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) // peephole connections for input and forget gates if(Wp != nullptr) { zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) } applyActivation(zi, params[3], params[4], params[5], zi); // inplace applyActivation(zf, params[3], params[4], params[5], zf); // inplace applyActivation(zg, params[6], params[7], params[8], zg); // inplace c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation if(params[2] != 0) c->applyScalar(scalar::LstmClip, params[2], *c); // peephole connections for output gate if(Wp != nullptr) zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) applyActivation(zo, params[3], params[4], params[5], zo); applyActivation(*c, params[9], params[10], params[11], *h); *h *= zo; // [bS, nOut] * [bS, nOut](or[nOut]) } ////////////////////////////////////////////////////////////////////////// // this auxiliary ff should be running before backprop void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, NDArray* z, NDArray* a, NDArray* h, NDArray* c) { // z - zi, zf, zg, zo // a - i, f, g, o const Nd4jLong nOut = Wx->sizeAt(-1) / 4; z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] // add biases if they are given if(b != nullptr) *z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut] auto zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) auto zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) auto i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut]) auto f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut]) auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut]) auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut]) // peephole connections for input and forget gates if(Wp != nullptr) { zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) } applyActivation(zi, params[3], params[4], params[5], i); applyActivation(zf, params[3], params[4], params[5], f); applyActivation(zg, params[6], params[7], params[8], g); c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation if(params[2] != 0) c->applyScalar(scalar::LstmClip, params[2], *c); // peephole connections for output gate if(Wp != nullptr) zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) applyActivation(zo, params[3], params[4], params[5], o); applyActivation(*c, params[9], params[10], params[11], *h); *h *= o; // [bS, nOut] * [bS, nOut](or[nOut]) } ////////////////////////////////////////////////////////////////////////// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /** the objective is to provide math-readable code **/ // equations (no peephole connections) // zi = x × Wxi + hI × Wri + bi // zf = x × Wxf + hI × Wrf + bf // zg = x × Wxg + hI × Wrg + bg // zo = x × Wxo + hI × Wro + bo // i = act(zi) // f = act(zf) // g = actC(zg) // o = act(zo) // c = clip(f * cI + i * g) // h = o * actH(c) // equations (peephole connections are present) // zi = x × Wxi + hI × Wri + cI * Wpi + bi // zf = x × Wxf + hI × Wrf + cI * Wpf + bf // zg = x × Wxg + hI × Wrg + bg // zo = x × Wxo + hI × Wro + c * Wpo + bo // i = act(zi) // f = act(zf) // g = actC(zg) // o = act(zo) // c = clip(f * cI + i * g) // h = o * actH(c) // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus // params[0] - dataFormat, ignore // params[1] - directionMode, ignore // params[2] - cell clipping value, if it = 0 then do not apply clipping // params[3] - activation ID for input (i), forget (f) and output (o) gates // params[4] - alpha value for gates activation // params[5] - beta value for gates activation // params[6] - activation ID for cell state (c) // params[7] - alpha value for cell state activation // params[8] - beta value for cell state activation // params[9] - activation ID for output (h) // params[10] - alpha value for output activation // params[11] - beta value for output activation // INPUTS: // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr // Wx - input weights [nIn, 4*nOut] // Wr - recurrent weights [nOut, 4*nOut] // b - biases [4*nOut], optional, may be nullptr // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr // Wp - peephole weights [3*nOut], optional, may be nullptr // dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr // dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr // dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] // OUTPUTS: // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] // !!! dimension 4*nOut implies order i, f, g, o // !!! dimension 3*nOut implies order i, f, o // dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut] // dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut] // dLdhI += dLdh; [bS, nOut] // dLdcI += dLdhI * dhdc; [bS, nOut] // dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut]) // dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut]) // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] // dLdcI = dLdcI*dcdcI, [bS, nOut] // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] // dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] // dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] // dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] // dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] // dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] // dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] const Nd4jLong nOut = Wx->sizeAt(-1) / 4; const Nd4jLong nIn = x->sizeAt(-1); NDArray zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) NDArray zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) NDArray i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) NDArray f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut]) NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0,0, 0, nOut}); NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut, 2*nOut}) : dLdz({0,0, nOut, 2*nOut}); NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut}); NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut}); // dcdzi = dcdi*didzi, [bS, nOut](or[nOut]) activationDeriv(zi, params[3], params[4], params[5], dLdzi); // didzi, inplace dLdzi *= g; // dcdi = g*clipDeriv // dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut]) activationDeriv(zf, params[3], params[4], params[5], dLdzf); // dfdzf, inplace dLdzf *= *cI; // dcdf = cI*clipDeriv // dcdzg = dcde*dedzg, [bS, nOut](or[nOut]) activationDeriv(zg, params[6], params[7], params[8], dLdzg); // dgdzg, inplace dLdzg *= i; // dcdf = i*clipDeriv // dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut]) activationDeriv(zo, params[3], params[4], params[5], dLdzo); NDArray temp = dLdzo.ulike(); applyActivation(*c, params[9], params[10], params[11], temp); // actH(c), inplace dLdzo *= temp; // dcdcI NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut]) // take into account possible deposit from clipping derivative clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI); // dhdc NDArray dhdc = c->ulike(); activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut] dhdc *= o; if(Wp) { dhdc += dLdzo*(*Wp)({2*nOut, 3*nOut}); dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut}); // broadcast [bS, nOut] * nOut + ... } if(dLdh) *dLdhI += *dLdh; if(dLdhL) *dLdhI += *dLdhL; if(dLdcL) *dLdcI += *dLdcL; *dLdcI += *dLdhI * dhdc; dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) dLdzg *= *dLdcI; // [bS, nOut](or[nOut]) dLdzo *= *dLdhI; // [bS, nOut](or[nOut]) // dLdx NDArray WxT = Wx->transpose(); MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] ) // dLdhI NDArray WrT = Wr->transpose(); MmulHelper::mmul(&dLdz, &WrT, dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] ) // dLdcI dLdcI->assign(*dLdcI*dcdcI); // [bS, nOut](or[nOut]) if(x->rankOf() == 1) { NDArray xT = x->reshape(x->ordering(),{nIn, 1}); // [nIn] -> [nIn, 1] NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1}); // [nOut] -> [nOut, 1] NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut}); // [nOut] -> [1, 4*nOut] // dLdWx *dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut] // dLdWr *dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut] } else { // dLdWx *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut] // dLdWr *dLdWr += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut] } // dLdb if(b && x->rankOf() == 1) *dLdb += dLdz; // [4*nOut] else if(b) *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut]; // dLdWp if(Wp && x->rankOf() == 1) { (*dLdWp)({ 0,nOut}) += std::move(dLdzi)*(*cI); // [nOut] (*dLdWp)({ nOut,2*nOut}) += std::move(dLdzf)*(*cI); // [nOut] (*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c); // [nOut] } else if(Wp) { NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext()); (std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] (*dLdWp)({0,nOut}) += temp; (std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] (*dLdWp)({nOut,2*nOut}) += temp; (std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut] (*dLdWp)({2*nOut,3*nOut}) += temp; } } ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, const std::vector& params, const bool forward, NDArray* h, NDArray* hL, NDArray* cL) { // INPUTS: // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], // Wx - input weights [nIn, 4*nOut] // Wr - recurrent weights [nOut, 4*nOut] // b - biases [4*nOut], optional, may be nullptr // seqLen - [bS], optional, may be nullptr // hI - initial output [bS, nOut], optional, may be nullptr // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr // Wp - peephole weights [3*nOut], optional, may be nullptr // OUTPUTS: // h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr // hL - output at last step [bS, nOut], optional, may be nullptr // cL - cell state at last step [bS, nOut], optional, may be nullptr // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] const int dataFormat = params[0]; const int directionMode = params[1]; const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); const Nd4jLong nOut = Wx->sizeAt(-1) / 4; const std::vector shapeOut = {bS, nOut}; const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); auto h0 = const_cast(hI); if(!hI) { h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); h0->nullify(); } auto c0 = const_cast(cI); if(!cI) { c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); c0->nullify(); } auto ct = cL; if(!cL) ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); auto ht = hL; if(!h && !hL) ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); // create sets of required (depends on seqLen presence) sub-arrays std::vector dims; ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr); if(!seqLen) { dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] if(h) hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] } else { dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] h0Set = new ResultSet(h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] c0Set = new ResultSet(c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] ctSet = new ResultSet(ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] if(h) hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] if(ht) htSet = new ResultSet(ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] } // loops if(forward) { if(!seqLen) { if(!h) { // seqLen and h are absent lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step for (Nd4jLong t = 1; t < sL; ++t) lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps } else { // seqLen is absent and h is present lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step for (Nd4jLong t = 1; t < sL; ++t) lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps if(hL) hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr } } else { if(!h) { // seqLen is present and h is absent for (Nd4jLong e = 0; e < bS; ++e) { const int limit = seqLen->e(e); if(limit == 0) { if(cL) ctSet->at(e)->nullify(); if(hL) htSet->at(e)->nullify(); continue; } auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step for (int t = 1; t < limit; ++t) { ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps } } } else { // seqLen and h are present for (Nd4jLong e = 0; e < bS; ++e) { int limit = seqLen->e(e); if(limit == 0) { tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range if(cL) ctSet->at(e)->nullify(); if(hL) htSet->at(e)->nullify(); continue; } auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step for (int t = 1; t < limit; ++t) { auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps indPrev = indCurr; } if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr if(limit != sL) tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } } else { // backward if(!seqLen) { if(!h) { // seqLen and h are absent lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step for (Nd4jLong t = sL - 2; t >= 0; --t) lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps } else { // seqLen is absent and h is present lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step for (Nd4jLong t = sL - 2; t >= 0; --t) lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps if(hL) hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr } } else if(directionMode == 1) { // only backward, no bidirectional mode if(!h) { // h is absent and seqLen is present for (Nd4jLong e = 0; e < bS; ++e) { const int limit = seqLen->e(e); if(limit == 0) { if(cL) ctSet->at(e)->nullify(); if(hL) htSet->at(e)->nullify(); continue; } auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps } } } else { // seqLen and h are present for (Nd4jLong e = 0; e < bS; ++e) { int limit = seqLen->e(e); if(limit == 0) { tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range if(cL) ctSet->at(e)->nullify(); if(hL) htSet->at(e)->nullify(); continue; } auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step for (Nd4jLong t = sL - 2; t >= sL - limit; --t) { auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps indPrev = indCurr; } if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr if(limit != sL) tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } else { // backward in bidirectional mode if(!h) { // h is absent and seqLen is present for (Nd4jLong e = 0; e < bS; ++e) { const int limit = seqLen->e(e); if(limit == 0) { if(cL) ctSet->at(e)->nullify(); if(hL) htSet->at(e)->nullify(); continue; } auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step for (int t = limit - 2; t >= 0; --t) { ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps } } } else { // seqLen and h are present for (Nd4jLong e = 0; e < bS; ++e) { int limit = seqLen->e(e); if(limit == 0) { tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range if(cL) ctSet->at(e)->nullify(); if(hL) htSet->at(e)->nullify(); continue; } auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step for (int t = limit - 2; t >= 0; --t) { auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps indPrev = indCurr; } if(hL) htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr if(limit != sL) tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } } delete xSet; delete hSet; delete h0Set; delete c0Set; delete htSet; delete ctSet; if(!hI) delete h0; if(!cI) delete c0; if(!cL) delete ct; if(!h && !hL) delete ht; } ////////////////////////////////////////////////////////////////////////// void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const std::vector& params, const bool forward, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) { // INPUTS: // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], // Wx - input weights [nIn, 4*nOut] // Wr - recurrent weights [nOut, 4*nOut] // b - biases [4*nOut], optional, may be nullptr // seqLen - [bS], optional, may be nullptr // hI - initial output [bS, nOut], optional, may be nullptr // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr // Wp - peephole weights [3*nOut], optional, may be nullptr // dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr // dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr // dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr // OUTPUTS: // dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL] // dLdWx - gradient vs. input weights [nIn, 4*nOut] // dLdWr - gradient vs. recurrent weights [nOut, 4*nOut] // dLdb - gradient vs. biases [4*nOut], optional, may be nullptr // dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr // dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr // dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] const int dataFormat = params[0]; const int directionMode = params[1]; const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); const int nOut = Wx->sizeAt(-1) / 4; const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); auto dLdh0 = dLdhI; if(!hI) dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically auto dLdc0 = dLdcI; if(!cI) dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext()); NDArray a = z.ulike(); NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext()); NDArray c = h.ulike(); // create sets of required (depends on seqLen presence) sub-arrays std::vector dims; ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr), *dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr); if(!seqLen) { dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut] xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn] hSet = new ResultSet(h.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] cSet = new ResultSet(c.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut] zSet = new ResultSet(z.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] aSet = new ResultSet(a.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut] if(dLdh) dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut] } else { dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn] hSet = new ResultSet(h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] cSet = new ResultSet(c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut] zSet = new ResultSet(z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] aSet = new ResultSet(a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut] if(hI) hISet = new ResultSet(hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] if(cI) cISet = new ResultSet(cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] if(dLdh) dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] if(dLdhL) dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] if(dLdcL) dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] } // loops if(forward) { if(!seqLen) { // seqLen is absent if(hI) hSet->at(0)->assign(hI); else hSet->at(0)->nullify(); if(cI) cSet->at(0)->assign(cI); else cSet->at(0)->nullify(); // ff for (int t = 0; t < sL; ++t) lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1)); // bp for (int t = sL-1; t >= 0; --t) { const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } else { // seqLen is present for (int e = 0; e < bS; ++e) { const int limit = seqLen->e(e); if(limit == 0) { tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range continue; } if(hI) hSet->at(e)->assign(hISet->at(e)); else hSet->at(e)->nullify(); if(cI) cSet->at(e)->assign(cISet->at(e)); else cSet->at(e)->nullify(); // ff for (int t = 0; t < limit; ++t) lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params, zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e)); // bp for (int t = limit-1; t >= 0; --t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr; const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr; lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } if(limit != sL) tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } else { // backward or bidirectional if(!seqLen) { // backward or bidirectional, seqLen is absent if(hI) hSet->at(sL)->assign(hI); else hSet->at(sL)->nullify(); if(cI) cSet->at(sL)->assign(cI); else cSet->at(sL)->nullify(); // ff for (int t = sL-1; t >= 0; --t) lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t)); // bp for (int t = 0; t < sL; ++t) { const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } else if(directionMode == 1) { // backward, seqLen is present for (int e = 0; e < bS; ++e) { const int limit = seqLen->e(e); if(limit == 0) { tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range continue; } if(hI) hSet->at(sL*bS + e)->assign(hISet->at(e)); else hSet->at(sL*bS + e)->nullify(); if(cI) cSet->at(sL*bS + e)->assign(cISet->at(e)); else cSet->at(sL*bS + e)->nullify(); // ff for (int t = sL - 1; t >= sL-limit; --t) lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); // bp for (int t = sL-limit; t < sL; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr; const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr; lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } if(limit != sL) tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } else { // bidirectional mode, seqLen is present for (int e = 0; e < bS; ++e) { const int limit = seqLen->e(e); if(limit == 0) { tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range continue; } if(hI) h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e)); else h({limit,limit+1, e,e+1, 0,0}).nullify(); if(cI) c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e)); else c({limit,limit+1, e,e+1, 0,0}).nullify(); // ff for (int t = limit - 1; t >= 0; --t) lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params, zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e)); // bp for (int t = 0; t < limit; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr; const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr; lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } if(limit != sL) tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) } } } delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet; delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet; if(!hI) delete dLdh0; if(!cI) delete dLdc0; } } } } ////////////////////////////////////////////////////////////////////////// // void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // const NDArray* b, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh, // const std::vector& params, const bool firstIter, // NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr, // NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp, // NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { // /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ // /** the objective is to provide math-readable code **/ // // equations (no peephole connections) // // zi = x × Wxi + hI × Wri + bi // // zf = x × Wxf + hI × Wrf + bf // // zg = x × Wxg + hI × Wrg + bg // // zo = x × Wxo + hI × Wro + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) // // o = act(zo) // // c = clip(f * cI + i * g) // // h = o * actH(c) // // equations (peephole connections are present) // // zi = x × Wxi + hI × Wri + cI * Wpi + bi // // zf = x × Wxf + hI × Wrf + cI * Wpf + bf // // zg = x × Wxg + hI × Wrg + bg // // zo = x × Wxo + hI × Wro + c * Wpo + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) // // o = act(zo) // // c = clip(f * cI + i * g) // // h = o * actH(c) // // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus // // params[0] - dataFormat, ignore // // params[1] - directionMode, ignore // // params[2] - cell clipping value, if it = 0 then do not apply clipping // // params[3] - activation ID for input (i), forget (f) and output (o) gates // // params[4] - alpha value for gates activation // // params[5] - beta value for gates activation // // params[6] - activation ID for cell state (c) // // params[7] - alpha value for cell state activation // // params[8] - beta value for cell state activation // // params[9] - activation ID for output (h) // // params[10] - alpha value for output activation // // params[11] - beta value for output activation // // INPUTS: // // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr // // Wx - input weights [nIn, 4*nOut] // // Wr - recurrent weights [nOut, 4*nOut] // // b - biases [4*nOut], optional, may be nullptr // // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr // // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr // // Wp - peephole weights [3*nOut], optional, may be nullptr // // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr // // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr // // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr // // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr // // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr // // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr // // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr // // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr // // dcIdb - derivative from previous time step, [4*nOut], optional, may be nullptr // // dhIdb - derivative from previous time step, [4*nOut], optional, may be nullptr // // OUTPUTS: // // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr // // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut] // // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut] // // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut] // // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr // // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut] // // !!! dimension 4*nOut implies order i, f, g, o // // !!! dimension 3*nOut implies order i, f, o // // dcdzi = dcdi*didzi // // dcdzf = dcdf*dfdzf // // dcdzg = dcdg*dgdzg // // dhdzo = dhdo*dodzo // // dhdc = dhdc + Wp ? dhdzo*dzodc : 0 [bS, nOut] // // factor = dLdh*dhdc [bS, nOut] // // iFactor = factor*dcdzi [bS, nOut] // // fFactor = factor*dcdzf [bS, nOut] // // eFactor = factor*dcdzg [bS, nOut] // // oFactor = *dLdh*dhdzo [bS, nOut] // // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0; // // tempIFE = dcdzi×WriT + dcdzf×WrfT + dcdzg×WrgT // // tempO = dhdzo×WroT // // dhIdcI = dhdc_from_previous_time_step // // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn] // // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut] // // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] // // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] // // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] // // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] // // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] // // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] // // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut] // // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut] // // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut] // // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] // // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut] // // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut] // // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] // // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] // // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] // // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut] // // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut] // // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dcIdWpo = (0 + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS] // // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS] // const Nd4jLong nOut = Wx->sizeAt(-1) / 4; // NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr); // if(Wp) { // Wpi = new NDArray((*Wp)({0, nOut})); // Wpf = new NDArray((*Wp)({nOut, 2*nOut})); // Wpo = new NDArray((*Wp)({2*nOut, 3*nOut})); // dhIdWpi = new NDArray((*dhIdWp)({0, nOut})); // dhIdWpf = new NDArray((*dhIdWp)({nOut, 2*nOut})); // dhIdWpo = new NDArray((*dhIdWp)({2*nOut, 3*nOut})); // dcIdWpi = new NDArray((*dcIdWp)({0, nOut})); // dcIdWpf = new NDArray((*dcIdWp)({nOut, 2*nOut})); // dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut})); // } // NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr); // if(b) { // dhIdbi = new NDArray((*dhIdb)({0, nOut})); // dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut})); // dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut})); // dhIdbo = new NDArray((*dhIdb)({3*nOut, 4*nOut})); // dcIdbi = new NDArray((*dcIdb)({0, nOut})); // dcIdbf = new NDArray((*dcIdb)({nOut, 2*nOut})); // dcIdbg = new NDArray((*dcIdb)({2*nOut, 3*nOut})); // dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut})); // } // NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr // NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr // NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // [nOut, nIn] // NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nIn] // NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nIn] // NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nIn] // NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // [nOut, nOut] // NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nOut] // NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nOut] // NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nOut] // // ***** feed forward step ***** // // auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] // //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] // // add biases if they are given // if(b) // z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut]) // auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut]) // auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut]) // auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut]) // auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut]) // // peephole connections for input and forget gates // if(Wp) { // zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) // zf += *cI * *Wpf; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) // } // NDArray i = zi.ulike(); // [bS, nOut] // NDArray f = zf.ulike(); // [bS, nOut] // NDArray g = zg.ulike(); // [bS, nOut] // applyActivation(zi, params[3], params[4], params[5], i); // applyActivation(zf, params[3], params[4], params[5], f); // applyActivation(zg, params[6], params[7], params[8], g); // NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut]) // // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation // if(params[2] != 0) // c.applyScalar(scalar::LstmClip, params[2], c); // // peephole connections for output gate // if(Wp) // zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut]) // NDArray o = zo.ulike(); // [bS, nOut](or[nOut]) // applyActivation(zo, params[3], params[4], params[5], o); // // ***** back prop step ***** // // NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, nOut, bS, nOut] (or [nIn, nOut, nOut]) // NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut]) // // dodzo // NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut]) // activationDeriv(zo, params[3], params[4], params[5], dodzo); // // dhdzo = dhdo*dodzo = actH(c)*dodzo // NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut]) // applyActivation(c, params[9], params[10], params[11], dhdzo); // actH(c) // hI->assign(o*dhdzo); // dhdzo *= dodzo; // // dcdzi = dcdi*didzi // NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut]) // activationDeriv(zi, params[3], params[4], params[5], dcdzi); // didzi // dcdzi *= g; // dcdi = g*clipDeriv // // dcdzf = dcdf*dfdzf // NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut]) // activationDeriv(zf, params[3], params[4], params[5], dcdzf); // dfdzf // dcdzf *= *cI; // dcdf = cI*clipDeriv // // dcdzg = dcde*dedzg // NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut]) // activationDeriv(zg, params[6], params[7], params[8], dcdzg); // dedzg // dcdzg *= i; // dcdf = i*clipDeriv // // dcdcI // NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut]) // // take into account possible deposit from clipping derivative // clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI); // // dzodc // NDArray* dzodc = Wpo; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) // // dzidcI // NDArray* dzidcI = Wpi; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) // // dzfdcI // NDArray* dzfdcI = Wpf; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication) // // dhdc // NDArray dhdc = c.ulike(); // activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, nOut] // dhdc *= o; // if(Wp) // dhdc += dhdzo* *dzodc; // NDArray factor = *dLdh * dhdc; // NDArray iFactor = factor*dcdzi; // [bS, nOut](or[nOut]) // NDArray fFactor = factor*dcdzf; // [bS, nOut](or[nOut]) // NDArray eFactor = factor*dcdzg; // [bS, nOut](or[nOut]) // NDArray oFactor = *dLdh *dhdzo; // [bS, nOut](or[nOut]) // NDArray tempC = dcdcI; // if(Wp) // tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI); // // dLdx // dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut]) // // NDArray temp = c.ulike(); // // applyActivation(c, params[9], params[10], params[11], temp); // actH(c) // // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // [bS, nIn](or[nOut]) // // dLdhI // NDArray* dLdhII = dLdhI; // if(dLdcI && !dLdhI) // dLdhII = new NDArray(dLdcI->ulike()); // dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut]) // if(firstIter) { // // dLdcI // if(dLdcI) // dLdcI->assign(factor*tempC); // [bS, nOut](or[nOut]) // // dcIdWxi(dcdWxi) // dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dcIdWxf(dcdWxf) // dcIdWxf.assign(dcdzf*dWxJacobian); // // dcIdWxg(dcdWxg) // dcIdWxg.assign(dcdzg*dWxJacobian); // // dcIdWxo(dcdWxo) = 0 // dcIdWxo.nullify(); // // dhIdWxi // dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dhIdWxf // dhIdWxf.assign(dhdc*dcIdWxf); // // dhIdWxg // dhIdWxg.assign(dhdc*dcIdWxg); // // dhIdWxo // dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */); // // dcIdWri(dcdWri) // dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);; // // dcIdWrf(dcdWrf) // dcIdWrf.assign(dcdzf*dWrJacobian); // // dcIdWrg(dcdWrg) // dcIdWrg.assign(dcdzg*dWrJacobian); // // dcIdWro(dcdWro) = 0 // dcIdWro.nullify(); // // dhIdWri // dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]); // // dhIdWrf // dhIdWrf.assign(dhdc*dcIdWrf); // // dhIdWrg // dhIdWrg.assign(dhdc*dcIdWrg); // // dhIdWro // dhIdWro.assign(dhdzo*dWrJacobian /*+ 0 */); // if(Wp && x->rankOf() == 1) { // // dcIdWpi // dcIdWpi->assign(dcdzi*(*cI)); // [nOut] * [nOut] // // dcIdWpf // dcIdWpf->assign(dcdzf*(*cI)); // [nOut] * [nOut] // // dcIdWpo // dcIdWpo->nullify(); // [nOut] // // dhdWpi // dhIdWpi->assign(dhdc*(*dcIdWpi)); // [nOut] * [nOut] // // dhdWpf // dhIdWpf->assign(dhdc*(*dcIdWpf)); // [nOut] * [nOut] // // dhdWpo // dhIdWpo->assign(dhdzo*c /* +0*/); // [nOut] * [nOut] // } // else if(Wp) { // // dcIdWpi // (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdWpf // (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdWpo // dcIdWpo->nullify(); // [nOut] // // dhIdWpi // (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpf // (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpo // (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // if(b && x->rankOf() == 1) { // // dcIdbi // dcIdbi->assign(dcdzi); // [nOut] // // dcIdbf // dcIdbf->assign(dcdzf); // [nOut] // // dcIdbg // dcIdbg->assign(dcdzg); // [nOut] // // dcIdbo // dcIdbo->nullify(); // [nOut] // //dhIdbi // dhIdbi->assign(dhdc*(*dcIdbi)); // [nOut] // //dhIdbf // dhIdbf->assign(dhdc*(*dcIdbf)); // [nOut] // //dhIdbg // dhIdbg->assign(dhdc*(*dcIdbg)); // [nOut] // //dhIdbo // dhIdbo->assign(dhdzo); // [nOut] // } // else if(b) { // // dcIdbi // dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdbf // dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdbg // dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdbo // dcIdbo->nullify(); // [nOut] // //dhIdbi // (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbf // (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbg // (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbo // (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // } // else { // NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT); // NDArray tempO = mmul(dhdzo, WroT); // // dLdcI // if(dLdcI) // dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI)); // // dcIdWxi(dcdWxi) // dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dcIdWxf(dcdWxf) // dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf); // // dcIdWxg(dcdWxg) // dcIdWxg.assign(dcdzg*dWxJacobian + tempIFE*dhIdWxg + tempC*dcIdWxg); // // dcIdWxo(dcdWxo) // dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo); // // dhIdWxi // dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dhIdWxf // dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf); // // dhIdWxg // dhIdWxg.assign(dhdc*dcIdWxg + tempO*dhIdWxg); // // dhIdWxo // dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo); // // dcIdWri(dcdWri) // dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dcIdWrf(dcdWrf) // dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf); // // dcIdWrg(dcdWrg) // dcIdWrg.assign(dcdzg*dWrJacobian + tempIFE*dhIdWrg + tempC*dcIdWrg); // // dcIdWro(dcdWro) // dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro); // // dhIdWri // dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]); // // dhIdWrf // dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf); // // dhIdWrg // dhIdWrg.assign(dhdc*dcIdWrg + tempO*dhIdWrg); // // dhIdWro // dhIdWro.assign(dhdzo*dWrJacobian + dhdc*dcIdWro + tempO*dhIdWro); // if(Wp && x->rankOf() == 1) { // // dcIdWpi // dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)); // [nOut] * [nOut] // // dcIdWpf // dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)); // [nOut] * [nOut] // // dcIdWpo // dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)); // [nOut] * [nOut] // // dhdWpi // dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * [nOut] // // dhdWpf // dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * [nOut] // // dhdWpo // dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // [nOut] * [nOut] // } // else if(Wp) { // // dcIdWpi // (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dcIdWpf // (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dcIdWpo // (/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpi // (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpf // (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpo // (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // if(b && x->rankOf() == 1) { // // dcIdbi // dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // [nOut] // // dcIdbf // dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // [nOut] // // dcIdbg // dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // [nOut] // // dcIdbo // dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // [nOut] // //dhIdbi // dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // [nOut] // //dhIdbf // dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // [nOut] // //dhIdbg // dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // [nOut] // //dhIdbo // dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // [nOut] // } // else if(b) { // // dcIdbi // (dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdbf // (dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdbg // (dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut] // // dcIdbo // (/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); // [bS, nOut]->reduce->[nOut] // //dhIdbi // (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbf // (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbg // (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // //dhIdbo // (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // } // } // const std::vector dimsToExclude = x->rankOf() == 1 ? std::vector({2}) : std::vector({2, 3}); // // dLdWxi, dLdWxf, dLdWxg, dLdWxo // (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude); // // dLdWri, dLdWrf, dLdWrg, dLdWro // (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude); // // dLdWpi, dLdWpf, dLdWpo // if(Wp) { // if(x->rankOf() == 1) { // (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] * [nOut] // (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] * [nOut] // (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] * [nOut] // } // else { // // NDArray temp1 = (*dLdWp)({0, nOut}); // // NDArray temp2 = (*dLdWp)({nOut, 2*nOut}); // // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut}); // // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // (*dLdWp)({0, nOut}).assign(dhIdWpi); // (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf); // (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo); // } // } // // dLdbi, dLdbf, dLdbg, dLdbo // if(b) { // if(x->rankOf() == 1) { // (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * [nOut] // (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * [nOut] // (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * [nOut] // (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * [nOut] // } // else { // // NDArray temp1 = (*dLdb)({0, nOut}); // // NDArray temp2 = (*dLdb)({nOut, 2*nOut}); // // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut}); // // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut}); // // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut] // (*dLdb)({0, nOut}).assign(dhIdbi); // (*dLdb)({nOut, 2*nOut}).assign(dhIdbf); // (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg); // (*dLdb)({3*nOut, 4*nOut}).assign(dhIdbo); // } // } // //dhIdcI // if(dLdcI) // dhIdcI->assign(dhdc); // cI->assign(c); // if(dLdcI && !dLdhI) // delete dLdhII; // if(Wp) { // delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo; // } // if(b) { // delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo; // } // }