1817 lines
91 KiB
C++
1817 lines
91 KiB
C++
/*******************************************************************************
|
||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||
* Copyright (c) 2020 Konduit K.K.
|
||
*
|
||
* This program and the accompanying materials are made available under the
|
||
* terms of the Apache License, Version 2.0 which is available at
|
||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
*
|
||
* Unless required by applicable law or agreed to in writing, software
|
||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
* License for the specific language governing permissions and limitations
|
||
* under the License.
|
||
*
|
||
* SPDX-License-Identifier: Apache-2.0
|
||
******************************************************************************/
|
||
|
||
//
|
||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||
//
|
||
|
||
// implementation of operation for LSTM cell with peep hole connections:
|
||
// http://www.bioinf.jku.at/publications/older/2604.pdf
|
||
// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
|
||
// and
|
||
// https://research.google.com/pubs/archive/43905.pdf
|
||
// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
|
||
|
||
|
||
#include <ops/declarable/helpers/lstmLayer.h>
|
||
#include <execution/Threads.h>
|
||
#include <ops/declarable/helpers/activations.h>
|
||
#include <helpers/ShapeUtils.h>
|
||
#include <helpers/MmulHelper.h>
|
||
// #include <VariableSpace.h>
|
||
// #include <ops/declarable/CustomOperations.h>
|
||
// #include<ops/declarable/helpers/transforms.h>
|
||
// #include <ops/declarable/helpers/legacy_helpers.h>
|
||
// #include <array/NDArrayList.h>
|
||
// #include <iterator>
|
||
|
||
|
||
namespace sd {
|
||
namespace ops {
|
||
namespace helpers {
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
static void applyActivation(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
|
||
|
||
switch (opId) {
|
||
case 0:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, z);
|
||
break;
|
||
case 1:
|
||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, z);
|
||
break;
|
||
case 2:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, z);
|
||
break;
|
||
case 3: {
|
||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, z, &args);
|
||
break;
|
||
}
|
||
case 4:
|
||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, z);
|
||
break;
|
||
case 5:
|
||
thresholdRelu(x.getContext(), x, alpha, z);
|
||
break;
|
||
case 6: {
|
||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, z, &args);
|
||
break;
|
||
}
|
||
case 7:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, z);
|
||
break;
|
||
case 8:
|
||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, z);
|
||
break;
|
||
case 9:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, z);
|
||
break;
|
||
case 10:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, z);
|
||
break;
|
||
default:
|
||
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
|
||
}
|
||
}
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
static void activationDeriv(const NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
|
||
|
||
switch (opId) {
|
||
case 0:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::TanhDerivative, z);
|
||
break;
|
||
case 1:
|
||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELUDerivative, 0, z);
|
||
break;
|
||
case 2:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::SigmoidDerivative, z);
|
||
break;
|
||
case 3: {
|
||
z = alpha;
|
||
break;
|
||
}
|
||
case 4:
|
||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELUDerivative, alpha, z);
|
||
break;
|
||
case 5:
|
||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELUDerivative, alpha, z);
|
||
break;
|
||
case 6: {
|
||
auto func = PRAGMA_THREADS_FOR {
|
||
for(Nd4jLong i = start; i < stop; ++i) {
|
||
auto val = beta * x.e<float>(i);
|
||
z.p<float>(i, alpha * beta * (1.f - sd::math::nd4j_tanh<float,float>(val) * sd::math::nd4j_tanh<float,float>(val)));
|
||
}
|
||
};
|
||
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
||
break;
|
||
}
|
||
case 7:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoidDerivative, z);
|
||
break;
|
||
case 8:
|
||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELUDerivative, alpha, z);
|
||
break;
|
||
case 9:
|
||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSignDerivative, z);
|
||
break;
|
||
case 10: {
|
||
auto func = PRAGMA_THREADS_FOR {
|
||
for(Nd4jLong i = start; i < stop; ++i) {
|
||
auto val = sd::math::nd4j_exp<float, float>(x.e<float>(i));
|
||
z.p<float>(i, val / (1.f + val));
|
||
}
|
||
};
|
||
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
||
break;
|
||
}
|
||
default:
|
||
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
|
||
}
|
||
}
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
// FIXME - derivative undefined when not-clipped c has element/elements equal to -clipVal or clipVal
|
||
static void clipDeriv(const float clipVal, const NDArray& c, NDArray& z0, NDArray& z1, NDArray& z2, NDArray& z3) {
|
||
|
||
if(clipVal == 0)
|
||
return;
|
||
|
||
auto func = PRAGMA_THREADS_FOR {
|
||
for(Nd4jLong i = start; i < stop; ++i) {
|
||
const auto val = c.e<float>(i);
|
||
if(val == -clipVal || val == clipVal) {
|
||
z0.p<float>(i, 0.f);
|
||
z1.p<float>(i, 0.f);
|
||
z2.p<float>(i, 0.f);
|
||
z3.p<float>(i, 0.f);
|
||
}
|
||
}
|
||
};
|
||
samediff::Threads::parallel_for(func, 0, c.lengthOf());
|
||
}
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
static NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
|
||
|
||
if(dataFormat == 0 || dataFormat == 3)
|
||
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
|
||
|
||
if(dataFormat == 1)
|
||
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
|
||
|
||
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
|
||
}
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
|
||
|
||
if(dataFormat == 0 || dataFormat == 3)
|
||
return t * bS + b; // TNS: shape [sL, bS, nIn]
|
||
|
||
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
|
||
}
|
||
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||
const std::vector<float>& params,
|
||
NDArray* h, NDArray* c) {
|
||
|
||
// * -> means element-wise multiplication
|
||
// × -> means matrix multiplication
|
||
|
||
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||
/** the objective is to provide math-readable code **/
|
||
|
||
// equations (no peephole connections)
|
||
// it = σ(Wxi × xt + Wri × ht-1 + bi)
|
||
// ft = σ(Wxf × xt + Wrf × ht-1 + bf)
|
||
// c't = tanh(Wxc × xt + Wrc × ht-1 + bc)
|
||
// ct = ft * ct-1 + it * c't
|
||
// ot = σ(Wxo × xt + Wro × ht-1 + bo)
|
||
// ht = ot * tanh(ct)
|
||
|
||
// equations (peephole connections are present)
|
||
// it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi)
|
||
// ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf)
|
||
// c't = tanh(Wxc × xt + Wrc × ht-1 + bc)
|
||
// ct = ft * ct-1 + it * c't
|
||
// ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo)
|
||
// ht = ot * tanh(ct)
|
||
|
||
|
||
// IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||
|
||
// params[0] - dataFormat, ignore
|
||
// params[1] - directionMode, ignore
|
||
// params[2] - cell clipping value, if it = 0 then do not apply clipping
|
||
|
||
// params[3] - activation ID for input (i), forget (f) and output (o) gates
|
||
// params[4] - alpha value for gates activation
|
||
// params[5] - beta value for gates activation
|
||
|
||
// params[6] - activation ID for cell state (c)
|
||
// params[7] - alpha value for cell state activation
|
||
// params[8] - beta value for cell state activation
|
||
|
||
// params[9] - activation ID for output (h)
|
||
// params[10] - alpha value for output activation
|
||
// params[11] - beta value for output activation
|
||
|
||
// INPUTS:
|
||
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
|
||
// Wx - input weights [nIn, 4*nOut]
|
||
// Wr - recurrent weights [nOut, 4*nOut]
|
||
// b - biases [4*nOut], optional, may be nullptr
|
||
// hI - (ht-1) previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// cI - (ct-1) previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// Wp - peephole weights [3*nOut], optional, may be nullptr
|
||
|
||
// OUTPUTS:
|
||
// h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
|
||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||
// !!! dimension 3*nOut implies order it, ft, ot
|
||
|
||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||
|
||
auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
|
||
//or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
|
||
|
||
// add biases if they are given
|
||
if(b != nullptr)
|
||
z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut]
|
||
|
||
auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut])
|
||
auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut])
|
||
auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut])
|
||
auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut])
|
||
|
||
// peephole connections for input and forget gates
|
||
if(Wp != nullptr) {
|
||
zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
}
|
||
|
||
applyActivation(zi, params[3], params[4], params[5], zi); // inplace
|
||
applyActivation(zf, params[3], params[4], params[5], zf); // inplace
|
||
applyActivation(zg, params[6], params[7], params[8], zg); // inplace
|
||
|
||
c->assign(zf * *cI + zi * zg); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])
|
||
|
||
// if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
|
||
if(params[2] != 0)
|
||
c->applyScalar(scalar::LstmClip, params[2], *c);
|
||
|
||
// peephole connections for output gate
|
||
if(Wp != nullptr)
|
||
zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
|
||
applyActivation(zo, params[3], params[4], params[5], zo);
|
||
|
||
applyActivation(*c, params[9], params[10], params[11], *h);
|
||
*h *= zo; // [bS, nOut] * [bS, nOut](or[nOut])
|
||
}
|
||
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
// this auxiliary ff should be running before backprop
|
||
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||
const std::vector<float>& params,
|
||
NDArray* z, NDArray* a, NDArray* h, NDArray* c) {
|
||
|
||
// z - zi, zf, zg, zo
|
||
// a - i, f, g, o
|
||
|
||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||
|
||
z->assign(mmul(*x, *Wx) + mmul(*hI, *Wr)); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
|
||
//or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
|
||
// add biases if they are given
|
||
if(b != nullptr)
|
||
*z += *b; // broadcast [bS, 4*nOut](or[4*nOut]) + [4*nOut] = [bS, 4*nOut]
|
||
|
||
auto zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut])
|
||
auto zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut])
|
||
auto zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut])
|
||
auto zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut])
|
||
|
||
auto i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate it, [bS, nOut](or[nOut])
|
||
auto f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut](or[nOut])
|
||
auto g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut](or[nOut])
|
||
auto o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut](or[nOut])
|
||
|
||
// peephole connections for input and forget gates
|
||
if(Wp != nullptr) {
|
||
zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
}
|
||
|
||
applyActivation(zi, params[3], params[4], params[5], i);
|
||
applyActivation(zf, params[3], params[4], params[5], f);
|
||
applyActivation(zg, params[6], params[7], params[8], g);
|
||
|
||
c->assign(f * *cI + i * g); // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])
|
||
|
||
// if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
|
||
if(params[2] != 0)
|
||
c->applyScalar(scalar::LstmClip, params[2], *c);
|
||
|
||
// peephole connections for output gate
|
||
if(Wp != nullptr)
|
||
zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
|
||
applyActivation(zo, params[3], params[4], params[5], o);
|
||
|
||
applyActivation(*c, params[9], params[10], params[11], *h);
|
||
*h *= o; // [bS, nOut] * [bS, nOut](or[nOut])
|
||
}
|
||
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||
const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
|
||
const NDArray* z, const NDArray* a, const NDArray* c, const std::vector<float>& params,
|
||
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) {
|
||
|
||
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||
/** the objective is to provide math-readable code **/
|
||
|
||
// equations (no peephole connections)
|
||
// zi = x × Wxi + hI × Wri + bi
|
||
// zf = x × Wxf + hI × Wrf + bf
|
||
// zg = x × Wxg + hI × Wrg + bg
|
||
// zo = x × Wxo + hI × Wro + bo
|
||
// i = act(zi)
|
||
// f = act(zf)
|
||
// g = actC(zg)
|
||
// o = act(zo)
|
||
// c = clip(f * cI + i * g)
|
||
// h = o * actH(c)
|
||
|
||
// equations (peephole connections are present)
|
||
// zi = x × Wxi + hI × Wri + cI * Wpi + bi
|
||
// zf = x × Wxf + hI × Wrf + cI * Wpf + bf
|
||
// zg = x × Wxg + hI × Wrg + bg
|
||
// zo = x × Wxo + hI × Wro + c * Wpo + bo
|
||
// i = act(zi)
|
||
// f = act(zf)
|
||
// g = actC(zg)
|
||
// o = act(zo)
|
||
// c = clip(f * cI + i * g)
|
||
// h = o * actH(c)
|
||
|
||
// IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||
|
||
// params[0] - dataFormat, ignore
|
||
// params[1] - directionMode, ignore
|
||
// params[2] - cell clipping value, if it = 0 then do not apply clipping
|
||
|
||
// params[3] - activation ID for input (i), forget (f) and output (o) gates
|
||
// params[4] - alpha value for gates activation
|
||
// params[5] - beta value for gates activation
|
||
|
||
// params[6] - activation ID for cell state (c)
|
||
// params[7] - alpha value for cell state activation
|
||
// params[8] - beta value for cell state activation
|
||
|
||
// params[9] - activation ID for output (h)
|
||
// params[10] - alpha value for output activation
|
||
// params[11] - beta value for output activation
|
||
|
||
// INPUTS:
|
||
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
|
||
// Wx - input weights [nIn, 4*nOut]
|
||
// Wr - recurrent weights [nOut, 4*nOut]
|
||
// b - biases [4*nOut], optional, may be nullptr
|
||
// hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// Wp - peephole weights [3*nOut], optional, may be nullptr
|
||
// dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
|
||
// a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut]
|
||
// c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut]
|
||
|
||
// OUTPUTS:
|
||
// dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr
|
||
// dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut]
|
||
// dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut]
|
||
// dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut]
|
||
// dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut]
|
||
|
||
// !!! dimension 4*nOut implies order i, f, g, o
|
||
// !!! dimension 3*nOut implies order i, f, o
|
||
|
||
// dhdc = o*tanhDeriv + Wp ? tanh(c)*dodzo*dzodc : 0 [bS, nOut]
|
||
// dcdcI = f + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0 [bS, nOut]
|
||
|
||
// dLdhI += dLdh; [bS, nOut]
|
||
// dLdcI += dLdhI * dhdc; [bS, nOut]
|
||
|
||
// dLdzi = dLdcI*dcdi*didzi; [bS, nOut](or[nOut])
|
||
// dLdzf = dLdcI*dcdf*dfdzf; [bS, nOut](or[nOut])
|
||
// dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut])
|
||
// dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut])
|
||
|
||
// dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn]
|
||
// dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut]
|
||
// dLdcI = dLdcI*dcdcI, [bS, nOut]
|
||
|
||
// dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut]
|
||
// dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut]
|
||
// dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut]
|
||
// dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut]
|
||
|
||
// dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut]
|
||
// dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut]
|
||
// dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut]
|
||
// dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut]
|
||
|
||
// dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
||
// dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
||
// dLdbg = dLdzg.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
||
// dLdbo = dLdzo.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
||
|
||
// dLdWpi = (dLdzi*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
||
// dLdWpf = (dLdzf*cI).reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
||
// dLdWpo = (dLdzo*c) .reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
||
|
||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||
const Nd4jLong nIn = x->sizeAt(-1);
|
||
|
||
NDArray zi = x->rankOf() == 1 ? (*z)({0, nOut}) : (*z)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut])
|
||
NDArray zf = x->rankOf() == 1 ? (*z)({nOut, 2*nOut}) : (*z)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut])
|
||
NDArray zg = x->rankOf() == 1 ? (*z)({2*nOut, 3*nOut}) : (*z)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut])
|
||
NDArray zo = x->rankOf() == 1 ? (*z)({3*nOut, 4*nOut}) : (*z)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut])
|
||
|
||
NDArray i = x->rankOf() == 1 ? (*a)({0, nOut}) : (*a)({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut])
|
||
NDArray f = x->rankOf() == 1 ? (*a)({nOut, 2*nOut}) : (*a)({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut])
|
||
NDArray g = x->rankOf() == 1 ? (*a)({2*nOut, 3*nOut}) : (*a)({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut])
|
||
NDArray o = x->rankOf() == 1 ? (*a)({3*nOut, 4*nOut}) : (*a)({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut])
|
||
|
||
NDArray dLdz = z->ulike(); // [bS, 4*nOut](or[4*nOut])
|
||
NDArray dLdzi = x->rankOf() == 1 ? dLdz({0, nOut}) : dLdz({0,0, 0, nOut});
|
||
NDArray dLdzf = x->rankOf() == 1 ? dLdz({nOut, 2*nOut}) : dLdz({0,0, nOut, 2*nOut});
|
||
NDArray dLdzg = x->rankOf() == 1 ? dLdz({2*nOut, 3*nOut}) : dLdz({0,0, 2*nOut, 3*nOut});
|
||
NDArray dLdzo = x->rankOf() == 1 ? dLdz({3*nOut, 4*nOut}) : dLdz({0,0, 3*nOut, 4*nOut});
|
||
|
||
// dcdzi = dcdi*didzi, [bS, nOut](or[nOut])
|
||
activationDeriv(zi, params[3], params[4], params[5], dLdzi); // didzi, inplace
|
||
dLdzi *= g; // dcdi = g*clipDeriv
|
||
|
||
// dcdzf = dcdf*dfdzf, [bS, nOut](or[nOut])
|
||
activationDeriv(zf, params[3], params[4], params[5], dLdzf); // dfdzf, inplace
|
||
dLdzf *= *cI; // dcdf = cI*clipDeriv
|
||
|
||
// dcdzg = dcde*dedzg, [bS, nOut](or[nOut])
|
||
activationDeriv(zg, params[6], params[7], params[8], dLdzg); // dgdzg, inplace
|
||
dLdzg *= i; // dcdf = i*clipDeriv
|
||
|
||
// dhdzo = dhdo*dodzo = actH(c)*dodzo, [bS, nOut](or[nOut])
|
||
activationDeriv(zo, params[3], params[4], params[5], dLdzo);
|
||
NDArray temp = dLdzo.ulike();
|
||
applyActivation(*c, params[9], params[10], params[11], temp); // actH(c), inplace
|
||
dLdzo *= temp;
|
||
|
||
// dcdcI
|
||
NDArray dcdcI = f.dup(); // dcdcI = f*clipDeriv [bS, nOut](or[nOut])
|
||
|
||
// take into account possible deposit from clipping derivative
|
||
clipDeriv(params[2], *c, dLdzi, dLdzf, dLdzg, dcdcI);
|
||
|
||
// dhdc
|
||
NDArray dhdc = c->ulike();
|
||
activationDeriv(*c, params[9], params[10], params[11], dhdc); // [bS, nOut]
|
||
dhdc *= o;
|
||
|
||
if(Wp) {
|
||
dhdc += dLdzo*(*Wp)({2*nOut, 3*nOut});
|
||
dcdcI += dLdzi*(*Wp)({0, nOut}) + dLdzf*(*Wp)({nOut, 2*nOut}); // broadcast [bS, nOut] * nOut + ...
|
||
}
|
||
|
||
if(dLdh)
|
||
*dLdhI += *dLdh;
|
||
if(dLdhL)
|
||
*dLdhI += *dLdhL;
|
||
if(dLdcL)
|
||
*dLdcI += *dLdcL;
|
||
|
||
*dLdcI += *dLdhI * dhdc;
|
||
|
||
dLdzi *= *dLdcI; // [bS, nOut](or[nOut])
|
||
dLdzf *= *dLdcI; // [bS, nOut](or[nOut])
|
||
dLdzg *= *dLdcI; // [bS, nOut](or[nOut])
|
||
dLdzo *= *dLdhI; // [bS, nOut](or[nOut])
|
||
|
||
// dLdx
|
||
NDArray WxT = Wx->transpose();
|
||
MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 4*nOut] x [4*nOut, nIn] (or [4*nOut] x [4*nOut, nIn]) = [bS, nIn] ( or[nIn] )
|
||
|
||
// dLdhI
|
||
NDArray WrT = Wr->transpose();
|
||
MmulHelper::mmul(&dLdz, &WrT, dLdhI); // [bS, 4*nOut] x [4*nOut, nOut] (or [4*nOut] x [4*nOut, nOut]) = [bS, nOut] ( or[nOut] )
|
||
|
||
// dLdcI
|
||
dLdcI->assign(*dLdcI*dcdcI); // [bS, nOut](or[nOut])
|
||
|
||
if(x->rankOf() == 1) {
|
||
|
||
NDArray xT = x->reshape(x->ordering(),{nIn, 1}); // [nIn] -> [nIn, 1]
|
||
NDArray hIT = hI->reshape(hI->ordering(),{nOut, 1}); // [nOut] -> [nOut, 1]
|
||
NDArray dLdzR = dLdz.reshape(dLdz.ordering(), {1, 4*nOut}); // [nOut] -> [1, 4*nOut]
|
||
|
||
// dLdWx
|
||
*dLdWx += mmul(xT, dLdzR); // [nIn, 1] x [1, 4*nOut] = [nIn, 4*nOut]
|
||
|
||
// dLdWr
|
||
*dLdWr += mmul(hIT, dLdzR); // [nOut, 1] x [1, 4*nOut] = [nOut, 4*nOut]
|
||
}
|
||
else {
|
||
|
||
// dLdWx
|
||
*dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 4*nOut] = [nIn, 4*nOut]
|
||
|
||
// dLdWr
|
||
*dLdWr += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 4*nOut] = [nOut, 4*nOut]
|
||
}
|
||
|
||
// dLdb
|
||
if(b && x->rankOf() == 1)
|
||
*dLdb += dLdz; // [4*nOut]
|
||
else if(b)
|
||
*dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 4*nOut] -> reduce -> [4*nOut];
|
||
|
||
// dLdWp
|
||
if(Wp && x->rankOf() == 1) {
|
||
(*dLdWp)({ 0,nOut}) += std::move(dLdzi)*(*cI); // [nOut]
|
||
(*dLdWp)({ nOut,2*nOut}) += std::move(dLdzf)*(*cI); // [nOut]
|
||
(*dLdWp)({2*nOut,3*nOut}) += std::move(dLdzo)*(*c); // [nOut]
|
||
}
|
||
else if(Wp) {
|
||
NDArray temp(Wp->ordering(), {nOut}, Wp->dataType(), Wp->getContext());
|
||
(std::move(dLdzi)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut]
|
||
(*dLdWp)({0,nOut}) += temp;
|
||
(std::move(dLdzf)*(*cI)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut]
|
||
(*dLdWp)({nOut,2*nOut}) += temp;
|
||
(std::move(dLdzo)*(*c)).reduceAlongDimension(reduce::Sum, temp, {0}); // [bS, nOut] -> reduce -> [nOut]
|
||
(*dLdWp)({2*nOut,3*nOut}) += temp;
|
||
}
|
||
}
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||
const std::vector<float>& params,
|
||
const bool forward,
|
||
NDArray* h, NDArray* hL, NDArray* cL) {
|
||
|
||
// INPUTS:
|
||
// x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL],
|
||
// Wx - input weights [nIn, 4*nOut]
|
||
// Wr - recurrent weights [nOut, 4*nOut]
|
||
// b - biases [4*nOut], optional, may be nullptr
|
||
// seqLen - [bS], optional, may be nullptr
|
||
// hI - initial output [bS, nOut], optional, may be nullptr
|
||
// cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
|
||
// Wp - peephole weights [3*nOut], optional, may be nullptr
|
||
|
||
// OUTPUTS:
|
||
// h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr
|
||
// hL - output at last step [bS, nOut], optional, may be nullptr
|
||
// cL - cell state at last step [bS, nOut], optional, may be nullptr
|
||
|
||
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
|
||
// dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]
|
||
|
||
const int dataFormat = params[0];
|
||
const int directionMode = params[1];
|
||
|
||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||
|
||
const std::vector<Nd4jLong> shapeOut = {bS, nOut};
|
||
|
||
const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType());
|
||
|
||
auto h0 = const_cast<NDArray*>(hI);
|
||
if(!hI) {
|
||
h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext());
|
||
h0->nullify();
|
||
}
|
||
|
||
auto c0 = const_cast<NDArray*>(cI);
|
||
if(!cI) {
|
||
c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext());
|
||
c0->nullify();
|
||
}
|
||
|
||
auto ct = cL;
|
||
if(!cL)
|
||
ct = new NDArray(x->ordering(), shapeOut, type, x->getContext());
|
||
|
||
auto ht = hL;
|
||
if(!h && !hL)
|
||
ht = new NDArray(x->ordering(), shapeOut, type, x->getContext());
|
||
|
||
// create sets of required (depends on seqLen presence) sub-arrays
|
||
std::vector<int> dims;
|
||
ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr);
|
||
|
||
if(!seqLen) {
|
||
|
||
dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes
|
||
|
||
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn]
|
||
if(h)
|
||
hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut]
|
||
}
|
||
else {
|
||
|
||
dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2}); // points on nIn/nOut axis
|
||
|
||
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn]
|
||
h0Set = new ResultSet(h0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
c0Set = new ResultSet(c0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
ctSet = new ResultSet(ct->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
if(h)
|
||
hSet = new ResultSet(h->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut]
|
||
if(ht)
|
||
htSet = new ResultSet(ht->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
}
|
||
|
||
// loops
|
||
if(forward) {
|
||
|
||
if(!seqLen) {
|
||
|
||
if(!h) { // seqLen and h are absent
|
||
|
||
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
|
||
for (Nd4jLong t = 1; t < sL; ++t)
|
||
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
|
||
}
|
||
else { // seqLen is absent and h is present
|
||
|
||
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step
|
||
for (Nd4jLong t = 1; t < sL; ++t)
|
||
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
|
||
|
||
if(hL)
|
||
hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr
|
||
}
|
||
}
|
||
else {
|
||
|
||
if(!h) { // seqLen is present and h is absent
|
||
|
||
for (Nd4jLong e = 0; e < bS; ++e) {
|
||
|
||
const int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
if(cL)
|
||
ctSet->at(e)->nullify();
|
||
if(hL)
|
||
htSet->at(e)->nullify();
|
||
continue;
|
||
}
|
||
|
||
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
|
||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
|
||
|
||
for (int t = 1; t < limit; ++t) {
|
||
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
|
||
}
|
||
}
|
||
}
|
||
else { // seqLen and h are present
|
||
|
||
for (Nd4jLong e = 0; e < bS; ++e) {
|
||
|
||
int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
|
||
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||
|
||
if(cL)
|
||
ctSet->at(e)->nullify();
|
||
if(hL)
|
||
htSet->at(e)->nullify();
|
||
|
||
continue;
|
||
}
|
||
|
||
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
|
||
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
|
||
|
||
for (int t = 1; t < limit; ++t) {
|
||
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
|
||
indPrev = indCurr;
|
||
}
|
||
|
||
if(hL)
|
||
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr
|
||
|
||
if(limit != sL)
|
||
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
else { // backward
|
||
|
||
if(!seqLen) {
|
||
|
||
if(!h) { // seqLen and h are absent
|
||
|
||
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
|
||
for (Nd4jLong t = sL - 2; t >= 0; --t)
|
||
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
|
||
}
|
||
else { // seqLen is absent and h is present
|
||
|
||
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step
|
||
for (Nd4jLong t = sL - 2; t >= 0; --t)
|
||
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
|
||
|
||
if(hL)
|
||
hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr
|
||
}
|
||
}
|
||
else if(directionMode == 1) { // only backward, no bidirectional mode
|
||
|
||
if(!h) { // h is absent and seqLen is present
|
||
|
||
for (Nd4jLong e = 0; e < bS; ++e) {
|
||
|
||
const int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
if(cL)
|
||
ctSet->at(e)->nullify();
|
||
if(hL)
|
||
htSet->at(e)->nullify();
|
||
continue;
|
||
}
|
||
|
||
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
|
||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
|
||
|
||
for (Nd4jLong t = sL - 2; t >= sL - limit; --t) {
|
||
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
|
||
}
|
||
}
|
||
}
|
||
else { // seqLen and h are present
|
||
|
||
for (Nd4jLong e = 0; e < bS; ++e) {
|
||
|
||
int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
|
||
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||
|
||
if(cL)
|
||
ctSet->at(e)->nullify();
|
||
if(hL)
|
||
htSet->at(e)->nullify();
|
||
|
||
continue;
|
||
}
|
||
|
||
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
|
||
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
|
||
|
||
for (Nd4jLong t = sL - 2; t >= sL - limit; --t) {
|
||
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
|
||
indPrev = indCurr;
|
||
}
|
||
|
||
if(hL)
|
||
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
|
||
|
||
if(limit != sL)
|
||
tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||
}
|
||
}
|
||
}
|
||
else { // backward in bidirectional mode
|
||
|
||
if(!h) { // h is absent and seqLen is present
|
||
|
||
for (Nd4jLong e = 0; e < bS; ++e) {
|
||
|
||
const int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
if(cL)
|
||
ctSet->at(e)->nullify();
|
||
if(hL)
|
||
htSet->at(e)->nullify();
|
||
continue;
|
||
}
|
||
|
||
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
|
||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
|
||
|
||
for (int t = limit - 2; t >= 0; --t) {
|
||
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
|
||
}
|
||
}
|
||
}
|
||
else { // seqLen and h are present
|
||
|
||
for (Nd4jLong e = 0; e < bS; ++e) {
|
||
|
||
int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
|
||
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||
|
||
if(cL)
|
||
ctSet->at(e)->nullify();
|
||
if(hL)
|
||
htSet->at(e)->nullify();
|
||
|
||
continue;
|
||
}
|
||
|
||
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
|
||
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
|
||
|
||
for (int t = limit - 2; t >= 0; --t) {
|
||
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
|
||
indPrev = indCurr;
|
||
}
|
||
|
||
if(hL)
|
||
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
|
||
|
||
if(limit != sL)
|
||
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
delete xSet;
|
||
delete hSet;
|
||
delete h0Set;
|
||
delete c0Set;
|
||
delete htSet;
|
||
delete ctSet;
|
||
|
||
if(!hI)
|
||
delete h0;
|
||
if(!cI)
|
||
delete c0;
|
||
if(!cL)
|
||
delete ct;
|
||
if(!h && !hL)
|
||
delete ht;
|
||
}
|
||
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||
const NDArray* b, const NDArray* seqLen, NDArray* hI, NDArray* cI, const NDArray* Wp,
|
||
const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL,
|
||
const std::vector<float>& params, const bool forward,
|
||
NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdb, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdWp) {
|
||
|
||
// INPUTS:
|
||
// x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL],
|
||
// Wx - input weights [nIn, 4*nOut]
|
||
// Wr - recurrent weights [nOut, 4*nOut]
|
||
// b - biases [4*nOut], optional, may be nullptr
|
||
// seqLen - [bS], optional, may be nullptr
|
||
// hI - initial output [bS, nOut], optional, may be nullptr
|
||
// cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
|
||
// Wp - peephole weights [3*nOut], optional, may be nullptr
|
||
// dLdh - gradient vs. output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr
|
||
// dLdhL - gradient vs. output at last time step [bS, nOut], optional, may be nullptr
|
||
// dLdcL - gradient vs. cell state at last time step [bS, nOut], optional, may be nullptr
|
||
|
||
// OUTPUTS:
|
||
// dLdx - gradient vs. input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL]
|
||
// dLdWx - gradient vs. input weights [nIn, 4*nOut]
|
||
// dLdWr - gradient vs. recurrent weights [nOut, 4*nOut]
|
||
// dLdb - gradient vs. biases [4*nOut], optional, may be nullptr
|
||
// dLdhI - gradient vs. initial output [bS, nOut], optional, may be nullptr
|
||
// dLdcI - gradient vs. initial cell state at time t-1 [bS, nOut], optional, may be nullptr
|
||
// dLdWp - gradient vs. peephole weights [3*nOut], optional, may be nullptr
|
||
|
||
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
|
||
// dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]
|
||
|
||
const int dataFormat = params[0];
|
||
const int directionMode = params[1];
|
||
|
||
const int sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||
const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||
const int nOut = Wx->sizeAt(-1) / 4;
|
||
|
||
const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType());
|
||
|
||
auto dLdh0 = dLdhI;
|
||
if(!hI)
|
||
dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically
|
||
|
||
auto dLdc0 = dLdcI;
|
||
if(!cI)
|
||
dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically
|
||
|
||
NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext());
|
||
NDArray a = z.ulike();
|
||
NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext());
|
||
NDArray c = h.ulike();
|
||
|
||
// create sets of required (depends on seqLen presence) sub-arrays
|
||
std::vector<int> dims;
|
||
ResultSet *xSet(nullptr), *dLdxSet(nullptr), *hSet(nullptr), *cSet(nullptr), *zSet(nullptr), *aSet(nullptr), *dLdhSet(nullptr),
|
||
*dLdh0Set(nullptr), *dLdc0Set(nullptr), *dLdhLSet(nullptr), *dLdcLSet(nullptr), *hISet(nullptr), *cISet(nullptr);
|
||
|
||
if(!seqLen) {
|
||
|
||
dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on [bS, nIn/nOut]
|
||
|
||
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn]
|
||
dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nIn]
|
||
hSet = new ResultSet(h.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut]
|
||
cSet = new ResultSet(c.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, nOut]
|
||
zSet = new ResultSet(z.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut]
|
||
aSet = new ResultSet(a.allTensorsAlongDimension({1, 2})); // sub-arrays with shape [bS, 4*nOut]
|
||
if(dLdh)
|
||
dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [bS, nOut]
|
||
}
|
||
else {
|
||
|
||
dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2}); // points on nIn/nOut axis
|
||
|
||
xSet = new ResultSet(x->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn]
|
||
dLdxSet = new ResultSet(dLdx->allTensorsAlongDimension(dims)); // sub-arrays with shape [nIn]
|
||
hSet = new ResultSet(h.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut]
|
||
cSet = new ResultSet(c.allTensorsAlongDimension({2})); // sub-arrays with shape [nOut]
|
||
zSet = new ResultSet(z.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut]
|
||
aSet = new ResultSet(a.allTensorsAlongDimension({2})); // sub-arrays with shape [4*nOut]
|
||
|
||
if(hI)
|
||
hISet = new ResultSet(hI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
if(cI)
|
||
cISet = new ResultSet(cI->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
|
||
dLdh0Set = new ResultSet(dLdh0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
dLdc0Set = new ResultSet(dLdc0->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
|
||
if(dLdh)
|
||
dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut]
|
||
if(dLdhL)
|
||
dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
if(dLdcL)
|
||
dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut]
|
||
}
|
||
|
||
|
||
// loops
|
||
if(forward) {
|
||
|
||
if(!seqLen) { // seqLen is absent
|
||
|
||
if(hI)
|
||
hSet->at(0)->assign(hI);
|
||
else
|
||
hSet->at(0)->nullify();
|
||
if(cI)
|
||
cSet->at(0)->assign(cI);
|
||
else
|
||
cSet->at(0)->nullify();
|
||
|
||
// ff
|
||
for (int t = 0; t < sL; ++t)
|
||
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t+1), cSet->at(t+1));
|
||
|
||
// bp
|
||
for (int t = sL-1; t >= 0; --t) {
|
||
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr;
|
||
const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr;
|
||
const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr;
|
||
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL,
|
||
zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
|
||
}
|
||
}
|
||
else { // seqLen is present
|
||
|
||
for (int e = 0; e < bS; ++e) {
|
||
|
||
const int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||
continue;
|
||
}
|
||
|
||
if(hI)
|
||
hSet->at(e)->assign(hISet->at(e));
|
||
else
|
||
hSet->at(e)->nullify();
|
||
if(cI)
|
||
cSet->at(e)->assign(cISet->at(e));
|
||
else
|
||
cSet->at(e)->nullify();
|
||
|
||
// ff
|
||
for (int t = 0; t < limit; ++t)
|
||
lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, params,
|
||
zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e));
|
||
|
||
// bp
|
||
for (int t = limit-1; t >= 0; --t) {
|
||
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr;
|
||
const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr;
|
||
const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr;
|
||
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
|
||
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
|
||
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
|
||
}
|
||
|
||
if(limit != sL)
|
||
tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||
}
|
||
}
|
||
}
|
||
else { // backward or bidirectional
|
||
|
||
if(!seqLen) { // backward or bidirectional, seqLen is absent
|
||
|
||
if(hI)
|
||
hSet->at(sL)->assign(hI);
|
||
else
|
||
hSet->at(sL)->nullify();
|
||
if(cI)
|
||
cSet->at(sL)->assign(cI);
|
||
else
|
||
cSet->at(sL)->nullify();
|
||
|
||
// ff
|
||
for (int t = sL-1; t >= 0; --t)
|
||
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, params, zSet->at(t), aSet->at(t), hSet->at(t), cSet->at(t));
|
||
|
||
// bp
|
||
for (int t = 0; t < sL; ++t) {
|
||
const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr;
|
||
const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr;
|
||
const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr;
|
||
lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL,
|
||
zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp);
|
||
}
|
||
}
|
||
else if(directionMode == 1) { // backward, seqLen is present
|
||
|
||
for (int e = 0; e < bS; ++e) {
|
||
|
||
const int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||
continue;
|
||
}
|
||
|
||
if(hI)
|
||
hSet->at(sL*bS + e)->assign(hISet->at(e));
|
||
else
|
||
hSet->at(sL*bS + e)->nullify();
|
||
if(cI)
|
||
cSet->at(sL*bS + e)->assign(cISet->at(e));
|
||
else
|
||
cSet->at(sL*bS + e)->nullify();
|
||
|
||
// ff
|
||
for (int t = sL - 1; t >= sL-limit; --t)
|
||
lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params,
|
||
zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e));
|
||
|
||
// bp
|
||
for (int t = sL-limit; t < sL; ++t) {
|
||
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr;
|
||
const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr;
|
||
const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr;
|
||
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
|
||
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
|
||
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
|
||
}
|
||
|
||
if(limit != sL)
|
||
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||
}
|
||
}
|
||
else { // bidirectional mode, seqLen is present
|
||
|
||
for (int e = 0; e < bS; ++e) {
|
||
|
||
const int limit = seqLen->e<int>(e);
|
||
|
||
if(limit == 0) {
|
||
tensorAlongTimeBatchDims(*dLdx, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||
continue;
|
||
}
|
||
|
||
if(hI)
|
||
h({limit,limit+1, e,e+1, 0,0}).assign(hISet->at(e));
|
||
else
|
||
h({limit,limit+1, e,e+1, 0,0}).nullify();
|
||
if(cI)
|
||
c({limit,limit+1, e,e+1, 0,0}).assign(cISet->at(e));
|
||
else
|
||
c({limit,limit+1, e,e+1, 0,0}).nullify();
|
||
|
||
// ff
|
||
for (int t = limit - 1; t >= 0; --t)
|
||
lstmLayerCell(xSet->at(getBatchTimeTotalIndex(dataFormat, sL, bS, t, e)), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, params,
|
||
zSet->at(t*bS + e), aSet->at(t*bS + e), hSet->at(t*bS + e), cSet->at(t*bS + e));
|
||
|
||
// bp
|
||
for (int t = 0; t < limit; ++t) {
|
||
const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||
const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr;
|
||
const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr;
|
||
const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr;
|
||
lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL,
|
||
zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr,
|
||
dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp);
|
||
}
|
||
|
||
if(limit != sL)
|
||
tensorAlongTimeBatchDims(*dLdx, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||
}
|
||
}
|
||
}
|
||
|
||
delete xSet; delete dLdxSet; delete hSet; delete cSet; delete aSet; delete zSet;
|
||
delete dLdhSet; delete dLdh0Set; delete dLdc0Set; delete dLdhLSet; delete dLdcLSet; delete hISet; delete cISet;
|
||
|
||
if(!hI)
|
||
delete dLdh0;
|
||
if(!cI)
|
||
delete dLdc0;
|
||
}
|
||
|
||
|
||
}
|
||
}
|
||
}
|
||
|
||
|
||
|
||
//////////////////////////////////////////////////////////////////////////
|
||
// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||
// const NDArray* b, NDArray* hI, NDArray* cI, const NDArray* Wp, const NDArray* dLdh,
|
||
// const std::vector<float>& params, const bool firstIter,
|
||
|
||
// NDArray* dhIdcI, NDArray* dhIdWx, NDArray* dcIdWx, NDArray* dhIdWr, NDArray* dcIdWr,
|
||
// NDArray* dhIdb, NDArray* dcIdb, NDArray* dhIdWp, NDArray* dcIdWp,
|
||
// NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) {
|
||
|
||
// /************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||
// /** the objective is to provide math-readable code **/
|
||
|
||
// // equations (no peephole connections)
|
||
// // zi = x × Wxi + hI × Wri + bi
|
||
// // zf = x × Wxf + hI × Wrf + bf
|
||
// // zg = x × Wxg + hI × Wrg + bg
|
||
// // zo = x × Wxo + hI × Wro + bo
|
||
// // i = act(zi)
|
||
// // f = act(zf)
|
||
// // g = actC(zg)
|
||
// // o = act(zo)
|
||
// // c = clip(f * cI + i * g)
|
||
// // h = o * actH(c)
|
||
|
||
// // equations (peephole connections are present)
|
||
// // zi = x × Wxi + hI × Wri + cI * Wpi + bi
|
||
// // zf = x × Wxf + hI × Wrf + cI * Wpf + bf
|
||
// // zg = x × Wxg + hI × Wrg + bg
|
||
// // zo = x × Wxo + hI × Wro + c * Wpo + bo
|
||
// // i = act(zi)
|
||
// // f = act(zf)
|
||
// // g = actC(zg)
|
||
// // o = act(zo)
|
||
// // c = clip(f * cI + i * g)
|
||
// // h = o * actH(c)
|
||
|
||
// // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||
|
||
// // params[0] - dataFormat, ignore
|
||
// // params[1] - directionMode, ignore
|
||
// // params[2] - cell clipping value, if it = 0 then do not apply clipping
|
||
|
||
// // params[3] - activation ID for input (i), forget (f) and output (o) gates
|
||
// // params[4] - alpha value for gates activation
|
||
// // params[5] - beta value for gates activation
|
||
|
||
// // params[6] - activation ID for cell state (c)
|
||
// // params[7] - alpha value for cell state activation
|
||
// // params[8] - beta value for cell state activation
|
||
|
||
// // params[9] - activation ID for output (h)
|
||
// // params[10] - alpha value for output activation
|
||
// // params[11] - beta value for output activation
|
||
|
||
// // INPUTS:
|
||
// // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
|
||
// // Wx - input weights [nIn, 4*nOut]
|
||
// // Wr - recurrent weights [nOut, 4*nOut]
|
||
// // b - biases [4*nOut], optional, may be nullptr
|
||
// // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// // Wp - peephole weights [3*nOut], optional, may be nullptr
|
||
// // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// // dhIdcI - derivative from previous time step, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// // dhIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr
|
||
// // dcIdWx - derivative from previous time step (Jacobian), [nIn, 4*nOut, bS, nOut] or [nIn, 4*nOut, nOut] if seqLen != nullptr
|
||
// // dhIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr
|
||
// // dcIdWr - derivative from previous time step (Jacobian), [nOut, 4*nOut, bS, nOut] or [nOut, 4*nOut, nOut] if seqLen != nullptr
|
||
// // dcIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr
|
||
// // dhIdWp - derivative from previous time step, [3*nOut], optional, may be nullptr
|
||
// // dcIdb - derivative from previous time step, [4*nOut], optional, may be nullptr
|
||
// // dhIdb - derivative from previous time step, [4*nOut], optional, may be nullptr
|
||
|
||
// // OUTPUTS:
|
||
// // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr
|
||
// // dLdWx - loss derivative with respect to Wx, [nIn, 4*nOut]
|
||
// // dLdWr - loss derivative with respect to Wr, [nOut, 4*nOut]
|
||
// // dLdb - loss derivative with respect to b, optional, may be nullptr, [4*nOut]
|
||
// // dLdhI - loss derivative with respect to hI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// // dLdcI - loss derivative with respect to cI, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||
// // dLdWp - loss derivative with respect to Wp, optional, may be nullptr, [3*nOut]
|
||
|
||
// // !!! dimension 4*nOut implies order i, f, g, o
|
||
// // !!! dimension 3*nOut implies order i, f, o
|
||
|
||
// // dcdzi = dcdi*didzi
|
||
// // dcdzf = dcdf*dfdzf
|
||
// // dcdzg = dcdg*dgdzg
|
||
// // dhdzo = dhdo*dodzo
|
||
|
||
// // dhdc = dhdc + Wp ? dhdzo*dzodc : 0 [bS, nOut]
|
||
// // factor = dLdh*dhdc [bS, nOut]
|
||
// // iFactor = factor*dcdzi [bS, nOut]
|
||
// // fFactor = factor*dcdzf [bS, nOut]
|
||
// // eFactor = factor*dcdzg [bS, nOut]
|
||
// // oFactor = *dLdh*dhdzo [bS, nOut]
|
||
|
||
// // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0;
|
||
// // tempIFE = dcdzi×WriT + dcdzf×WrfT + dcdzg×WrgT
|
||
// // tempO = dhdzo×WroT
|
||
|
||
// // dhIdcI = dhdc_from_previous_time_step
|
||
|
||
// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn]
|
||
// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut]
|
||
// // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut]
|
||
|
||
// // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
// // dcdWxf(dcIdWxf) = dcdzf*dzfdWxf + tempIFE*dhIdWxf + tempC*dcIdWxf, dcIdWxf=dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
// // dcdWxg(dcIdWxg) = dcdzg*dzgdWxg + tempIFE*dhIdWxg + tempC*dcIdWxg, dcIdWxg=dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
// // dcdWxo(dcIdWxo) = 0 + tempIFE*dhIdWxo + tempC*dcIdWxo; dcIdWxo=dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
|
||
// // dhdWxi(dhIdWxi) = 0 + dhdc*dcdWxi + tempO*dhIdWxi, dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
// // dhdWxf(dhIdWxf) = 0 + dhdc*dcdWxf + tempO*dhIdWxf, dhIdWxf= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
// // dhdWxg(dhIdWxg) = 0 + dhdc*dcdWxg + tempO*dhIdWxg, dhIdWxg= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
// // dhdWxo(dhIdWxo) = dhdzo*dzodWxo + dhdc*dcdWxo + tempO*dhIdWxo, dhIdWxo= 0 if firstIter, [nIn, nOut, bS, nOut]
|
||
|
||
// // dhdWri(dhIdWri) = 0 + dhdc*dcdWri + tempO*dhIdWri, dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
// // dhdWrf(dhIdWrf) = 0 + dhdc*dcdWrf + tempO*dhIdWrf, dhIdWrf= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
// // dhdWrg(dhIdWrg) = 0 + dhdc*dcdWrg + tempO*dhIdWrg, dhIdWrg= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
// // dhdWro(dhIdWro) = dhdzo*dzodWro + dhdc*dcdWro + tempO*dhIdWro, dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
|
||
// // dcdWri(dcIdWri) = dcdzi*dzidWri + tempIFE*dhIdWri + tempC*dcIdWri, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
// // dcdWrf(dcIdWrf) = dcdzf*dzfdWrf + tempIFE*dhIdWrf + tempC*dcIdWrf, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
// // dcdWrg(dcIdWrg) = dcdzg*dzgdWrg + tempIFE*dhIdWrg + tempC*dcIdWrg, dcIdWri=dhIdWri= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
// // dcdWro(dcIdWro) = 0 + tempIFE*dhIdWro + tempC*dcIdWro; dcIdWro=dhIdWro= 0 if firstIter, [nOut, nOut, bS, nOut]
|
||
|
||
// // dcIdWpi = (dcdzi*cI + tempIFE*dhIdWpi + tempC*dcIdWpi).reduceALongFirstDim, dcIdWpi=dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dcIdWpf = (dcdzf*cI + tempIFE*dhIdWpf + tempC*dcIdWpf).reduceALongFirstDim, dcIdWpf=dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dcIdWpo = (0 + tempIFE*dhIdWpo + tempC*dcIdWpo).reduceALongFirstDim, dcIdWpo=dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
|
||
// // dhdWpi(dhIdWpi) =( 0 + dhdc*dcdWpi + tempO*dhIdWpi).reduceALongFirstDim, dhIdWpi= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dhdWpf(dhIdWpf) =( 0 + dhdc*dcdWpf + tempO*dhIdWpf).reduceALongFirstDim, dhIdWpf= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dhdWpo(dhIdWpo) =(dhdzo*c + dhdc*dcdWpo + tempO*dhIdWpo).reduceALongFirstDim, dhIdWpo= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
|
||
// // dcdbi(dcIdbi) = (dcdzi + tempIFE*dhIdbi + tempC*dcIdbi).reduceALongFirstDim, dcIdbi=dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dcdbf(dcIdbf) = (dcdzf + tempIFE*dhIdbf + tempC*dcIdbf).reduceALongFirstDim, dcIdbf=dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dcdbg(dcIdbg) = (dcdzg + tempIFE*dhIdbg + tempC*dcIdbg).reduceALongFirstDim, dcIdbg=dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dcdbo(dcIdbo) = ( 0 + tempIFE*dhIdbo + tempC*dcIdbo).reduceALongFirstDim; dcIdbo=dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
|
||
// // dhdbi(dhIdbi) = ( 0 + dhdc*dcdbi + tempO*dhIdbi).reduceALongFirstDim, dhIdbi= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dhdbf(dhIdbf) = ( 0 + dhdc*dcdbf + tempO*dhIdbf).reduceALongFirstDim, dhIdbf= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dhdbg(dhIdbg) = ( 0 + dhdc*dcdbg + tempO*dhIdbg).reduceALongFirstDim, dhIdbg= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
// // dhdbo(dhIdbo) = (dhdzo + dhdc*dcdbo + tempO*dhIdbo).reduceALongFirstDim, dhIdbo= 0 if firstIter, [bS, nOut]->reduce->[bS]
|
||
|
||
// const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||
|
||
// NDArray *Wpi(nullptr), *Wpf(nullptr), *Wpo(nullptr), *dcIdWpi(nullptr), *dcIdWpf(nullptr), *dcIdWpo(nullptr), *dhIdWpi(nullptr), *dhIdWpf(nullptr), *dhIdWpo(nullptr);
|
||
// if(Wp) {
|
||
// Wpi = new NDArray((*Wp)({0, nOut}));
|
||
// Wpf = new NDArray((*Wp)({nOut, 2*nOut}));
|
||
// Wpo = new NDArray((*Wp)({2*nOut, 3*nOut}));
|
||
// dhIdWpi = new NDArray((*dhIdWp)({0, nOut}));
|
||
// dhIdWpf = new NDArray((*dhIdWp)({nOut, 2*nOut}));
|
||
// dhIdWpo = new NDArray((*dhIdWp)({2*nOut, 3*nOut}));
|
||
// dcIdWpi = new NDArray((*dcIdWp)({0, nOut}));
|
||
// dcIdWpf = new NDArray((*dcIdWp)({nOut, 2*nOut}));
|
||
// dcIdWpo = new NDArray((*dcIdWp)({2*nOut, 3*nOut}));
|
||
// }
|
||
|
||
// NDArray *dcIdbi(nullptr), *dcIdbf(nullptr), *dcIdbg(nullptr), *dcIdbo(nullptr), *dhIdbi(nullptr), *dhIdbf(nullptr), *dhIdbg(nullptr), *dhIdbo(nullptr);
|
||
// if(b) {
|
||
// dhIdbi = new NDArray((*dhIdb)({0, nOut}));
|
||
// dhIdbf = new NDArray((*dhIdb)({nOut, 2*nOut}));
|
||
// dhIdbg = new NDArray((*dhIdb)({2*nOut, 3*nOut}));
|
||
// dhIdbo = new NDArray((*dhIdb)({3*nOut, 4*nOut}));
|
||
// dcIdbi = new NDArray((*dcIdb)({0, nOut}));
|
||
// dcIdbf = new NDArray((*dcIdb)({nOut, 2*nOut}));
|
||
// dcIdbg = new NDArray((*dcIdb)({2*nOut, 3*nOut}));
|
||
// dcIdbo = new NDArray((*dcIdb)({3*nOut, 4*nOut}));
|
||
// }
|
||
|
||
// NDArray dhIdWxi = x->rankOf() == 1 ? (*dhIdWx)({0,0, 0,nOut, 0,0}) : (*dhIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dhIdWxf = x->rankOf() == 1 ? (*dhIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dhIdWxg = x->rankOf() == 1 ? (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dhIdWxo = x->rankOf() == 1 ? (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
|
||
// NDArray dhIdWri = x->rankOf() == 1 ? (*dhIdWr)({0,0, 0,nOut, 0,0}) : (*dhIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dhIdWrf = x->rankOf() == 1 ? (*dhIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dhIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dhIdWrg = x->rankOf() == 1 ? (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dhIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dhIdWro = x->rankOf() == 1 ? (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dhIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
|
||
// NDArray dcIdWxi = x->rankOf() == 1 ? (*dcIdWx)({0,0, 0,nOut, 0,0}) : (*dcIdWx)({0,0, 0,nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dcIdWxf = x->rankOf() == 1 ? (*dcIdWx)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWx)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dcIdWxg = x->rankOf() == 1 ? (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWx)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dcIdWxo = x->rankOf() == 1 ? (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWx)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nIn, nOut, bS, nOut] or [nIn, nOut, nOut] if seqLen != nullptr
|
||
|
||
// NDArray dcIdWri = x->rankOf() == 1 ? (*dcIdWr)({0,0, 0,nOut, 0,0}) : (*dcIdWr)({0,0, 0,nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dcIdWrf = x->rankOf() == 1 ? (*dcIdWr)({0,0, nOut,2*nOut, 0,0}) : (*dcIdWr)({0,0, nOut,2*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dcIdWrg = x->rankOf() == 1 ? (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0}) : (*dcIdWr)({0,0, 2*nOut,3*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
// NDArray dcIdWro = x->rankOf() == 1 ? (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0}) : (*dcIdWr)({0,0, 3*nOut,4*nOut, 0,0, 0,0}); // [nOut, nOut, bS, nOut] or [nOut, nOut, nOut] if seqLen != nullptr
|
||
|
||
// NDArray WxiT = (*Wx)({0,0, 0, nOut}).transpose(); // [nOut, nIn]
|
||
// NDArray WxfT = (*Wx)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nIn]
|
||
// NDArray WxgT = (*Wx)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nIn]
|
||
// NDArray WxoT = (*Wx)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nIn]
|
||
|
||
// NDArray WriT = (*Wr)({0,0, 0, nOut}).transpose(); // [nOut, nOut]
|
||
// NDArray WrfT = (*Wr)({0,0, nOut, 2*nOut}).transpose(); // [nOut, nOut]
|
||
// NDArray WrgT = (*Wr)({0,0, 2*nOut,3*nOut}).transpose(); // [nOut, nOut]
|
||
// NDArray WroT = (*Wr)({0,0, 3*nOut,4*nOut}).transpose(); // [nOut, nOut]
|
||
|
||
// // ***** feed forward step ***** //
|
||
|
||
// auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
|
||
// //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
|
||
// // add biases if they are given
|
||
// if(b)
|
||
// z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut](or[4*nOut])
|
||
|
||
// auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate i, [bS, nOut](or[nOut])
|
||
// auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate f, [bS, nOut](or[nOut])
|
||
// auto zg = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate g, [bS, nOut](or[nOut])
|
||
// auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate o, [bS, nOut](or[nOut])
|
||
|
||
// // peephole connections for input and forget gates
|
||
// if(Wp) {
|
||
// zi += *cI * *Wpi; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
// zf += *cI * *Wpf; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
// }
|
||
|
||
// NDArray i = zi.ulike(); // [bS, nOut]
|
||
// NDArray f = zf.ulike(); // [bS, nOut]
|
||
// NDArray g = zg.ulike(); // [bS, nOut]
|
||
// applyActivation(zi, params[3], params[4], params[5], i);
|
||
// applyActivation(zf, params[3], params[4], params[5], f);
|
||
// applyActivation(zg, params[6], params[7], params[8], g);
|
||
|
||
// NDArray c = f * *cI + i * g; // [bS, nOut] * [bS, nOut] + [bS, nOut] * [bS, nOut] = [bS, nOut](or[nOut])
|
||
|
||
// // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
|
||
// if(params[2] != 0)
|
||
// c.applyScalar(scalar::LstmClip, params[2], c);
|
||
|
||
// // peephole connections for output gate
|
||
// if(Wp)
|
||
// zo += c * *Wpo; // broadcast: [bS, nOut] + [bS, nOut] * [nOut] = [bS, nOut](or[nOut])
|
||
|
||
// NDArray o = zo.ulike(); // [bS, nOut](or[nOut])
|
||
// applyActivation(zo, params[3], params[4], params[5], o);
|
||
|
||
// // ***** back prop step ***** //
|
||
|
||
// NDArray dWxJacobian = mmulJacobianWeightsDeriv(nOut, *x); // [nIn, nOut, bS, nOut] (or [nIn, nOut, nOut])
|
||
// NDArray dWrJacobian = mmulJacobianWeightsDeriv(nOut, *hI); // [nOut, nOut, bS, nOut] (or [nOut, nOut, nOut])
|
||
|
||
// // dodzo
|
||
// NDArray dodzo = zo.ulike(); // [bS, nOut](or[nOut])
|
||
// activationDeriv(zo, params[3], params[4], params[5], dodzo);
|
||
|
||
// // dhdzo = dhdo*dodzo = actH(c)*dodzo
|
||
// NDArray dhdzo = zo.ulike(); // [bS, nOut](or[nOut])
|
||
// applyActivation(c, params[9], params[10], params[11], dhdzo); // actH(c)
|
||
// hI->assign(o*dhdzo);
|
||
// dhdzo *= dodzo;
|
||
|
||
// // dcdzi = dcdi*didzi
|
||
// NDArray dcdzi = zi.ulike(); // [bS, nOut](or[nOut])
|
||
// activationDeriv(zi, params[3], params[4], params[5], dcdzi); // didzi
|
||
// dcdzi *= g; // dcdi = g*clipDeriv
|
||
|
||
// // dcdzf = dcdf*dfdzf
|
||
// NDArray dcdzf = zf.ulike(); // [bS, nOut](or[nOut])
|
||
// activationDeriv(zf, params[3], params[4], params[5], dcdzf); // dfdzf
|
||
// dcdzf *= *cI; // dcdf = cI*clipDeriv
|
||
|
||
// // dcdzg = dcde*dedzg
|
||
// NDArray dcdzg = zg.ulike(); // [bS, nOut](or[nOut])
|
||
// activationDeriv(zg, params[6], params[7], params[8], dcdzg); // dedzg
|
||
// dcdzg *= i; // dcdf = i*clipDeriv
|
||
|
||
// // dcdcI
|
||
// NDArray dcdcI = f.dup(); // [bS, nOut](or[nOut])
|
||
|
||
// // take into account possible deposit from clipping derivative
|
||
// clipDeriv(params[2], c, dcdzi, dcdzf, dcdzg, dcdcI);
|
||
|
||
// // dzodc
|
||
// NDArray* dzodc = Wpo; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)
|
||
|
||
// // dzidcI
|
||
// NDArray* dzidcI = Wpi; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)
|
||
|
||
// // dzfdcI
|
||
// NDArray* dzfdcI = Wpf; // [nOut], should be [bS, nOut] actually, however it will be broadcasted appropriately in future calcus (element-wise multiplication)
|
||
|
||
// // dhdc
|
||
// NDArray dhdc = c.ulike();
|
||
// activationDeriv(c, params[9], params[10], params[11], dhdc); // [bS, nOut]
|
||
// dhdc *= o;
|
||
// if(Wp)
|
||
// dhdc += dhdzo* *dzodc;
|
||
|
||
// NDArray factor = *dLdh * dhdc;
|
||
|
||
// NDArray iFactor = factor*dcdzi; // [bS, nOut](or[nOut])
|
||
// NDArray fFactor = factor*dcdzf; // [bS, nOut](or[nOut])
|
||
// NDArray eFactor = factor*dcdzg; // [bS, nOut](or[nOut])
|
||
// NDArray oFactor = *dLdh *dhdzo; // [bS, nOut](or[nOut])
|
||
|
||
// NDArray tempC = dcdcI;
|
||
// if(Wp)
|
||
// tempC += dcdzi*(*dzidcI) + dcdzf*(*dzfdcI);
|
||
|
||
// // dLdx
|
||
// dLdx->assign(mmul(iFactor, WxiT) + mmul(fFactor, WxfT) + mmul(eFactor, WxgT) + mmul(oFactor, WxoT)); // [bS, nIn](or[nOut])
|
||
// // NDArray temp = c.ulike();
|
||
// // applyActivation(c, params[9], params[10], params[11], temp); // actH(c)
|
||
// // dLdx->assign(mmul(o*(1-temp*temp)*g*i*(1-i), WxiT) + mmul(o*(1-temp*temp)*(*cI)*f*(1-f), WxfT) + mmul(o*(1-temp*temp)*i*g*(1-g), WxgT) + mmul(temp*o*(1-o), WxoT)); // [bS, nIn](or[nOut])
|
||
|
||
// // dLdhI
|
||
// NDArray* dLdhII = dLdhI;
|
||
// if(dLdcI && !dLdhI)
|
||
// dLdhII = new NDArray(dLdcI->ulike());
|
||
// dLdhII->assign(mmul(iFactor, WriT) + mmul(fFactor, WrfT) + mmul(eFactor, WrgT) + mmul(oFactor, WroT)); // [bS, nOut](or[nOut])
|
||
|
||
// if(firstIter) {
|
||
|
||
// // dLdcI
|
||
// if(dLdcI)
|
||
// dLdcI->assign(factor*tempC); // [bS, nOut](or[nOut])
|
||
|
||
// // dcIdWxi(dcdWxi)
|
||
// dcIdWxi.assign(dcdzi*dWxJacobian); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
|
||
// // dcIdWxf(dcdWxf)
|
||
// dcIdWxf.assign(dcdzf*dWxJacobian);
|
||
// // dcIdWxg(dcdWxg)
|
||
// dcIdWxg.assign(dcdzg*dWxJacobian);
|
||
// // dcIdWxo(dcdWxo) = 0
|
||
// dcIdWxo.nullify();
|
||
|
||
// // dhIdWxi
|
||
// dhIdWxi.assign(dhdc*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
|
||
// // dhIdWxf
|
||
// dhIdWxf.assign(dhdc*dcIdWxf);
|
||
// // dhIdWxg
|
||
// dhIdWxg.assign(dhdc*dcIdWxg);
|
||
// // dhIdWxo
|
||
// dhIdWxo.assign(dhdzo*dWxJacobian /*+ 0 */);
|
||
|
||
// // dcIdWri(dcdWri)
|
||
// dcIdWri.assign(dcdzi*dWrJacobian); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);;
|
||
// // dcIdWrf(dcdWrf)
|
||
// dcIdWrf.assign(dcdzf*dWrJacobian);
|
||
// // dcIdWrg(dcdWrg)
|
||
// dcIdWrg.assign(dcdzg*dWrJacobian);
|
||
// // dcIdWro(dcdWro) = 0
|
||
// dcIdWro.nullify();
|
||
|
||
// // dhIdWri
|
||
// dhIdWri.assign(dhdc*dcIdWri); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut] (or [nOut] * [nIn, nOut, nOut]);
|
||
// // dhIdWrf
|
||
// dhIdWrf.assign(dhdc*dcIdWrf);
|
||
// // dhIdWrg
|
||
// dhIdWrg.assign(dhdc*dcIdWrg);
|
||
// // dhIdWro
|
||
// dhIdWro.assign(dhdzo*dWrJacobian /*+ 0 */);
|
||
|
||
// if(Wp && x->rankOf() == 1) {
|
||
// // dcIdWpi
|
||
// dcIdWpi->assign(dcdzi*(*cI)); // [nOut] * [nOut]
|
||
// // dcIdWpf
|
||
// dcIdWpf->assign(dcdzf*(*cI)); // [nOut] * [nOut]
|
||
// // dcIdWpo
|
||
// dcIdWpo->nullify(); // [nOut]
|
||
|
||
// // dhdWpi
|
||
// dhIdWpi->assign(dhdc*(*dcIdWpi)); // [nOut] * [nOut]
|
||
// // dhdWpf
|
||
// dhIdWpf->assign(dhdc*(*dcIdWpf)); // [nOut] * [nOut]
|
||
// // dhdWpo
|
||
// dhIdWpo->assign(dhdzo*c /* +0*/); // [nOut] * [nOut]
|
||
// }
|
||
// else if(Wp) {
|
||
// // dcIdWpi
|
||
// (dcdzi*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdWpf
|
||
// (dcdzf*(*cI)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdWpo
|
||
// dcIdWpo->nullify(); // [nOut]
|
||
|
||
// // dhIdWpi
|
||
// (*dLdh*dhdc*(dcdzi*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dhIdWpf
|
||
// (*dLdh*dhdc*(dcdzf*(*cI))).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dhIdWpo
|
||
// (*dLdh*dhdzo*c /* +0*/).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// }
|
||
|
||
// if(b && x->rankOf() == 1) {
|
||
// // dcIdbi
|
||
// dcIdbi->assign(dcdzi); // [nOut]
|
||
// // dcIdbf
|
||
// dcIdbf->assign(dcdzf); // [nOut]
|
||
// // dcIdbg
|
||
// dcIdbg->assign(dcdzg); // [nOut]
|
||
// // dcIdbo
|
||
// dcIdbo->nullify(); // [nOut]
|
||
|
||
// //dhIdbi
|
||
// dhIdbi->assign(dhdc*(*dcIdbi)); // [nOut]
|
||
// //dhIdbf
|
||
// dhIdbf->assign(dhdc*(*dcIdbf)); // [nOut]
|
||
// //dhIdbg
|
||
// dhIdbg->assign(dhdc*(*dcIdbg)); // [nOut]
|
||
// //dhIdbo
|
||
// dhIdbo->assign(dhdzo); // [nOut]
|
||
|
||
// }
|
||
// else if(b) {
|
||
// // dcIdbi
|
||
// dcdzi.reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdbf
|
||
// dcdzf.reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdbg
|
||
// dcdzg.reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdbo
|
||
// dcIdbo->nullify(); // [nOut]
|
||
|
||
// //dhIdbi
|
||
// (*dLdh*dhdc*dcdzi).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// //dhIdbf
|
||
// (*dLdh*dhdc*dcdzf).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// //dhIdbg
|
||
// (*dLdh*dhdc*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// //dhIdbo
|
||
// (*dLdh*dhdzo).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
|
||
// }
|
||
// }
|
||
// else {
|
||
|
||
// NDArray tempIFE = mmul(dcdzi, WriT) + mmul(dcdzf, WrfT) + mmul(dcdzg, WrgT);
|
||
// NDArray tempO = mmul(dhdzo, WroT);
|
||
|
||
// // dLdcI
|
||
// if(dLdcI)
|
||
// dLdcI->assign(factor*tempC + (*dLdhII)*(*dhIdcI));
|
||
|
||
// // dcIdWxi(dcdWxi)
|
||
// dcIdWxi.assign(dcdzi*dWxJacobian + tempIFE*dhIdWxi + tempC*dcIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
|
||
// // dcIdWxf(dcdWxf)
|
||
// dcIdWxf.assign(dcdzf*dWxJacobian + tempIFE*dhIdWxf + tempC*dcIdWxf);
|
||
// // dcIdWxg(dcdWxg)
|
||
// dcIdWxg.assign(dcdzg*dWxJacobian + tempIFE*dhIdWxg + tempC*dcIdWxg);
|
||
// // dcIdWxo(dcdWxo)
|
||
// dcIdWxo.assign(/* 0 + */tempIFE * dhIdWxo + tempC*dcIdWxo);
|
||
|
||
// // dhIdWxi
|
||
// dhIdWxi.assign(dhdc*dcIdWxi + tempO*dhIdWxi); // broadcast [bS, nOut] * [nIn, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
|
||
// // dhIdWxf
|
||
// dhIdWxf.assign(dhdc*dcIdWxf + tempO*dhIdWxf);
|
||
// // dhIdWxg
|
||
// dhIdWxg.assign(dhdc*dcIdWxg + tempO*dhIdWxg);
|
||
// // dhIdWxo
|
||
// dhIdWxo.assign(dhdzo*dWxJacobian + dhdc*dcIdWxo + tempO*dhIdWxo);
|
||
|
||
// // dcIdWri(dcdWri)
|
||
// dcIdWri.assign(dcdzi*dWrJacobian + tempIFE*dhIdWri + tempC*dcIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
|
||
// // dcIdWrf(dcdWrf)
|
||
// dcIdWrf.assign(dcdzf*dWrJacobian + tempIFE*dhIdWrf + tempC*dcIdWrf);
|
||
// // dcIdWrg(dcdWrg)
|
||
// dcIdWrg.assign(dcdzg*dWrJacobian + tempIFE*dhIdWrg + tempC*dcIdWrg);
|
||
// // dcIdWro(dcdWro)
|
||
// dcIdWro.assign(/* 0 + */tempIFE * dhIdWro + tempC*dcIdWro);
|
||
|
||
// // dhIdWri
|
||
// dhIdWri.assign(dhdc*dcIdWri + tempO*dhIdWri); // broadcast [bS, nOut] * [nOut, nOut, bS, nOut](or [nOut] * [nIn, nOut, nOut]);
|
||
// // dhIdWrf
|
||
// dhIdWrf.assign(dhdc*dcIdWrf + tempO*dhIdWrf);
|
||
// // dhIdWrg
|
||
// dhIdWrg.assign(dhdc*dcIdWrg + tempO*dhIdWrg);
|
||
// // dhIdWro
|
||
// dhIdWro.assign(dhdzo*dWrJacobian + dhdc*dcIdWro + tempO*dhIdWro);
|
||
|
||
// if(Wp && x->rankOf() == 1) {
|
||
// // dcIdWpi
|
||
// dcIdWpi->assign(dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)); // [nOut] * [nOut]
|
||
// // dcIdWpf
|
||
// dcIdWpf->assign(dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)); // [nOut] * [nOut]
|
||
// // dcIdWpo
|
||
// dcIdWpo->assign(/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)); // [nOut] * [nOut]
|
||
|
||
// // dhdWpi
|
||
// dhIdWpi->assign(dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)); // [nOut] * [nOut]
|
||
// // dhdWpf
|
||
// dhIdWpf->assign(dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)); // [nOut] * [nOut]
|
||
// // dhdWpo
|
||
// dhIdWpo->assign(dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)); // [nOut] * [nOut]
|
||
// }
|
||
// else if(Wp) {
|
||
// // dcIdWpi
|
||
// (dcdzi*(*cI) + tempIFE*(*dhIdWpi) + tempC*(*dcIdWpi)).reduceAlongDimension(reduce::Sum, *dcIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dcIdWpf
|
||
// (dcdzf*(*cI) + tempIFE*(*dhIdWpf) + tempC*(*dcIdWpf)).reduceAlongDimension(reduce::Sum, *dcIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dcIdWpo
|
||
// (/* 0 + */ tempIFE*(*dhIdWpo) + tempC*(*dcIdWpo)).reduceAlongDimension(reduce::Sum, *dcIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
|
||
// // dhIdWpi
|
||
// (dhdc*(*dcIdWpi) + tempO*(*dhIdWpi)).reduceAlongDimension(reduce::Sum, *dhIdWpi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dhIdWpf
|
||
// (dhdc*(*dcIdWpf) + tempO*(*dhIdWpf)).reduceAlongDimension(reduce::Sum, *dhIdWpf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dhIdWpo
|
||
// (dhdzo*c + dhdc*(*dcIdWpo) + tempO*(*dhIdWpo)).reduceAlongDimension(reduce::Sum, *dhIdWpo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// }
|
||
|
||
// if(b && x->rankOf() == 1) {
|
||
// // dcIdbi
|
||
// dcIdbi->assign(dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)); // [nOut]
|
||
// // dcIdbf
|
||
// dcIdbf->assign(dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)); // [nOut]
|
||
// // dcIdbg
|
||
// dcIdbg->assign(dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)); // [nOut]
|
||
// // dcIdbo
|
||
// dcIdbo->assign(/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)); // [nOut]
|
||
|
||
// //dhIdbi
|
||
// dhIdbi->assign(dhdc*(*dcIdbi) + tempO*(*dhIdbi)); // [nOut]
|
||
// //dhIdbf
|
||
// dhIdbf->assign(dhdc*(*dcIdbf) + tempO*(*dhIdbf)); // [nOut]
|
||
// //dhIdbg
|
||
// dhIdbg->assign(dhdc*(*dcIdbg) + tempO*(*dhIdbg)); // [nOut]
|
||
// //dhIdbo
|
||
// dhIdbo->assign(dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)); // [nOut]
|
||
|
||
// }
|
||
// else if(b) {
|
||
// // dcIdbi
|
||
// (dcdzi + tempIFE*(*dhIdbi) + tempC*(*dcIdbi)).reduceAlongDimension(reduce::Sum, *dcIdbi, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdbf
|
||
// (dcdzf + tempIFE*(*dhIdbf) + tempC*(*dcIdbf)).reduceAlongDimension(reduce::Sum, *dcIdbf, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdbg
|
||
// (dcdzg + tempIFE*(*dhIdbg) + tempC*(*dcIdbg)).reduceAlongDimension(reduce::Sum, *dcIdbg, {0}); // [bS, nOut]->reduce->[nOut]
|
||
// // dcIdbo
|
||
// (/*0+*/ tempIFE*(*dhIdbo) + tempC*(*dcIdbo)).reduceAlongDimension(reduce::Sum, *dcIdbo, {0}); // [bS, nOut]->reduce->[nOut]
|
||
|
||
// //dhIdbi
|
||
// (dhdc*(*dcIdbi) + tempO*(*dhIdbi)).reduceAlongDimension(reduce::Sum, *dhIdbi, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// //dhIdbf
|
||
// (dhdc*(*dcIdbf) + tempO*(*dhIdbf)).reduceAlongDimension(reduce::Sum, *dhIdbf, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// //dhIdbg
|
||
// (dhdc*(*dcIdbg) + tempO*(*dhIdbg)).reduceAlongDimension(reduce::Sum, *dhIdbg, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// //dhIdbo
|
||
// (dhdzo + dhdc*(*dcIdbo) + tempO*(*dhIdbo)).reduceAlongDimension(reduce::Sum, *dhIdbo, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
|
||
// }
|
||
// }
|
||
|
||
// const std::vector<int> dimsToExclude = x->rankOf() == 1 ? std::vector<int>({2}) : std::vector<int>({2, 3});
|
||
|
||
// // dLdWxi, dLdWxf, dLdWxg, dLdWxo
|
||
// (*dLdh*(*dhIdWx)).reduceAlongDimension(reduce::Sum, *dLdWx, dimsToExclude);
|
||
|
||
// // dLdWri, dLdWrf, dLdWrg, dLdWro
|
||
// (*dLdh*(*dhIdWr)).reduceAlongDimension(reduce::Sum, *dLdWr, dimsToExclude);
|
||
|
||
// // dLdWpi, dLdWpf, dLdWpo
|
||
// if(Wp) {
|
||
// if(x->rankOf() == 1) {
|
||
// (*dLdWp)({0, nOut}).assign(*dLdh*(*dhIdWpi)); // [nOut] * [nOut]
|
||
// (*dLdWp)({nOut, 2*nOut}).assign(*dLdh*(*dhIdWpf)); // [nOut] * [nOut]
|
||
// (*dLdWp)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdWpo)); // [nOut] * [nOut]
|
||
// }
|
||
// else {
|
||
// // NDArray temp1 = (*dLdWp)({0, nOut});
|
||
// // NDArray temp2 = (*dLdWp)({nOut, 2*nOut});
|
||
// // NDArray temp3 = (*dLdWp)({2*nOut, 3*nOut});
|
||
// // dhIdWpi->reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dhIdWpf->reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // dhIdWpo->reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// (*dLdWp)({0, nOut}).assign(dhIdWpi);
|
||
// (*dLdWp)({nOut, 2*nOut}).assign(dhIdWpf);
|
||
// (*dLdWp)({2*nOut, 3*nOut}).assign(dhIdWpo);
|
||
// }
|
||
// }
|
||
|
||
// // dLdbi, dLdbf, dLdbg, dLdbo
|
||
// if(b) {
|
||
// if(x->rankOf() == 1) {
|
||
// (*dLdb)({0, nOut}).assign(*dLdh*(*dhIdbi)); // [nOut] * [nOut]
|
||
// (*dLdb)({nOut, 2*nOut}).assign(*dLdh*(*dhIdbf)); // [nOut] * [nOut]
|
||
// (*dLdb)({2*nOut, 3*nOut}).assign(*dLdh*(*dhIdbg)); // [nOut] * [nOut]
|
||
// (*dLdb)({3*nOut, 4*nOut}).assign(*dLdh*(*dhIdbo)); // [nOut] * [nOut]
|
||
// }
|
||
// else {
|
||
// // NDArray temp1 = (*dLdb)({0, nOut});
|
||
// // NDArray temp2 = (*dLdb)({nOut, 2*nOut});
|
||
// // NDArray temp3 = (*dLdb)({2*nOut, 3*nOut});
|
||
// // NDArray temp4 = (*dLdb)({3*nOut, 4*nOut});
|
||
// // (*dLdh*(*dhIdbi)).reduceAlongDimension(reduce::Sum, temp1, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // (*dLdh*(*dhIdbf)).reduceAlongDimension(reduce::Sum, temp2, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // (*dLdh*(*dhIdbg)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// // (*dLdh*(*dhIdbo)).reduceAlongDimension(reduce::Sum, temp3, {0}); // ([bS, nOut] * [nOut])->reduce->[nOut]
|
||
// (*dLdb)({0, nOut}).assign(dhIdbi);
|
||
// (*dLdb)({nOut, 2*nOut}).assign(dhIdbf);
|
||
// (*dLdb)({2*nOut, 3*nOut}).assign(dhIdbg);
|
||
// (*dLdb)({3*nOut, 4*nOut}).assign(dhIdbo);
|
||
// }
|
||
// }
|
||
|
||
// //dhIdcI
|
||
// if(dLdcI)
|
||
// dhIdcI->assign(dhdc);
|
||
|
||
// cI->assign(c);
|
||
|
||
// if(dLdcI && !dLdhI)
|
||
// delete dLdhII;
|
||
// if(Wp) {
|
||
// delete Wpi; delete Wpf; delete Wpo; delete dcIdWpi; delete dcIdWpf; delete dcIdWpo; delete dhIdWpi; delete dhIdWpf; delete dhIdWpo;
|
||
// }
|
||
// if(b) {
|
||
// delete dcIdbi; delete dcIdbf; delete dcIdbg; delete dcIdbo; delete dhIdbi; delete dhIdbf; delete dhIdbg; delete dhIdbo;
|
||
// }
|
||
// }
|