/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017 // // @author Yurii Shyrma, created on 05.12.2017 // #include #include namespace nd4j { namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// static FORCEINLINE NDArray activation(const NDArray& arr) { // return (const_cast&>(arr)).template transform>(); auto result = NDArray(&arr, false, arr.getContext()); (const_cast(arr)).applyTransform(transform::Tanh, &result); return result; } ////////////////////////////////////////////////////////////////////////// static FORCEINLINE NDArray sigmoid(const NDArray& arr) { return (const_cast(arr)).transform(transform::Sigmoid); } ////////////////////////////////////////////////////////////////////////// void sruCell(nd4j::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { // x input [bS x inSize], bS - batch size, inSize - number of features // c0 previous cell state c [bS x inSize], that is at previous time step t-1 // w weights [inSize x 3*inSize] // b biases [2*inSize] // h current cell output [bS x inSize], that is at current time step t // c current cell state [bS x inSize], that is at current time step t const int inSize = x->sizeAt(1); // inSize - number of features auto z = mmul(*x, *w); // [bS x 3*inSize] // forget gate = sigmoid(x*Wf + bf) auto f = sigmoid(z({0,0, inSize, 2*inSize}) + (*b)({0, inSize})); // reset gate = sigmoid(x*Wr + br) auto r = sigmoid(z({0,0, 2*inSize, 3*inSize}) + (*b)({inSize, 2*inSize})); // ◦ means element-wise product or so called Hadamard product // current sell state = f◦c0 + (1 - f)◦(x*Wc) c->assign(f * (*c0) + (1.f - f) * z({0, 0 ,0, inSize}) ); // *c = f*(*c0 - z({},{0, inSize})) + z({{},{0, inSize}}); // current cell output = r◦activation(c) + (1 - r)◦x h->assign( r * activation(*c) + (1.f - r) * (*x) ); // *h = r * (activation(c) - *x) + *x; } ////////////////////////////////////////////////////////////////////////// void sruTimeLoop(nd4j::LaunchContext * context, const NDArray* x, const NDArray* c0, const NDArray* w, const NDArray* b, NDArray* h, NDArray* c) { // x input [bS x inSize x time] // c0 initial cell state (at time step = 0) [bS x inSize], // w weights, [3*inSize x inSize] // b biases, [2*inSize] // h cell outputs [bS x inSize x time] // c cell states [bS x inSize x time] auto wT = w->transpose(); // [3*inSize x inSize] -> [inSize x 3*inSize] const int time = x->sizeAt(2); NDArray ct_1(*c0); // loop through time steps for (int t = 0; t < time; ++t) { auto xt = (*x)({0,0, 0,0, t,t+1}); auto ht = (*h)({0,0, 0,0, t,t+1}); auto ct = (*c)({0,0, 0,0, t,t+1}); helpers::sruCell(context, &xt, &ct_1, &wT, b, &ht, &ct); ct_1.assign(ct); } } ////////////////////////////////////////////////////////////////////////// template static void sruBI_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { // x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features // w 2d tensor of weights [2*inSize x 6*inSize] // b row of biases with twice length [1 × 4*inSize] // c0 2d tensor of initial state [bS x 2*inSize] at time t=0 // mask optional, 2d tensor of dropout mask [bS x 2*inSize] // ht [time x bS x 2*inSize] // ct [time x bS x 2*inSize] const Nd4jLong time = x->sizeAt(0); // time - number of time steps const Nd4jLong bS = x->sizeAt(1); // bS - batch size const Nd4jLong inSize = x->sizeAt(2) / 2; // inSize - number of features // x = x * mask if(mask) x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // U [time x bS x 6*inSize] const Nd4jLong d2 = 2*inSize; const Nd4jLong ncols = bS*d2; const Nd4jLong ncolsWi = 3*ncols; T* pI = x->bufferAsT(); T* pWi = wi.bufferAsT(); T* pBias = const_cast(b)->bufferAsT(); T* pInit = const_cast(c0)->bufferAsT(); T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; T* pHt = ht->bufferAsT(); T* pCt = ct->bufferAsT(); Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction T maskVal, cur, bF, bR, ft, rt, val; T *pIVal(nullptr), *pWiVal(nullptr), *pHtVal(nullptr), *pCtVal(nullptr); bool flip = false; for (Nd4jLong col = 0; col < ncols; ++col) { const auto colNum = col % d2; flip = colNum >= inSize; maskVal = mask ? *(pMask + col) : T(1); cur = *(pInit + col); bF = *(pBias + colNum); bR = *(pBias + colNum + d2); pWiVal = pWi + 3*col; pIVal = pI + col; pHtVal = pHt + col; pCtVal = pCt + col; if (flip) { pIVal += (time-1)*ncols; pWiVal += (time-1)*ncolsWi; pHtVal += (time-1)*ncols; pCtVal += (time-1)*ncols; } ncolsRev = flip ? -ncols : ncols; ncolsWiRev = flip ? -ncolsWi : ncolsWi; for (Nd4jLong t = 0; t < time; ++t) { // evaluate sigmoids ft = (1.)/(1. + nd4j::math::nd4j_exp(-(*(pWiVal + 1) + bF))); rt = (1.)/(1. + nd4j::math::nd4j_exp(-(*(pWiVal + 2) + bR))); cur = (cur - *pWiVal)*ft + *pWiVal; *pCtVal = cur; val = nd4j::math::nd4j_tanh(cur); *pHtVal = (val*maskVal - *pIVal)*rt + *pIVal; pIVal += ncolsRev; pWiVal += ncolsWiRev; pCtVal += ncolsRev; pHtVal += ncolsRev; } } } ////////////////////////////////////////////////////////////////////////// template static void sruBIBP_(NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradHt, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { // x input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features // w 2d tensor of weights [2*inSize x 6*inSize] // b row of biases with twice length [1 × 4*inSize] // c0 2d tensor of initial state [bS x 2*inSize] at time t=0 // ct [time x bS x 2*inSize] // inGradC0 [bS x 2*inSize] // inGradHt [time x bS x 2*inSize] // mask optional, 2d tensor of dropout mask [bS x 2*inSize] // gradI [time x bS x 2*inSize] // gradW [time x 2*inSize x 6*inSize] // gradB [1 x 4*inSize] // gradC0 [bS x 2*inSize] const Nd4jLong time = x->sizeAt(0); // time - number of time steps const Nd4jLong bS = x->sizeAt(1); const Nd4jLong inSize = x->sizeAt(2) / 2; // x = x * mask if(mask) x->applyBroadcast(broadcast::Multiply, {1, 2}, mask, x, nullptr); // apply mask // U = x * w NDArray wi = mmul(*x, *w); // [time x bS x 2*inSize] * [2*inSize x 6*inSize] = [time x bS x 6*inSize] NDArray gradBias(x->ordering(), {bS, 4*inSize}, x->dataType(), x->getContext()); NDArray gradWi (x->ordering(), {time, bS, 6*inSize}, x->dataType(), x->getContext()); const Nd4jLong d2 = 2*inSize; const Nd4jLong ncols = bS*d2; const Nd4jLong ncolsWi = 3*ncols; T* pInput = x->bufferAsT(); T* pWi = wi.bufferAsT(); T* pBias = const_cast(b)->bufferAsT(); T* pInit = const_cast(c0)->bufferAsT(); T* pMask = mask ? const_cast(mask)->bufferAsT() : nullptr; T* pState = const_cast(ct)->bufferAsT(); T* pInGradCt = const_cast(inGradC0)->bufferAsT(); T* pInGradHt = const_cast(inGradHt)->bufferAsT(); T* pGradWi = gradWi.bufferAsT(); T* pGradInput = gradI->bufferAsT(); T* pGradBias = gradBias.bufferAsT(); T* pGradInit = gradC0->bufferAsT(); Nd4jLong ncolsRev, ncolsWiRev; // for reverse direction T gbF, gbR, cur, maskVal, bF, bR, ft, rt, val, prevVal, gft, grt, gradSateVal; bool flip = false; T *pInputVal(nullptr), *pWiVal(nullptr), *pStateVal(nullptr), *pInGradHtVal(nullptr), *pGradWiVal(nullptr), *pGradInputVal(nullptr); for (Nd4jLong col = 0; col < ncols; ++col) { gbF = gbR = (T)0.; const auto colNum = col % d2; flip = colNum >= inSize; maskVal = mask ? *(pMask + col) : T(1.); cur = *(pInGradCt + col); bF = *(pBias + colNum); bR = *(pBias + colNum + d2); pWiVal = pWi + 3*col; pInputVal = pInput + col; pStateVal = pState + col; pInGradHtVal = pInGradHt + col; pGradWiVal = pGradWi + 3*col; pGradInputVal = pGradInput + col; if (!flip) { pInputVal += (time-1)*ncols; pWiVal += (time-1)*ncolsWi; pStateVal += (time-1)*ncols; pInGradHtVal += (time-1)*ncols; pGradWiVal += (time-1)*ncolsWi; pGradInputVal += (time-1)*ncols; } ncolsRev = flip ? -ncols : ncols; ncolsWiRev = flip ? -ncolsWi : ncolsWi; for (Nd4jLong t = 0; t < time; ++t) { // evaluate sigmoids ft = ((T)1.)/((T)1. + nd4j::math::nd4j_exp(-(*(pWiVal + 1) + bF))); rt = ((T)1.)/((T)1. + nd4j::math::nd4j_exp(-(*(pWiVal + 2) + bR))); val = nd4j::math::nd4j_tanh(*pStateVal); prevVal = (t < time-1) ? (*(pStateVal - ncolsRev)) : (*(pInit + col)); // grad wrt input *pGradInputVal = *pInGradHtVal - (*pInGradHtVal)*rt ; // grad wrt rt, wiR and bR grt = (*pInGradHtVal) * (val*maskVal - *pInputVal) * (rt - rt*rt); *(pGradWiVal + 2) = grt; gbR += grt; // grad wrt state gradSateVal = (*pInGradHtVal) * maskVal * (rt - rt*val*val) + cur; // grad wrt wi0 *pGradWiVal = gradSateVal - gradSateVal*ft; // grad wrt ft, wi1, and bF gft = gradSateVal * (prevVal - *pWiVal) * (ft - ft*ft); *(pGradWiVal + 1) = gft; gbF += gft; // grad wrt c_previous cur = gradSateVal * ft; pInputVal -= ncolsRev; pWiVal -= ncolsWiRev; pStateVal -= ncolsRev; pGradWiVal -= ncolsWiRev; pGradInputVal -= ncolsRev; pInGradHtVal -= ncolsRev; } *(pGradBias + col) = gbF; *(pGradBias + col + ncols) = gbR; *(pGradInit + col) = cur; } // gradB gradBias.reduceAlongDimension(reduce::Sum, gradB, {0}, false, true); // [1 x 4*inSize] // gradW x->permutei({0, 2, 1}); // [time x bS x 2*inSize] -> [time x 2*inSize x bS] *gradW = mmul(*x, gradWi); // [time x 2*inSize x bS ] * [time x bS x 6*inSize] = [time x 2*inSize x 6*inSize] } void sruBI(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct) { BUILD_SINGLE_SELECTOR(x->dataType(), sruBI_, (x, w, b, c0, mask, ht, ct), FLOAT_TYPES); } void sruBIBP(nd4j::LaunchContext * context, NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0) { BUILD_SINGLE_SELECTOR(x->dataType(), sruBIBP_, (x, w, b, c0, ct, inGradC0, inGradH, mask, gradI, gradW, gradB, gradC0), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template void sruBI_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* mask, NDArray* ht, NDArray* ct), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void sruBIBP_, (NDArray* x, const NDArray* w, const NDArray* b, const NDArray* c0, const NDArray* ct, const NDArray* inGradC0, const NDArray* inGradH, const NDArray* mask, NDArray* gradI, NDArray* gradW, NDArray* gradB, NDArray* gradC0), FLOAT_TYPES); } } } ////////////////////////////////////////////////////////////////////////// // template // void sruCellBP(const std::vector*>& inArrs, const std::vector*>& outArrs) { // NDArray* x = inArrs[0]; // input [bS x inSize], bS - batch size, inSize - number of features // NDArray* c0 = inArrs[1]; // previous cell state c [bS x inSize], that is at previous time step t-1 // NDArray* w = inArrs[2]; // weights [inSize x 3*inSize] // NDArray* b = inArrs[3]; // biases [2*inSize] // NDArray* dLdC = inArrs[4]; // gradient of the loss func with respect to cell output [bS x inSize] // NDArray* dLdH = inArrs[5]; // gradient of the loss func with respect to cell state [bS x inSize] // NDArray* dLdX = outArrs[0]; // gradient of the loss func with respect to input [bS x inSize], so called epsilon // NDArray* dLdW = outArrs[1]; // gradient of the loss func with respect to weights [inSize x 3*inSize] // NDArray* dLdB = outArrs[2]; // gradient of the loss func with respect to biases [2*inSize] // NDArray* dLdC0 = outArrs[3]; // gradient of the loss func with respect to previous cell state [bS, inSize] // const int inSize = x->sizeAt(1); // inSize - number of features // //*********** feed forward ***********// // NDArray z = mmul(*x, *w); // [bS x 3*inSize] // // forget gate = sigmoid(x*Wf + bf) // NDArray f = sigmoid(z({{},{inSize, 2*inSize}}) + (*b)({{0, inSize}})); // [bS, inSize] // NDArray oneMinusF = 1. - f; // // reset gate = sigmoid(x*Wr + br) // NDArray r = sigmoid(z({{},{2*inSize, 3*inSize}}) + (*b)({{inSize, 2*inSize}})); // [bS, inSize] // NDArray oneMinusR = 1. - r; // // current sell state = f◦c0 + (1 - f)◦(x*Wc) ---> c->assign( f*(*c0) + ((T)1. - f) * z({{},{0, inSize}}) ); // // current cell output = r◦activation(c) + (1 - r)◦x ---> h->assign( r*activation(*c) + ((T)1. - r) * (*x) ); // //*********** back propagation ***********// // // dCdC0 = f; // // dFdX = Wf // // dRdX = Wr // NDArray tanh = activation(*c); // NDArray dFdBf = f * oneMinusF; // NDArray dRdBr = r * oneMinusR; // NDArray dHdR = tanh - *x; // // dCdF = c0 - x*Wc; // NDArray dCdF = *c0 - z({{},{0, inSize}}); // // dHdC = r * (1 - tanh*tanh) // NDArray dHdC = r * (1. - tanh * tanh); // // dCdX = dCdX + dCdF*dFdX = (1-f)*Wc + dCdF*Wf // NDArray dCdX = oneMinusF * (*w)({{},{0, inSize}}) + dCdF * (*w)({{},{inSize, 2*inSize}}); // // dLdC0 = dLdC * dCdC0 = dLdC * f // dLdC0->assign((*dLdC) * f); // // dLdBf = dLdH*dHdBf + dLdC*dCdBf = dLdH*dHdC*dCdBf + dLdC*dCdF*dFdBf = dLdH*dHdC*dCdF*dFdBf + dLdC*dCdF*dFdBf = (dLdH*dHdC + dLdC)*dCdF*dFdBf // (*dLdB)({{0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * dCdF * dFdBf); // // dLdBr = dLdH * dHdR * dRdBr // (*dLdB)({{inSize, 2*inSize}}).assign((*dLdH) * dHdR * dRdBr) // // dLdWc = dLdH*dHdWc + dLdC*dCdWc = dLdH*dHdC*dCdWc + dLdC*dCdWc = (dLdH*dHdC + dLdC) * dCdWc = (dLdH*dHdC + dLdC) * (1-f)*x // (*dLdW)({{}, {0, inSize}}).assign(((*dLdH) * dHdC + *dLdC) * oneMinusF * (*x)); // // dLdWf = dLdBf * x // (*dLdW)({{}, {inSize, 2*inSize}}).assign((*dLdB)({{0, inSize}}) * (*x)); // // dLdWr = dLdBr * x // (*dLdW)({{}, {2*inSize, 3*inSize}}).assign((*dLdB)({{inSize, 2*inSize}}) * (*x)); // // dLdX = dLdH*dHdX + dLdC*dCdX = dLdH*(dHdX + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdF*dFdX = dLdH*(1 - r + dHdR*dRdX + dHdC*dCdX) + dLdC*dCdX // dLdX->assign((*dLdH) * (oneMinusR + dHdR * (*w)({{},{2*inSize, 3*inSize}}) + dHdC * dCdX) + (*dLdC) * dCdX); // }