cavis/libnd4j/include/ops/declarable/helpers/cpu/sru.cpp

401 lines
17 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017
//
// @author Yurii Shyrma, created on 05.12.2017
//
#include<ops/declarable/helpers/sru.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE NDArray activation(const NDArray& arr) {
// return (const_cast<NDArray<T>&>(arr)).template transform<simdOps::Tanh<T>>();
auto result = NDArray(&arr, false, arr.getContext());
(const_cast<NDArray&>(arr)).applyTransform(transform::Tanh, &result);
return result;
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE NDArray sigmoid(const NDArray& arr) {
return (const_cast<NDArray&>(arr)).transform(transform::Sigmoid);
}
//////////////////////////////////////////////////////////////////////////
void sruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) {
// x input [bS x inSize], bS - batch size, inSize - number of features
// c0 previous cell state c [bS x inSize], that is at previous time step t-1
// w weights [inSize x 3*inSize]
// b biases [2*inSize]
// h current cell output [bS x inSize], that is at current time step t
// c current cell state [bS x inSize], that is at current time step t
const int inSize = x->sizeAt(1); // inSize - number of features
auto z = mmul(*x, *w); // [bS x 3*inSize]
// forget gate = sigmoid(x*Wf + bf)
auto f = sigmoid(z({0,0, inSize, 2*inSize}) + (*b)({0, inSize}));
// reset gate = sigmoid(x*Wr + br)
auto r = sigmoid(z({0,0, 2*inSize, 3*inSize}) + (*b)({inSize, 2*inSize}));
// ◦ means element-wise product or so called Hadamard product
// current sell state = f◦c0 + (1 - f)◦(x*Wc)
c->assign(f * (*c0) + (1.f - f) * z({0, 0 ,0, inSize}) );
// *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}});
// current cell output = r◦activation(c) + (1 - r)◦x
h->assign( r * activation(*c) + (1.f - r) * (*x) );
// *h = r * (activation<T>(c) - *x) + *x;
}
//////////////////////////////////////////////////////////////////////////
void sruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) {
// x input [bS x inSize x time]
// c0 initial cell state (at time step = 0) [bS x inSize],
// w weights, [3*inSize x inSize]
// b biases, [2*inSize]
// h cell outputs [bS x inSize x time]
// c cell states [bS x inSize x time]
w = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize]
const int time = x->sizeAt(2);
NDArray ct_1(*c0);
// loop through time steps
for (int t = 0; t < time; ++t) {
auto xt = (*x)({0,0, 0,0, t,t+1});
auto ht = (*h)({0,0, 0,0, t,t+1});
auto ct = (*c)({0,0, 0,0, t,t+1});
helpers::sruCell(context, &xt, &ct_1, w, b, &ht, &ct);
ct_1.assign(ct);
}
delete w;
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
// x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
// w 2d tensor of weights [2*inSize x 6*inSize]
// b row of biases with twice length [1 × 4*inSize]
// c0 2d tensor of initial state [bS x 2*inSize] at time t=0
// mask optional, 2d tensor of dropout mask [bS x 2*inSize]
// ht [time x bS x 2*inSize]
// ct [time x bS x 2*inSize]
const Nd4jLong time = x->sizeAt(0); // time - number of time steps
const Nd4jLong bS = x->sizeAt(1); // bS - batch size
const Nd4jLong inSize = x->sizeAt(2) / 2; // inSize - number of features
// x = x * mask
if(mask)
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
// U = x * w
NDArray wi = mmul(*x, *w); // U [time x bS x 6*inSize]
const Nd4jLong d2 = 2*inSize;
const Nd4jLong ncols = bS*d2;
const Nd4jLong ncolsWi = 3*ncols;
T* pI = x->bufferAsT<T>();
T* pWi = wi.bufferAsT<T>();
T* pBias = const_cast<NDArray*>(b)->bufferAsT<T>();
T* pInit = const_cast<NDArray*>(c0)->bufferAsT<T>();
T* pMask = mask ? const_cast<NDArray*>(mask)->bufferAsT<T>() : nullptr;
T* pHt = ht->bufferAsT<T>();
T* pCt = ct->bufferAsT<T>();
Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction
T maskVal, cur, bF, bR, ft, rt, val;
T *pIVal(nullptr), *pWiVal(nullptr), *pHtVal(nullptr), *pCtVal(nullptr);
bool flip = false;
for (Nd4jLong col = 0; col < ncols; ++col) {
const auto colNum = col % d2;
flip = colNum >= inSize;
maskVal = mask ? *(pMask + col) : T(1);
cur = *(pInit + col);
bF = *(pBias + colNum);
bR = *(pBias + colNum + d2);
pWiVal = pWi + 3*col;
pIVal = pI + col;
pHtVal = pHt + col;
pCtVal = pCt + col;
if (flip) {
pIVal += (time-1)*ncols;
pWiVal += (time-1)*ncolsWi;
pHtVal += (time-1)*ncols;
pCtVal += (time-1)*ncols;
}
ncolsRev = flip ? -ncols : ncols;
ncolsWiRev = flip ? -ncolsWi : ncolsWi;
for (Nd4jLong t = 0; t < time; ++t) {
// evaluate sigmoids
ft = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(*(pWiVal + 1) + bF)));
rt = (1.)/(1. + nd4j::math::nd4j_exp<T, T>(-(*(pWiVal + 2) + bR)));
cur = (cur - *pWiVal)*ft + *pWiVal;
*pCtVal = cur;
val = nd4j::math::nd4j_tanh<T, T>(cur);
*pHtVal = (val*maskVal - *pIVal)*rt + *pIVal;
pIVal += ncolsRev;
pWiVal += ncolsWiRev;
pCtVal += ncolsRev;
pHtVal += ncolsRev;
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask,
NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
// x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
// w 2d tensor of weights [2*inSize x 6*inSize]
// b row of biases with twice length [1 × 4*inSize]
// c0 2d tensor of initial state [bS x 2*inSize] at time t=0
// ct [time x bS x 2*inSize]
// inGradC0 [bS x 2*inSize]
// inGradHt [time x bS x 2*inSize]
// mask optional, 2d tensor of dropout mask [bS x 2*inSize]
// gradI [time x bS x 2*inSize]
// gradW [time x 2*inSize x 6*inSize]
// gradB [1 x 4*inSize]
// gradC0 [bS x 2*inSize]
const Nd4jLong time = x->sizeAt(0); // time - number of time steps
const Nd4jLong bS = x->sizeAt(1);
const Nd4jLong inSize = x->sizeAt(2) / 2;
// x = x * mask
if(mask)
x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask
// U = x * w
NDArray wi = mmul(*x, *w); // [time x bS x 2*inSize] * [2*inSize x 6*inSize] = [time x bS x 6*inSize]
NDArray gradBias(x->ordering(), {bS, 4*inSize}, x->dataType(), x->getContext());
NDArray gradWi (x->ordering(), {time, bS, 6*inSize}, x->dataType(), x->getContext());
const Nd4jLong d2 = 2*inSize;
const Nd4jLong ncols = bS*d2;
const Nd4jLong ncolsWi = 3*ncols;
T* pInput = x->bufferAsT<T>();
T* pWi = wi.bufferAsT<T>();
T* pBias = const_cast<NDArray*>(b)->bufferAsT<T>();
T* pInit = const_cast<NDArray*>(c0)->bufferAsT<T>();
T* pMask = mask ? const_cast<NDArray*>(mask)->bufferAsT<T>() : nullptr;
T* pState = const_cast<NDArray*>(ct)->bufferAsT<T>();
T* pInGradCt = const_cast<NDArray*>(inGradC0)->bufferAsT<T>();
T* pInGradHt = const_cast<NDArray*>(inGradHt)->bufferAsT<T>();
T* pGradWi = gradWi.bufferAsT<T>();
T* pGradInput = gradI->bufferAsT<T>();
T* pGradBias = gradBias.bufferAsT<T>();
T* pGradInit = gradC0->bufferAsT<T>();
Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction
T gbF, gbR, cur, maskVal, bF, bR, ft, rt, val, prevVal, gft, grt, gradSateVal;
bool flip = false;
T *pInputVal(nullptr), *pWiVal(nullptr), *pStateVal(nullptr), *pInGradHtVal(nullptr), *pGradWiVal(nullptr), *pGradInputVal(nullptr);
for (Nd4jLong col = 0; col < ncols; ++col) {
gbF = gbR = (T)0.;
const auto colNum = col % d2;
flip = colNum >= inSize;
maskVal = mask ? *(pMask + col) : T(1.);
cur = *(pInGradCt + col);
bF = *(pBias + colNum);
bR = *(pBias + colNum + d2);
pWiVal = pWi + 3*col;
pInputVal = pInput + col;
pStateVal = pState + col;
pInGradHtVal = pInGradHt + col;
pGradWiVal = pGradWi + 3*col;
pGradInputVal = pGradInput + col;
if (!flip) {
pInputVal += (time-1)*ncols;
pWiVal += (time-1)*ncolsWi;
pStateVal += (time-1)*ncols;
pInGradHtVal += (time-1)*ncols;
pGradWiVal += (time-1)*ncolsWi;
pGradInputVal += (time-1)*ncols;
}
ncolsRev = flip ? -ncols : ncols;
ncolsWiRev = flip ? -ncolsWi : ncolsWi;
for (Nd4jLong t = 0; t < time; ++t) {
// evaluate sigmoids
ft = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 1) + bF)));
rt = ((T)1.)/((T)1. + nd4j::math::nd4j_exp<T,T>(-(*(pWiVal + 2) + bR)));
val = nd4j::math::nd4j_tanh<T,T>(*pStateVal);
prevVal = (t < time-1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col));
// grad wrt input
*pGradInputVal = *pInGradHtVal - (*pInGradHtVal)*rt ;
// grad wrt rt, wiR and bR
grt = (*pInGradHtVal) * (val*maskVal - *pInputVal) * (rt - rt*rt);
*(pGradWiVal + 2) = grt;
gbR += grt;
// grad wrt state
gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt*val*val) + cur;
// grad wrt wi0
*pGradWiVal = gradSateVal - gradSateVal*ft;
// grad wrt ft, wi1, and bF
gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft*ft);
*(pGradWiVal + 1) = gft;
gbF += gft;
// grad wrt c_previous
cur = gradSateVal * ft;
pInputVal -= ncolsRev;
pWiVal -= ncolsWiRev;
pStateVal -= ncolsRev;
pGradWiVal -= ncolsWiRev;
pGradInputVal -= ncolsRev;
pInGradHtVal -= ncolsRev;
}
*(pGradBias + col) = gbF;
*(pGradBias + col + ncols) = gbR;
*(pGradInit + col) = cur;
}
// gradB
gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}, false, true); // [1 x 4*inSize]
// gradW
x->permutei({0, 2, 1}); // [time x bS x 2*inSize] -> [time x 2*inSize x bS]
*gradW = mmul(*x, gradWi); // [time x 2*inSize x bS ] * [time x bS x 6*inSize] = [time x 2*inSize x 6*inSize]
}
void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) {
BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), FLOAT_TYPES);
}
void sruBIBP(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) {
BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBP_, (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void sruBI_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct), FLOAT_TYPES);
BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0), FLOAT_TYPES);
}
}
}
//////////////////////////////////////////////////////////////////////////
// template <typename T>
// void sruCellBP(const std::vector<NDArray<T>*>& inArrs, const std::vector<NDArray<T>*>& outArrs) {
// NDArray<T>* x = inArrs[0]; // input [bS x inSize], bS - batch size, inSize - number of features
// NDArray<T>* c0 = inArrs[1]; // previous cell state c [bS x inSize], that is at previous time step t-1
// NDArray<T>* w = inArrs[2]; // weights [inSize x 3*inSize]
// NDArray<T>* b = inArrs[3]; // biases [2*inSize]
// NDArray<T>* dLdC = inArrs[4]; // gradient of the loss func with respect to cell output [bS x inSize]
// NDArray<T>* dLdH = inArrs[5]; // gradient of the loss func with respect to cell state [bS x inSize]
// NDArray<T>* dLdX = outArrs[0]; // gradient of the loss func with respect to input [bS x inSize], so called epsilon
// NDArray<T>* dLdW = outArrs[1]; // gradient of the loss func with respect to weights [inSize x 3*inSize]
// NDArray<T>* dLdB = outArrs[2]; // gradient of the loss func with respect to biases [2*inSize]
// NDArray<T>* dLdC0 = outArrs[3]; // gradient of the loss func with respect to previous cell state [bS, inSize]
// const int inSize = x->sizeAt(1); // inSize - number of features
// //*********** feed forward ***********//
// NDArray<T> z = mmul(*x, *w); // [bS x 3*inSize]
// // forget gate = sigmoid(x*Wf + bf)
// NDArray<T> f = sigmoid<T>(z({{},{inSize, 2*inSize}}) + (*b)({{0, inSize}})); // [bS, inSize]
// NDArray<T> oneMinusF = 1. - f;
// // reset gate = sigmoid(x*Wr + br)
// NDArray<T> r = sigmoid<T>(z({{},{2*inSize, 3*inSize}}) + (*b)({{inSize, 2*inSize}})); // [bS, inSize]
// NDArray<T> oneMinusR = 1. - r;
// // current sell state = f◦c0 + (1 - f)◦(x*Wc) ---> c->assign( f*(*c0) + ((T)1. - f) * z({{},{0, inSize}}) );
// // current cell output = r◦activation(c) + (1 - r)◦x ---> h->assign( r*activation<T>(*c) + ((T)1. - r) * (*x) );
// //*********** back propagation ***********//
// // dCdC0 = f;
// // dFdX = Wf
// // dRdX = Wr
// NDArray<T> tanh = activation<T>(*c);
// NDArray<T> dFdBf = f * oneMinusF;
// NDArray<T> dRdBr = r * oneMinusR;
// NDArray<T> dHdR = tanh - *x;
// // dCdF = c0 - x*Wc;
// NDArray<T> dCdF = *c0 - z({{},{0, inSize}});
// // dHdC = r * (1 - tanh*tanh)
// NDArray<T> dHdC = r * (1. - tanh * tanh);
// // dCdX = dCdX + dCdF*dFdX = (1-f)*Wc + dCdF*Wf
// NDArray<T> dCdX = oneMinusF * (*w)({{},{0, inSize}}) + dCdF * (*w)({{},{inSize, 2*inSize}});
// // dLdC0 = dLdC * dCdC0 = dLdC * f
// dLdC0->assign((*dLdC) * f);
// // dLdBf = dLdH*dHdBf + dLdC*dCdBf = dLdH*dHdC*dCdBf + dLdC*dCdF*dFdBf = dLdH*dHdC*dCdF*dFdBf + dLdC*dCdF*dFdBf = (dLdH*dHdC + dLdC)*dCdF*dFdBf
// (*dLdB)({{0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * dCdF * dFdBf);
// // dLdBr = dLdH * dHdR * dRdBr
// (*dLdB)({{inSize, 2*inSize}}).assign((*dLdH) * dHdR * dRdBr)
// // dLdWc = dLdH*dHdWc + dLdC*dCdWc = dLdH*dHdC*dCdWc + dLdC*dCdWc = (dLdH*dHdC + dLdC) * dCdWc = (dLdH*dHdC + dLdC) * (1-f)*x
// (*dLdW)({{}, {0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * oneMinusF * (*x));
// // dLdWf = dLdBf * x
// (*dLdW)({{}, {inSize, 2*inSize}}).assign((*dLdB)({{0, inSize}}) * (*x));
// // dLdWr = dLdBr * x
// (*dLdW)({{}, {2*inSize, 3*inSize}}).assign((*dLdB)({{inSize, 2*inSize}}) * (*x));
// // dLdX = dLdH*dHdX + dLdC*dCdX = dLdH*(dHdX + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdF*dFdX = dLdH*(1 - r + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdX
// dLdX->assign((*dLdH) * (oneMinusR + dHdR * (*w)({{},{2*inSize, 3*inSize}}) + dHdC * dCdX) + (*dLdC) * dCdX);
// }