547 lines
24 KiB
C++
547 lines
24 KiB
C++
|
/*******************************************************************************
|
|||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|||
|
* Copyright (c) 2020 Konduit K.K.
|
|||
|
*
|
|||
|
* ThnIn program and the accompanying materials are made available under the
|
|||
|
* terms of the Apache License, Version 2.0 which nIn available at
|
|||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|||
|
*
|
|||
|
* Unless required by applicable law or agreed to in writing, software
|
|||
|
* dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT
|
|||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|||
|
* License for the specific language governing permnInsions and limitations
|
|||
|
* under the License.
|
|||
|
*
|
|||
|
* SPDX-License-Identifier: Apache-2.0
|
|||
|
******************************************************************************/
|
|||
|
|
|||
|
//
|
|||
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black
|
|||
|
//
|
|||
|
|
|||
|
// implementation of gated Recurrent Unit cell
|
|||
|
// (cf. https://arxiv.org/abs/1406.1078).
|
|||
|
// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio
|
|||
|
// "Learning Phrase Representations using RNN Encoder-Decoder for StatnIntical Machine Translation"
|
|||
|
|
|||
|
|
|||
|
#include <ops/declarable/helpers/gru.h>
|
|||
|
#include <ops/declarable/CustomOperations.h>
|
|||
|
#include <ops/declarable/helpers/transforms.h>
|
|||
|
#include <helpers/MmulHelper.h>
|
|||
|
|
|||
|
namespace sd {
|
|||
|
namespace ops {
|
|||
|
namespace helpers {
|
|||
|
|
|||
|
//////////////////////////////////////////////////////////////////////////
|
|||
|
void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* W, const NDArray* Wc,
|
|||
|
const NDArray* b, const NDArray* bc,
|
|||
|
NDArray* r, NDArray* u, NDArray* c, NDArray* h) {
|
|||
|
|
|||
|
//Inputs:
|
|||
|
// x input [bS, nIn], nIn - input size
|
|||
|
// hI previous cell output [bS, nOut], that is at previous time step t-1, nOut - number of units
|
|||
|
// W RU weights - [nIn+nOut, 2*nOut] - reset and update gates
|
|||
|
// Wc C weights - [nIn+nOut, nOut] - cell gate
|
|||
|
// b r and u biases, [2*nOut] - reset and update gates
|
|||
|
// bc c biases, [nOut] - cell gate
|
|||
|
|
|||
|
//Outputs:
|
|||
|
// r Reset gate output [bS, nOut]
|
|||
|
// u Update gate output [bS, nOut]
|
|||
|
// c Cell gate output [bS, nOut]
|
|||
|
// h current cell output [bS, nOut]
|
|||
|
|
|||
|
/***************************************************************************************/
|
|||
|
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
|||
|
/** however it is more math-friendly and convenient for backprop formulas derivation) **/
|
|||
|
|
|||
|
const int bS = x->sizeAt(0);
|
|||
|
const int nIn = x->sizeAt(1);
|
|||
|
const int nOut = hI->sizeAt(1);
|
|||
|
|
|||
|
NDArray Wrx = (*W)({0,nIn, 0,nOut}); // [nIn, nOut]
|
|||
|
NDArray Wux = (*W)({0,nIn, nOut,2*nOut}); // [nIn, nOut]
|
|||
|
NDArray Wrh = (*W)({nIn,nIn+nOut, 0,nOut}); // [nOut, nOut]
|
|||
|
NDArray Wuh = (*W)({nIn,nIn+nOut, nOut,2*nOut}); // [nOut, nOut]
|
|||
|
|
|||
|
NDArray Wcx = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nOut]
|
|||
|
NDArray Wch = (*Wc)({nIn,nIn+nOut, 0,0}); // updates cell weights [nOut, nOut]
|
|||
|
|
|||
|
NDArray br = (*b)({0, nOut}); // [nOut]
|
|||
|
NDArray bu = (*b)({nOut, 2*nOut}); // [nOut]
|
|||
|
|
|||
|
// × means matrix multipication
|
|||
|
// * means element-wise product or so called Hadamard product
|
|||
|
|
|||
|
// reset gate
|
|||
|
r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut]
|
|||
|
r->applyTransform(transform::Sigmoid, *r);
|
|||
|
|
|||
|
// update gate
|
|||
|
u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut]
|
|||
|
u->applyTransform(transform::Sigmoid, *u);
|
|||
|
|
|||
|
// cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc)
|
|||
|
c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut]
|
|||
|
c->applyTransform(transform::Tanh, *c);
|
|||
|
|
|||
|
// cell output
|
|||
|
h->assign(*u * *hI + (1.f - *u) * *c);
|
|||
|
}
|
|||
|
|
|||
|
//////////////////////////////////////////////////////////////////////////
|
|||
|
void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b,
|
|||
|
NDArray* gates, NDArray* h) {
|
|||
|
|
|||
|
//Inputs:
|
|||
|
// x input [bS, nIn]
|
|||
|
// hI previous cell output [bS, nOut], that is at previous time step t-1
|
|||
|
// Wx weights for x - [nIn, 3*nOut]
|
|||
|
// Wh weights for h - [nOut, 3*nOut]
|
|||
|
// b biases [3*nOut]
|
|||
|
|
|||
|
// 3*nOut means following sequence: reset, update, cell
|
|||
|
|
|||
|
//Outputs:
|
|||
|
// gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + cell gate [bS, nOut]
|
|||
|
// h current cell output [bS, nOut]
|
|||
|
|
|||
|
// formulas:
|
|||
|
// zr = x × Wxr + hI × Whr + br
|
|||
|
// zu = x × Wxu + hI × Whu + bu
|
|||
|
// r = sigmoid(zr)
|
|||
|
// u = sigmoid(zu)
|
|||
|
// zc = x × Wxc + (r * hI) × Whc + bc
|
|||
|
// c = tanh(zc)
|
|||
|
// h = (1-u)*c + u*hI
|
|||
|
|
|||
|
const int bS = x->sizeAt(0);
|
|||
|
const int nIn = x->sizeAt(1);
|
|||
|
const int nOut = hI->sizeAt(1);
|
|||
|
|
|||
|
NDArray temp = gates->ulike();
|
|||
|
MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut]
|
|||
|
temp += *b;
|
|||
|
|
|||
|
MmulHelper::mmul(hI, Wh, gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut]
|
|||
|
|
|||
|
NDArray ru = (*gates)({0,0, 0,2*nOut}); // [bS, 2*nOut]
|
|||
|
|
|||
|
NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut]
|
|||
|
NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut]
|
|||
|
NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut]
|
|||
|
|
|||
|
// reset and update gates
|
|||
|
ru += temp({0,0, 0,2*nOut});
|
|||
|
ru.applyTransform(transform::Sigmoid, ru);
|
|||
|
|
|||
|
// cell gate
|
|||
|
c.assign(c*r + temp({0,0, 2*nOut, 3*nOut}));
|
|||
|
c.applyTransform(transform::Tanh, c);
|
|||
|
|
|||
|
// cell output
|
|||
|
h->assign(u * *hI + (1.f - u) * c);
|
|||
|
}
|
|||
|
|
|||
|
//////////////////////////////////////////////////////////////////////////
|
|||
|
void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) {
|
|||
|
|
|||
|
// sL means time steps
|
|||
|
|
|||
|
// x input [sL, bS, nIn]
|
|||
|
// hI initial cell output (at time step = 0) [bS, nOut]
|
|||
|
// Wx input-to-hidden weights, [nIn, 3*nOut]
|
|||
|
// Wh hidden-to-hidden weights, [nOut, 3*nOut]
|
|||
|
// b biases, [3*nOut]
|
|||
|
|
|||
|
// h cell outputs at each time step [sL, bS, nOut]
|
|||
|
|
|||
|
const int sL = x->sizeAt(0);
|
|||
|
const int bS = x->sizeAt(1);
|
|||
|
const int nOut = hI->sizeAt(1);
|
|||
|
|
|||
|
NDArray gates(h->ordering(), {bS, 3*nOut}, h->dataType(), context);
|
|||
|
|
|||
|
auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn]
|
|||
|
auto hSet = h->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
|
|||
|
|
|||
|
// time loop
|
|||
|
for (int t = 0; t < sL; ++t)
|
|||
|
gruCell(context, xSet.at(t), t == 0 ? hI : hSet.at(t-1), Wx, Wh, b, &gates, hSet.at(t));
|
|||
|
}
|
|||
|
|
|||
|
//////////////////////////////////////////////////////////////////////////
|
|||
|
void gruCellBp(sd::LaunchContext* context,
|
|||
|
const NDArray* x, const NDArray* hLast,
|
|||
|
const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc,
|
|||
|
const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh,
|
|||
|
NDArray* dLdx, NDArray* dLdhLast,
|
|||
|
NDArray* dLdW, NDArray* dLdWc,
|
|||
|
NDArray* dLdb, NDArray* dLdbc) {
|
|||
|
|
|||
|
//Inputs:
|
|||
|
// x input [bS, iS]
|
|||
|
// hLast previous cell output [bS, nU], that is at previous time step t-1
|
|||
|
// W weights - [iS+nU, 2*nU] - reset and update gates
|
|||
|
// Wc C weights - [iS+nU, nU] - cell gate
|
|||
|
// b r and u biases, [2*nU] - reset and update gates
|
|||
|
// bc c biases, [nU] - cell gate
|
|||
|
// dLdr gradient wrt reset gate, [bS, nU]
|
|||
|
// dLdu gradient wrt update gate, [bS, nU]
|
|||
|
// dLdc gradient wrt cell state, [bS, nU]
|
|||
|
// dLdh gradient wrt current cell output, [bS, nU]
|
|||
|
|
|||
|
//Outputs:
|
|||
|
// dLdx gradient wrt x, [bS, iS],
|
|||
|
// dLdhLast gradient wrt hLast, [bS, nU]
|
|||
|
// dLdW gradient wrt W, [iS+nU, 2*nU]
|
|||
|
// dLdWc gradient wrt Wc, [iS+nU, nU]
|
|||
|
// dLdb gradient wrt bru [2*nU]
|
|||
|
// dLdbc gradient wrt bc [nU]
|
|||
|
|
|||
|
// * means element-wise product or so called Hadamard product
|
|||
|
// × means matrix multiplication
|
|||
|
|
|||
|
/************************************************************************************************/
|
|||
|
/******************************* THIS IS NOT OPTIMAZED CODE *************************************/
|
|||
|
/*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/
|
|||
|
|
|||
|
const int bS = x->sizeAt(0);
|
|||
|
const int iS = x->sizeAt(1);
|
|||
|
const int nU = hLast->sizeAt(1);
|
|||
|
|
|||
|
NDArray xT = x->transpose(); // [iS, bS]
|
|||
|
NDArray hLastT = hLast->transpose(); // [nU, bS]
|
|||
|
|
|||
|
NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU]
|
|||
|
NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU]
|
|||
|
NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU]
|
|||
|
NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
|||
|
|
|||
|
NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU]
|
|||
|
NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU]
|
|||
|
|
|||
|
NDArray br = (*b)({0, nU}); // [nU]
|
|||
|
NDArray bu = (*b)({nU, 2*nU}); // [nU]
|
|||
|
|
|||
|
NDArray WrxT = Wrx.transpose(); // [nU, iS]
|
|||
|
NDArray WuxT = Wux.transpose(); // [nU, iS]
|
|||
|
NDArray WrhT = Wrh.transpose(); // [nU, nU]
|
|||
|
NDArray WuhT = Wuh.transpose(); // [nU, nU]
|
|||
|
|
|||
|
NDArray WcxT = Wcx.transpose(); // [nU, iS]
|
|||
|
NDArray WchT = Wch.transpose(); // [nU, nU]
|
|||
|
|
|||
|
NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU]
|
|||
|
NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU]
|
|||
|
NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU]
|
|||
|
NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU]
|
|||
|
|
|||
|
NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU]
|
|||
|
NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU]
|
|||
|
|
|||
|
NDArray dLdbr = (*dLdb)({0, nU}); // [nU]
|
|||
|
NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU]
|
|||
|
|
|||
|
|
|||
|
// ***** feed forward step ***** //
|
|||
|
|
|||
|
// reset gate
|
|||
|
NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
|||
|
r.applyTransform(transform::Sigmoid, r);
|
|||
|
|
|||
|
// update gate
|
|||
|
NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
|||
|
u.applyTransform(transform::Sigmoid, u);
|
|||
|
|
|||
|
// cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc)
|
|||
|
NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU]
|
|||
|
c.applyTransform(transform::Tanh, c);
|
|||
|
|
|||
|
// h = (1 - u) * c + u * hPrev
|
|||
|
|
|||
|
|
|||
|
// ***** back prop step ***** //
|
|||
|
|
|||
|
// notations:
|
|||
|
// Zr = x × Wrx + hLast × Wrh + br
|
|||
|
// Zu = x × Wux + hLast × Wuh + bu
|
|||
|
// Sr = sigmoid(Zr)
|
|||
|
// Su = sigmoid(Zu)
|
|||
|
// Zc = x × Wcx + (r * hlast) × Wch + bc
|
|||
|
|
|||
|
|
|||
|
// dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx
|
|||
|
// = dLdx_u + dLdx_c
|
|||
|
// dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT
|
|||
|
// dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1
|
|||
|
// dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT
|
|||
|
// dZcdr = (... * hLast) × WchT
|
|||
|
// dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT
|
|||
|
// drdx = drdZr * dZrdx
|
|||
|
// dZrdx = ... × WrxT
|
|||
|
// dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT
|
|||
|
// finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT
|
|||
|
|
|||
|
|
|||
|
// dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast
|
|||
|
// = dLdhLast_h + dLdhLast_u + dLdhLast_c
|
|||
|
// dLdhLast_h = dLdh * dhdhLas = dLdh * u
|
|||
|
// dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT
|
|||
|
// dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) =
|
|||
|
// = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast =
|
|||
|
// = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1
|
|||
|
// dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT
|
|||
|
// dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT
|
|||
|
// finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 =
|
|||
|
// = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT
|
|||
|
|
|||
|
|
|||
|
// dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx =
|
|||
|
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx
|
|||
|
// dZrdWrx = xT × ...
|
|||
|
// finally dLdWrx = xT × (dLdr * drdZr)
|
|||
|
|
|||
|
|
|||
|
// dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh =
|
|||
|
// = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh
|
|||
|
// dZrdWrh = hLastT × ...
|
|||
|
// finally dLdWrh = hLastT × (dLdr * drdZr)
|
|||
|
|
|||
|
|
|||
|
// dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux
|
|||
|
// dZudWux = xT × ...
|
|||
|
// dLdu * dudZu * dZudWux = xT × (dLdu * dudZu)
|
|||
|
|
|||
|
|
|||
|
// dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh
|
|||
|
// dZudWuh = hLastT × ...
|
|||
|
// finally dLdWuh = hLastT × (dLdu * dudZu)
|
|||
|
|
|||
|
|
|||
|
// dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx
|
|||
|
// dZcdWcx = xT × ...
|
|||
|
// finally dLdWcx = xT × (dLdc * dcdZc)
|
|||
|
|
|||
|
|
|||
|
// dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch
|
|||
|
// dZcdWch = (r*hLast)^T × ...
|
|||
|
// finally dLdWch = (r*hLast)^T × (dLdc * dcdZc)
|
|||
|
|
|||
|
|
|||
|
// dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr =
|
|||
|
// = dLdr * drdZr * dZrdbr
|
|||
|
// dZrdbr = 1
|
|||
|
// finally dLdbr = dLdr * drdZr
|
|||
|
|
|||
|
|
|||
|
// dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu
|
|||
|
// dZudbu = 1
|
|||
|
// finally dLdbu = dLdu * dudZu
|
|||
|
|
|||
|
|
|||
|
// dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc
|
|||
|
// dZcdbc = 1
|
|||
|
// finally dLdbc = dLdc * dcdZc
|
|||
|
|
|||
|
NDArray dhdc = 1.f - u; // [bS, nU]
|
|||
|
NDArray dhdu = *hLast - c; // [bS, nU]
|
|||
|
NDArray dudZu = u * dhdc; // [bS, nU]
|
|||
|
NDArray drdZr = r * (1.f - r); // [bS, nU]
|
|||
|
NDArray dcdZc = 1.f - c * c; // [bS, nU]
|
|||
|
NDArray dLdZc = *dLdc * dcdZc; // [bS, nU]
|
|||
|
NDArray dLdZu = *dLdu * dudZu; // [bS, nU]
|
|||
|
NDArray dLdZr = *dLdr * drdZr; // [bS, nU]
|
|||
|
|
|||
|
// NDArray dLdc = *dLdh * dhdc; // [bS, nU]
|
|||
|
// NDArray dLdu = *dLdh * dhdu; // [bS, nU]
|
|||
|
// NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU]
|
|||
|
|
|||
|
dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS]
|
|||
|
|
|||
|
dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU]
|
|||
|
|
|||
|
dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU]
|
|||
|
dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU]
|
|||
|
dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU]
|
|||
|
dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU]
|
|||
|
|
|||
|
dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU]
|
|||
|
dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU]
|
|||
|
|
|||
|
dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU]
|
|||
|
dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU]
|
|||
|
|
|||
|
dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU]
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
//////////////////////////////////////////////////////////////////////////
|
|||
|
void gruCellBp(sd::LaunchContext* context,
|
|||
|
const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates,
|
|||
|
NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
|||
|
|
|||
|
//Inputs:
|
|||
|
// x input [bS, nIn]
|
|||
|
// hI previous cell output [bS, nOut], that nIn at previous time step t-1
|
|||
|
// Wx input-to-hidden weights - [nIn, 3*nOut]
|
|||
|
// Wh hidden-to-hidden weights - [nOut, 3*nOut]
|
|||
|
// b biases, [3*nOut] - reset and update gates
|
|||
|
// dLdh gradient vs. ff output, [bS, nOut]
|
|||
|
|
|||
|
//Outputs:
|
|||
|
// dLdx gradient vs. x, [bS, nIn],
|
|||
|
// dLdhI gradient vs. hI, [bS, nOut]
|
|||
|
// dLdWx gradient vs. W, [nIn, 3*nOut]
|
|||
|
// dLdWh gradient vs. Wc, [nOut, 3*nOut]
|
|||
|
// dLdb gradient vs. b [3*nOut]
|
|||
|
|
|||
|
// 3*nOut means following sequence: reset, update, cell
|
|||
|
|
|||
|
// * means element-wnIne product or so called Hadamard product
|
|||
|
// × means matrix multiplication
|
|||
|
|
|||
|
// formulas:
|
|||
|
// zr = x × Wxr + hI × Whr + br
|
|||
|
// zu = x × Wxu + hI × Whu + bu
|
|||
|
// r = sigmoid(zr)
|
|||
|
// u = sigmoid(zu)
|
|||
|
// zc = x × Wxc + (r * hI) × Whc + bc
|
|||
|
// c = tanh(zc)
|
|||
|
// h = (1-u)*c + u*hI
|
|||
|
|
|||
|
// dLdhI += dLdh; [bS, nOut]
|
|||
|
|
|||
|
|
|||
|
// dhdc = 1 - u [bS, nOut]
|
|||
|
// dhdu = -c + hI [bS, nOut]
|
|||
|
|
|||
|
// dcdzc = 1 - c*c; [bS, nOut]
|
|||
|
// dudzu = u*(1-u) [bS, nOut]
|
|||
|
// drdzr = r(1-r) [bS, nOut]
|
|||
|
|
|||
|
// dzcdr = (...*hI × WhcT) [bS, nOut]
|
|||
|
|
|||
|
// dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut]
|
|||
|
// dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut]
|
|||
|
// dLdzc = dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut]
|
|||
|
|
|||
|
// dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, nIn] + ... = [bS, nIn]
|
|||
|
|
|||
|
// dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, nOut] + ... = [bS, nOut]
|
|||
|
|
|||
|
// dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, nOut]
|
|||
|
// dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = [nIn, nOut]
|
|||
|
// dLdWxc = xT × dLdzc [nIn, bS] x [bS, nOut] = [nIn, nOut]
|
|||
|
|
|||
|
// dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = [nOut, nOut]
|
|||
|
// dLdWhu = xT × dLdzu [nOut, bS] x [bS, nOut] = [nOut, nOut]
|
|||
|
// dLdWhc = (r*hI)T × dLdzc [nOut, bS] x [bS, nOut] = [nOut, nOut]
|
|||
|
|
|||
|
// dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
|||
|
// dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
|||
|
// dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut]
|
|||
|
|
|||
|
const int nOut = hI->sizeAt(1);
|
|||
|
|
|||
|
NDArray dLdz = gates->ulike(); // [bS, 3*nOut]
|
|||
|
|
|||
|
NDArray dLdzru = dLdz({0,0, 0,2*nOut}); // [bS, 2*nOut]
|
|||
|
|
|||
|
NDArray dLdzr = dLdz({0,0, 0,nOut}); // [bS, nOut]
|
|||
|
NDArray dLdzu = dLdz({0,0, nOut,2*nOut}); // [bS, nOut]
|
|||
|
NDArray dLdzc = dLdz({0,0, 2*nOut,3*nOut}); // [bS, nOut]
|
|||
|
|
|||
|
NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut]
|
|||
|
NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut]
|
|||
|
NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut]
|
|||
|
|
|||
|
NDArray WhcT = (*Wh)({0,0, 2*nOut,3*nOut}).transpose();
|
|||
|
|
|||
|
if(dLdh)
|
|||
|
*dLdhI += *dLdh;
|
|||
|
|
|||
|
NDArray temp1 = 1 - u; // [bS, nOut]
|
|||
|
|
|||
|
// dLdzc
|
|||
|
dLdzc.assign(*dLdhI * temp1 * (1-c*c)); // [bS, nOut]
|
|||
|
|
|||
|
// dLdzu
|
|||
|
dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut]
|
|||
|
|
|||
|
// dLdzr
|
|||
|
NDArray temp2 = dLdzc * (*hI) * r *(1-r);
|
|||
|
MmulHelper::mmul(&temp2, &WhcT, &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut]
|
|||
|
|
|||
|
// dLdx
|
|||
|
NDArray WxT = Wx->transpose();
|
|||
|
MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn]
|
|||
|
|
|||
|
// dLdWx
|
|||
|
*dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut]
|
|||
|
|
|||
|
// dLdb
|
|||
|
*dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut];
|
|||
|
|
|||
|
dLdzc *= r;
|
|||
|
|
|||
|
// dLdhI
|
|||
|
NDArray WhT = Wh->transpose();
|
|||
|
dLdhI->assign(*dLdhI*u + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut]
|
|||
|
|
|||
|
// dLdWr
|
|||
|
*dLdWh += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut]
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
//////////////////////////////////////////////////////////////////////////
|
|||
|
void gruTimeLoopBp(sd::LaunchContext * context,
|
|||
|
const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh,
|
|||
|
NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) {
|
|||
|
// sL means time steps
|
|||
|
|
|||
|
// x input [sL, bS, nIn]
|
|||
|
// hI initial cell output (at time step = 0) [bS, nOut]
|
|||
|
// Wx input-to-hidden weights, [nIn, 3*nOut]
|
|||
|
// Wh hidden-to-hidden weights, [nOut, 3*nOut]
|
|||
|
// b biases, [3*nOut]
|
|||
|
// dLdh gradient vs. ff output, [sL, bS, nOut]
|
|||
|
|
|||
|
// dLdx gradient vs. x, [sL, bS, nIn],
|
|||
|
// dLdhI gradient vs. hI, [bS, nOut]
|
|||
|
// dLdWx gradient vs. W, [nIn, 3*nOut]
|
|||
|
// dLdWh gradient vs. Wc, [nOut, 3*nOut]
|
|||
|
// dLdb gradient vs. b [3*nOut]
|
|||
|
|
|||
|
const int sL = x->sizeAt(0);
|
|||
|
const int bS = x->sizeAt(1);
|
|||
|
const int nOut = hI->sizeAt(1);
|
|||
|
|
|||
|
NDArray gates(x->ordering(), {sL, bS, 3*nOut}, dLdh->dataType(), x->getContext());
|
|||
|
NDArray h(x->ordering(), {sL+1, bS, nOut}, dLdh->dataType(), x->getContext());
|
|||
|
|
|||
|
auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn]
|
|||
|
auto dLdhSet = dLdh->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
|
|||
|
auto hSet = h.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
|
|||
|
auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut]
|
|||
|
auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn]
|
|||
|
|
|||
|
hSet.at(0)->assign(hI);
|
|||
|
|
|||
|
// forward time loop
|
|||
|
for (int t = 0; t < sL; ++t)
|
|||
|
gruCell(context, xSet.at(t), hSet.at(t), Wx, Wh, b, gatesSet.at(t), hSet.at(t+1));
|
|||
|
|
|||
|
// backward time loop
|
|||
|
for (int t = sL-1; t >= 0; --t)
|
|||
|
gruCellBp(context, xSet.at(t), hSet.at(t), Wx, Wh, b, dLdhSet.at(t), gatesSet.at(t),
|
|||
|
dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb);
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
}
|
|||
|
}
|
|||
|
}
|