cavis/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp

404 lines
20 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_lstmLayer)
#include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/lstmLayer.h>
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// sL - sequence length, number of time steps
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// *******
// input x:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// input weights Wx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// recurrent weights Wr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// peephole weights Wp:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// biases b:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// *******
// sequence length array seqLen:
// 1) [bS] always
// *******
// initial output hI:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// initial cell state cI (same shape as in hI):
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// OUTPUTS:
// *******
// output h:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// *******
// output at last step hL:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// cell state at last step cL (same shape as in hL):
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(3); // activation for cell state (c)
const auto outAct = INT_ARG(4); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
count = 3;
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode);
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) !");
REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER operation: please specify what output arrays to produce !");
count = 0;
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
// evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
if(directionMode < 2) { // no bidirectional
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
}
else { // bidirectional
// Wx validation
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx).c_str());
// Wr validation
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr).c_str());
// biases validation
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b).c_str());
// initial output validation
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI).c_str());
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI).c_str());
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp).c_str());
}
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
if(directionMode == 0) { // forward
helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, h, hL, cL);
}
else if(directionMode == 1) { // backward
helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, h, hL, cL);
}
else { // bidirectional
NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0});
NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0});
NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0});
NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0});
NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr),
*hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), *hFwd(nullptr), *hBwd(nullptr);
if(Wp) {
WpFwd = new NDArray((*Wp)({0,1, 0,0}));
WpBwd = new NDArray((*Wp)({1,2, 0,0}));
}
if(b) {
bFwd = new NDArray((*b)({0,1, 0,0}));
bBwd = new NDArray((*b)({1,2, 0,0}));
}
if(hI) {
hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0}));
hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0}));
}
if(cI) {
cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0}));
cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0}));
}
if(hL) {
hLFwd = new NDArray((*hL)({0,1, 0,0, 0,0}));
hLBwd = new NDArray((*hL)({1,2, 0,0, 0,0}));
}
if(cL) {
cLFwd = new NDArray((*cL)({0,1, 0,0, 0,0}));
cLBwd = new NDArray((*cL)({1,2, 0,0, 0,0}));
}
if(h) {
if(directionMode == 2) { // sum
hFwd = h;
hBwd = new NDArray(h, false, h->getContext());
}
else if(directionMode == 3) { // concat
hFwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, 0,nOut}) : (*h)({0,0, 0,nOut, 0,0}));
hBwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, nOut,2*nOut}) : (*h)({0,0, nOut,2*nOut, 0,0}));
}
else { // directionMode == 4
hFwd = new NDArray((*h)({0,0, 0,1, 0,0, 0,0}));
hBwd = new NDArray((*h)({0,0, 1,2, 0,0, 0,0}));
}
}
// FIXME - following two calls are independent and may run in different streams
helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, params, true, hFwd, hLFwd, cLFwd);
helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, params, false, hBwd, hLBwd, cLBwd);
if(h && directionMode == 2)
*h += *hBwd;
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd;
delete cIBwd; delete hLFwd; delete hLBwd; delete cLFwd; delete cLBwd; delete hBwd;
if(hFwd != h)
delete hFwd;
}
return Status::OK();
}
DECLARE_TYPES(lstmLayer) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayer) {
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim
const auto retFullSeq = B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument)
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
// evaluate dimensions
const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) );
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
DataType type;
if(x->isR())
type = x->dataType();
else
type = nd4j::DataType::FLOAT32;
std::vector<Nd4jLong*> shapes;
// evaluate h shape (output)
if(retFullSeq) {
std::vector<Nd4jLong> hShape;
if(directionMode <= 2) { // single direction or bidirectional with sum
if(dataFormat == 0)
hShape = {sL, bS, nOut};
else if(dataFormat == 1)
hShape = {bS, sL, nOut};
else if(dataFormat == 2)
hShape = {bS, nOut, sL};
}
else if(directionMode == 3) { // bidirectional with concat
if(dataFormat == 0)
hShape = {sL, bS, 2*nOut};
else if(dataFormat == 1)
hShape = {bS, sL, 2*nOut};
else if(dataFormat == 2)
hShape = {bS, 2*nOut, sL};
}
else { // bidirectional with extra output dimension equal to 2
hShape = {sL, 2, bS, nOut};
}
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hShape));
}
// evaluate hL shape (output at last step)
if(retLastH) {
std::vector<Nd4jLong> hLShape;
if(directionMode < 2)
hLShape = {bS, nOut};
else
hLShape = {2, bS, nOut};
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hLShape));
if(retLastC) // cL and hL have same shapes
shapes.push_back(shapes.back());
}
// evaluate cL shape (cell state at last step)
if(retLastC && !retLastH) {
std::vector<Nd4jLong> cLShape;
if(directionMode < 2)
cLShape = {bS, nOut};
else
cLShape = {2, bS, nOut};
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), cLShape));
}
return new ShapeList(shapes);
}
}
}
#endif