- write 2 versions of new lstmLayer: one is based on own code, second uses mkl dnn api

master
Yurii 2019-10-17 20:44:52 +03:00
parent 630bb3c9b6
commit 70bd925abd
13 changed files with 2678 additions and 8 deletions

View File

@ -4313,17 +4313,16 @@ Nd4jLong NDArray::getOffset(const Nd4jLong i) const {
return shape::getIndexOffset(i, _shapeInfo);
}
////////////////////////////////////////////////////////////////////////
NDArray NDArray::like() {
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
return res;
return NDArray(shapeInfo(), this->dataType(), false, getContext());
}
////////////////////////////////////////////////////////////////////////
NDArray NDArray::ulike() {
// FIXME: it should be non-memset array
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
return res;
return NDArray(this, false, getContext());
}
////////////////////////////////////////////////////////////////////////

View File

@ -268,6 +268,21 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
if(aRank == 2 && isBVector)
return mmulMxV(A, B, C, alpha, beta, outOrder);
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
if(isAVector && bRank == 2) {
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
delete A2;
delete C2;
if(!C) {
result->reshapei({result->lengthOf()}); // result{1,N} -> result{N}
return result;
}
return C;
}
// batched matrix multiplication
return mmulNxN(A, B, C, alpha, beta, outOrder);
}

View File

@ -119,6 +119,8 @@
#define TRANSFORM_STRICT_OPS \
(2, ScaledTanh), \
(3, Affine), \
(4, TanhDerivative), \
(5, HardTanhDerivative), \
(6, SigmoidDerivative), \

View File

@ -0,0 +1,404 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_lstmLayer)
#include <ops/declarable/CustomOperations.h>
#include<ops/declarable/helpers/lstmLayer.h>
namespace nd4j {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// sL - sequence length, number of time steps
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// *******
// input x:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// input weights Wx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// recurrent weights Wr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// peephole weights Wp:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// biases b:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// *******
// sequence length array seqLen:
// 1) [bS] always
// *******
// initial output hI:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// initial cell state cI (same shape as in hI):
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// OUTPUTS:
// *******
// output h:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// *******
// output at last step hL:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// cell state at last step cL (same shape as in hL):
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG(3); // activation for cell state (c)
const auto outAct = INT_ARG(4); // activation for output (h)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
const auto outActHasBeta = outAct == 3 || outAct == 6;
uint count = 1;
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
count = 3;
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode);
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) !");
REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER operation: please specify what output arrays to produce !");
count = 0;
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
// evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
if(directionMode < 2) { // no bidirectional
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
// initial output validation
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
}
else { // bidirectional
// Wx validation
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
// Wr validation
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
// biases validation
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
// initial output validation
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
// peephole weights validation
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
}
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
if(directionMode == 0) { // forward
helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, h, hL, cL);
}
else if(directionMode == 1) { // backward
helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, h, hL, cL);
}
else { // bidirectional
NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0});
NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0});
NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0});
NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0});
NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr),
*hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), *hFwd(nullptr), *hBwd(nullptr);
if(Wp) {
WpFwd = new NDArray((*Wp)({0,1, 0,0}));
WpBwd = new NDArray((*Wp)({1,2, 0,0}));
}
if(b) {
bFwd = new NDArray((*b)({0,1, 0,0}));
bBwd = new NDArray((*b)({1,2, 0,0}));
}
if(hI) {
hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0}));
hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0}));
}
if(cI) {
cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0}));
cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0}));
}
if(hL) {
hLFwd = new NDArray((*hL)({0,1, 0,0, 0,0}));
hLBwd = new NDArray((*hL)({1,2, 0,0, 0,0}));
}
if(cL) {
cLFwd = new NDArray((*cL)({0,1, 0,0, 0,0}));
cLBwd = new NDArray((*cL)({1,2, 0,0, 0,0}));
}
if(h) {
if(directionMode == 2) { // sum
hFwd = h;
hBwd = new NDArray(h, false, h->getContext());
}
else if(directionMode == 3) { // concat
hFwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, 0,nOut}) : (*h)({0,0, 0,nOut, 0,0}));
hBwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, nOut,2*nOut}) : (*h)({0,0, nOut,2*nOut, 0,0}));
}
else { // directionMode == 4
hFwd = new NDArray((*h)({0,0, 0,1, 0,0, 0,0}));
hBwd = new NDArray((*h)({0,0, 1,2, 0,0, 0,0}));
}
}
// FIXME - following two calls are independent and may run in different streams
helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, params, true, hFwd, hLFwd, cLFwd);
helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, params, false, hBwd, hLBwd, cLBwd);
if(h && directionMode == 2)
*h += *hBwd;
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd;
delete cIBwd; delete hLFwd; delete hLBwd; delete cLFwd; delete cLBwd; delete hBwd;
if(hFwd != h)
delete hFwd;
}
return Status::OK();
}
DECLARE_TYPES(lstmLayer) {
getOpDescriptor()
->setAllowedInputTypes(nd4j::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_SHAPE_FN(lstmLayer) {
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim
const auto retFullSeq = B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument)
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
// evaluate dimensions
const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) );
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
DataType type;
if(x->isR())
type = x->dataType();
else
type = nd4j::DataType::FLOAT32;
std::vector<Nd4jLong*> shapes;
// evaluate h shape (output)
if(retFullSeq) {
std::vector<Nd4jLong> hShape;
if(directionMode <= 2) { // single direction or bidirectional with sum
if(dataFormat == 0)
hShape = {sL, bS, nOut};
else if(dataFormat == 1)
hShape = {bS, sL, nOut};
else if(dataFormat == 2)
hShape = {bS, nOut, sL};
}
else if(directionMode == 3) { // bidirectional with concat
if(dataFormat == 0)
hShape = {sL, bS, 2*nOut};
else if(dataFormat == 1)
hShape = {bS, sL, 2*nOut};
else if(dataFormat == 2)
hShape = {bS, 2*nOut, sL};
}
else { // bidirectional with extra output dimension equal to 2
hShape = {sL, 2, bS, nOut};
}
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hShape));
}
// evaluate hL shape (output at last step)
if(retLastH) {
std::vector<Nd4jLong> hLShape;
if(directionMode < 2)
hLShape = {bS, nOut};
else
hLShape = {2, bS, nOut};
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hLShape));
if(retLastC) // cL and hL have same shapes
shapes.push_back(shapes.back());
}
// evaluate cL shape (cell state at last step)
if(retLastC && !retLastH) {
std::vector<Nd4jLong> cLShape;
if(directionMode < 2)
cLShape = {bS, nOut};
else
cLShape = {2, bS, nOut};
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), cLShape));
}
return new ShapeList(shapes);
}
}
}
#endif

View File

@ -231,6 +231,11 @@ namespace ops {
DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2);
#endif
//////////////////////////////////////////////////////////////////////////
#if NOT_EXCLUDED(OP_lstmLayer)
DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5);
#endif
//////////////////////////////////////////////////////////////////////////
/**
* Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi

View File

@ -0,0 +1,460 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
// implementation of operation for LSTM cell with peep hole connections:
// http://www.bioinf.jku.at/publications/older/2604.pdf
// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
// and
// https://research.google.com/pubs/archive/43905.pdf
// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
#include <ops/declarable/helpers/lstmLayer.h>
#include <helpers/ShapeUtils.h>
// #include <VariableSpace.h>
// #include <ops/declarable/CustomOperations.h>
// #include<ops/declarable/helpers/transforms.h>
// #include <ops/declarable/helpers/legacy_helpers.h>
// #include <array/NDArrayList.h>
// #include <iterator>
// #include <MmulHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
NDArray* h, NDArray* c) {
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
/** the objective is to provide math-readable code **/
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
// params[0] - dataFormat, ignore
// params[1] - directionMode, ignore
// params[2] - cell clipping value, if it = 0 then do not apply clipping
// params[3] - activation ID for input (i), forget (f) and output (o) gates
// params[4] - alpha value for gates activation
// params[5] - beta value for gates activation
// params[6] - activation ID for cell state (c)
// params[7] - alpha value for cell state activation
// params[8] - beta value for cell state activation
// params[9] - activation ID for output (h)
// params[10] - alpha value for output activation
// params[11] - beta value for output activation
// INPUTS:
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
// hI - previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// cI - previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// OUTPUTS:
// h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
// c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
//or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
// add biases if they are given
if(b != nullptr)
z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut]
auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut]
auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut]
auto zc = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut]
auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut]
// peephole connections for input and forget gates
if(Wp != nullptr) {
zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut]
zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut]
}
applyActivation(zi, params[3], params[4], params[5], zi); // inplace
applyActivation(zf, params[3], params[4], params[5], zf); // inplace
applyActivation(zc, params[6], params[7], params[8], zc); // inplace
c->assign(zf * *cI + zi * zc); // [bS, nOut] ◦ [bS, nOut] + [bS, nOut] ◦ [bS, nOut] = [bS, nOut]
// if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
if(params[2] != 0)
c->applyScalar(scalar::LstmClip, params[2]);
// peephole connections for output gate
if(Wp != nullptr)
zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [nOut] ◦ [bS, nOut] = [bS, nOut]
applyActivation(zo, params[3], params[4], params[5], zo);
applyActivation(*c, params[9], params[10], params[11], *h);
*h *= zo; // [bS, nOut] ◦ [bS, nOut]
}
//////////////////////////////////////////////////////////////////////////
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
const bool forward,
NDArray* h, NDArray* hL, NDArray* cL) {
// INPUTS:
// x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL],
// Wx - input weights [nIn, 4*nOut]
// Wr - recurrent weights [nOut, 4*nOut]
// b - biases [4*nOut], optional, may be nullptr
// seqLen - [bS], optional, may be nullptr
// hI - initial output [bS, nOut], optional, may be nullptr
// cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
// Wp - peephole weights [3*nOut], optional, may be nullptr
// OUTPUTS:
// h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr
// hL - output at last step [bS, nOut], optional, may be nullptr
// cL - cell state at last step [bS, nOut], optional, may be nullptr
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
// dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]
const int dataFormat = params[0];
const int directionMode = params[1];
const Nd4jLong sL = x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
const std::vector<Nd4jLong> shapeOut = {bS, nOut};
auto h0 = const_cast<NDArray*>(hI);
if(!hI) {
h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
h0->nullify();
}
auto c0 = const_cast<NDArray*>(cI);
if(!cI) {
c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
c0->nullify();
}
auto ct = cL;
if(!cL)
cL = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
auto ht = hL;
if(!h && !hL)
ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
// create sets of required (depends on seqLen presence) sub-arrays
std::vector<int> dims;
ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr);
if(!seqLen) {
dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes
xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nIn]
if(h)
hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nOut]
}
else {
dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2}); // points on nIn/nOut axis
xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [nIn]
h0Set = h0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
c0Set = c0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
ctSet = ct->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
if(h)
hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [nOut]
if(ht)
htSet = ht->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
}
// loops
if(forward) {
if(!seqLen) {
if(!h) { // seqLen and h are absent
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
for (int t = 1; t < sL; ++t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
}
else { // seqLen is absent and h is present
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step
for (int t = 1; t < sL; ++t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
if(hL)
hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr
}
}
else {
if(!h) { // seqLen is present and h is absent
for (int e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
for (int t = 1; t < limit; ++t) {
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
}
}
}
else { // seqLen and h are present
for (int e = 0; e < bS; ++e) {
int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
for (int t = 1; t < limit; ++t) {
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
indPrev = indCurr;
}
if(hL)
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
}
else { // backward
if(!seqLen) {
if(!h) { // seqLen and h are absent
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
for (int t = sL - 2; t >= 0; --t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
}
else { // seqLen is absent and h is present
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step
for (int t = sL - 2; t >= 0; --t)
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
if(hL)
hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr
}
}
else if(directionMode == 1) { // only backward, no bidirectional mode
if(!h) { // h is absent and seqLen is present
for (int e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
for (int t = sL - 2; t >= sL - limit; --t) {
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
}
}
}
else { // seqLen and h are present
for (int e = 0; e < bS; ++e) {
int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
for (int t = sL - 2; t >= sL - limit; --t) {
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
indPrev = indCurr;
}
if(hL)
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
else { // backward in bidirectional mode
if(!h) { // h is absent and seqLen is present
for (int e = 0; e < bS; ++e) {
const int limit = seqLen->e<int>(e);
if(limit == 0) {
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
for (int t = limit - 2; t >= 0; --t) {
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
}
}
}
else { // seqLen and h are present
for (int e = 0; e < bS; ++e) {
int limit = seqLen->e<int>(e);
if(limit == 0) {
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
if(cL)
ctSet->at(e)->nullify();
if(hL)
htSet->at(e)->nullify();
continue;
}
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
for (int t = limit - 2; t >= 0; --t) {
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
indPrev = indCurr;
}
if(hL)
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
}
}
}
}
delete xSet;
delete hSet;
delete h0Set;
delete c0Set;
delete htSet;
delete ctSet;
}
}
}
}

View File

@ -0,0 +1,117 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#ifndef LIBND4J_LSTMLAYER_H
#define LIBND4J_LSTMLAYER_H
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/activations.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
NDArray* h, NDArray* c);
//////////////////////////////////////////////////////////////////////////
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
const std::vector<float>& params,
const bool forward,
NDArray* h, NDArray* hL, NDArray* cL);
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
switch (opId) {
case 0:
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, &z);
break;
case 1:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, &z);
break;
case 2:
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, &z);
break;
case 3: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, &z, &args);
break;
}
case 4:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, &z);
break;
case 5:
helpers::thresholdRelu(x.getContext(), x, alpha, z);
break;
case 6: {
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, &z, &args);
break;
}
case 7:
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, &z);
break;
case 8:
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, &z);
break;
case 9:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, &z);
break;
case 10:
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, &z);
break;
default:
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
}
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
if(dataFormat == 0 || dataFormat == 3)
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
if(dataFormat == 1)
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
}
//////////////////////////////////////////////////////////////////////////
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
if(dataFormat == 0 || dataFormat == 3)
return t * bS + b; // TNS: shape [sL, bS, nIn]
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
}
}
}
}
#endif //LIBND4J_LSTMLAYER_H

View File

@ -0,0 +1,546 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/OpRegistrator.h>
#include "mkldnnUtils.h"
using namespace mkldnn;
namespace nd4j {
namespace ops {
namespace platforms {
static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
const NDArray* b, const NDArray* hI, const NDArray* cI,
const std::vector<float>& params,
NDArray* h, NDArray* hL, NDArray* cL) {
// equations (no peephole connections)
// it = σ(Wxi * xt + Wri * ht-1 + bi)
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// sL - sequence length, number of time steps
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// *******
// input x:
// 1) [sL, bS, nIn] when dataFormat == 0
// *******
// input weights Wx:
// 1) [1, 1, nIn, 4*nOut] when directionMode < 2
// 2) [1, 2, nIn, 4*nOut] when directionMode >= 2
// *******
// recurrent weights Wr:
// 1) [1, 1, nOut, 4*nOut] when directionMode < 2
// 2) [1, 2, nOut, 4*nOut] when directionMode >= 2
// *******
// biases b:
// 1) [1, 1, 4*nOut] when directionMode < 2
// 2) [1, 2, 4*nOut] when directionMode >= 2
// *******
// initial output hI:
// 1) [1, 1, bS, nOut] when directionMode < 2
// 2) [1, 2, bS, nOut] when directionMode >= 2
// *******
// initial cell state cI (same shape as in hI):
// 1) [1, 1, bS, nOut] when directionMode < 2
// 2) [1, 2, bS, nOut] when directionMode >= 2
// OUTPUTS:
// *******
// output h:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
// *******
// output at last step hL:
// 1) [1, 1, bS, nOut] when directionMode < 2
// 2) [1, 2, bS, nOut] when directionMode >= 2
// *******
// cell state at last step cL (same shape as in hL):
// 1) [1, 1, bS, nOut] when directionMode < 2
// 2) [1, 2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
// dataFormat: 0 = [sL, bS, nIn]
// directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = bidirectional concat
const int dataFormat = params[0];
const int directionMode = params[1];
const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1);
const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0);
const int nIn = x->sizeAt(-1);
const int nOut = Wx->sizeAt(-1);
const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional
const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2)
// evaluate direction
rnn_direction direction;
switch (directionMode) {
case 0:
direction = rnn_direction::unidirectional_left2right;
break;
case 1:
direction = rnn_direction::unidirectional_right2left;
break;
case 2:
direction = rnn_direction::bidirectional_sum;
break;
default:
direction = rnn_direction::bidirectional_concat;
}
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
mkldnn_memory_desc_t empty;
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;
// input type
mkldnn::memory::data_type xType;
if(x->dataType() == DataType::FLOAT32)
xType = mkldnn::memory::data_type::f32;
else if(x->dataType() == DataType::HALF)
xType = mkldnn::memory::data_type::f16;
else
xType = mkldnn::memory::data_type::u8;
// weights type
mkldnn::memory::data_type wType = xType;
if(xType == mkldnn::memory::data_type::u8)
wType = mkldnn::memory::data_type::s8;
// bias type
mkldnn::memory::data_type bType = xType;
if(xType == mkldnn::memory::data_type::u8)
bType = mkldnn::memory::data_type::f32;
// output type
mkldnn::memory::data_type hType;
if(h->dataType() == DataType::FLOAT32)
hType = mkldnn::memory::data_type::f32;
else if(h->dataType() == DataType::HALF)
hType = mkldnn::memory::data_type::f16;
else
hType = mkldnn::memory::data_type::u8;
// memory descriptors for arrays
// x
x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any);
// x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc);
x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc);
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
// wx
wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any);
wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
wx_user_md.data.format_kind = mkldnn_blocked; // overrides format
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3];
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
// wr
wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any);
wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
wr_user_md.data.format_kind = mkldnn_blocked; // overrides format
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3];
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
// h
h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any);
// h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc);
h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc);
h_user_md.data.format_kind = mkldnn_blocked; // overrides format
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
// b
if(b) {
b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any);
b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo);
b_user_md.data.format_kind = mkldnn_blocked; // overrides format
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3];
}
// hI
if(hI) {
hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
hI_user_md.data.format_kind = mkldnn_blocked; // overrides format
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3];
}
// cI
if(cI) {
cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
cI_user_md.data.format_kind = mkldnn_blocked; // overrides format
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3];
}
// hL
if(hL) {
hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any);
hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
hL_user_md.data.format_kind = mkldnn_blocked; // overrides format
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3];
}
if(cL) {
cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
cL_user_md.data.format_kind = mkldnn_blocked; // overrides format
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3];
}
// lstm memory description
lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction,
x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md,
h_lstm_md, hL_lstm_md, cL_lstm_md);
mkldnn::stream stream(engine);
// lstm primitive description
lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine);
// arguments (memory buffers) necessary for calculations
std::unordered_map<int, mkldnn::memory> args;
// provide memory and check whether reorder is required
// x
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc();
auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
if (xReorder)
reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem);
args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem;
// wx
auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer());
const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc();
auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
if (wxReorder)
reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem);
args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
// wr
auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer());
const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc();
auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
if (wrReorder)
reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem);
args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem;
// h
auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer());
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem;
// b
if(b) {
auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer());
const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc();
auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
if (bReorder)
reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem);
args[MKLDNN_ARG_BIAS] = b_lstm_mem;
}
// hI
if(hI) {
auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer());
const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc();
auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
if (hIReorder)
reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem);
args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem;
}
// cI
if(cI) {
auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer());
const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc();
auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
if (cIReorder)
reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem);
args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem;
}
bool hLReorder(false), cLReorder(false);
mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
// hL
if(hL) {
hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer());
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem;
}
// cL
if(cL) {
cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer());
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem;
}
// run calculations
lstm_forward(lstm_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (hReorder)
reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem);
if(hLReorder)
reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem);
if(cLReorder)
reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem);
stream.wait();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(lstmLayer) {
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
int count = 3;
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
REQUIRE_TRUE(cellClip == 0 , 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported currently !");
REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !");
REQUIRE_TRUE(hasPH == false , 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !");
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
count = 0;
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
// evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
if(directionMode < 2) { // no bidirectional
// Wx validation
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
// Wr validation
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
// biases validation
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
// initial output validation
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
}
else { // bidirectional
// Wx validation
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
// Wr validation
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
// biases validation
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
// initial output validation
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
// initial cell validation
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
}
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip)};
const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional
// permut x and h to tnc format if they have ntc format
NDArray* xP(const_cast<NDArray*>(x)), *hP(h);
if(dataFormat == 1) {
xP = new NDArray(x->permute({1,0,2})); // [bS, sL, nIn] -> [sL, bS, nIn]
hP = new NDArray(h->permute({1,0,2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn]
}
// reshape arrays in accordance to mkl allowed formats
NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr);
WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut}));
WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut}));
if(b)
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
if(hI)
hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut}));
if(cI)
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
if(hL)
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}));
if(cL)
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}));
lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR);
delete WxR;
delete WrR;
delete bR;
delete hIR;
delete cIR;
delete hLR;
delete cLR;
if(dataFormat == 1) {
delete xP;
delete hP;
}
return Status::OK();
}
PLATFORM_CHECK(lstmLayer) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2) {
// return false;
// }
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto x = INPUT_VARIABLE(0); // input
const auto Wx = INPUT_VARIABLE(1); // input weights
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
int count = 3;
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
count = 0;
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
DataType xType = x->dataType();
DataType WxType = Wx->dataType();
DataType WrType = Wr->dataType();
DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32);
DataType hIType = hI != nullptr ? hI->dataType() : xType;
DataType cIType = cI != nullptr ? hI->dataType() : xType;
DataType hType = h != nullptr ? h->dataType() : xType;
DataType hLType = hL != nullptr ? hL->dataType() : xType;
DataType cLType = cL != nullptr ? cL->dataType() : xType;
return block.isUseMKLDNN() && (
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))
);
}
}
}
}

View File

@ -63,6 +63,8 @@ namespace nd4j{
DECLARE_PLATFORM(lrn);
DECLARE_PLATFORM(batchnorm_new);
DECLARE_PLATFORM(lstmLayer);
}
}

View File

@ -1857,6 +1857,17 @@ namespace simdOps {
}
};
template <typename X>
class Affine {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) {
return params[0] * d1 + params[1];
}
};
template <typename X>
class SigmoidDerivative {
public:
@ -2051,6 +2062,17 @@ namespace simdOps {
}
};
template <typename X>
class ScaledTanh {
public:
no_op_exec_special_same
no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) {
return params[0] * nd4j::math::nd4j_tanh<X, X>(params[1] * d1);
}
};
template <typename X>
class RectifiedTanh {
public:

View File

@ -983,5 +983,952 @@ TEST_F(DeclarableOpsTests13, mergemax_2) {
ASSERT_EQ(20, status);
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_1) {
const int sL = 5;
const int bS = 3;
const int nIn = 3;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = false; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434,
0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338,
0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287,
0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327,
0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625});
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861});
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *h = results->at(0);
auto *cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expClast.isSameShape(cL));
ASSERT_TRUE(expClast.equalsTo(cL));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_2) {
const int sL = 5;
const int bS = 3;
const int nIn = 3;
const int nOut = 3;
// input arguments
const int dataFormat = 1; // [bS,sL,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = false; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735, 0.575735, 0.575735, 0.541562, 0.541562, 0.541562, 0.514003, 0.514003, 0.514003, 0.495597, 0.495597, 0.495597, 0.485999, 0.485999, 0.485999,
0.596965, 0.596965, 0.596965, 0.571978, 0.571978, 0.571978, 0.552888, 0.552888, 0.552888, 0.540606, 0.540606, 0.540606, 0.534764, 0.534764, 0.534764,
0.61725 , 0.61725 , 0.61725 , 0.599828, 0.599828, 0.599828, 0.587627, 0.587627, 0.587627, 0.580408, 0.580408, 0.580408, 0.577735, 0.577735, 0.577735});
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965, 0.996965, 0.996965, 1.146756, 1.146756, 1.146756, 1.301922, 1.301922, 1.301922});
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *h = results->at(0);
auto *cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expClast.isSameShape(cL));
ASSERT_TRUE(expClast.equalsTo(cL));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_3) {
const int sL = 5;
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 1; // backward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = false; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL,bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139,
0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106,
0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_4) {
const int sL = 5;
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 3; // bidirectional concat
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = false; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289,
-0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647,
0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395,
-0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862,
0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308,
-0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128,
0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_5) {
const int sL = 5;
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 1; // [bS,sL,nIn]
const int directionMode = 3; // bidirectional concat
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = false; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406,
0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 ,
0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739,
0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681,
0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492,
-0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609,
-0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
// h->printBuffer();
// hL->printBuffer();
// cL->printBuffer();
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_6) {
const int sL = 5;
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 2; // bidirectional sum
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = false; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060,
0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491,
0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_7) {
#ifndef HAVE_MKLDNN
const int sL = 5;
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
Wp = -0.05;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.55533 , 0.55533 , 0.55533 , 0.562925, 0.562925, 0.562925, 0.531795, 0.531795, 0.531795, 0.542556,
0.542556, 0.542556, 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, 0.524805, 0.524805,
0.524805, 0.539187, 0.539187, 0.539187, 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.538309, 0.538309, 0.538309,0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
#endif
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_8) {
#ifndef HAVE_MKLDNN
const int sL = 5;
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 1; // backward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 1.; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
Wp = -0.05;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674,
0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028,
0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
#endif
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_9) {
#ifndef HAVE_MKLDNN
const int sL = 5;
const int bS = 2;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 3; // bidirectional concat
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = false; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2;
Wp({0,1, 0,0}) = -0.05;
Wp({1,2, 0,0}) = 0.05;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843,
0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139,
0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 ,
0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 ,
0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502,
-0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425,
-0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
#endif
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_10) {
#ifndef HAVE_MKLDNN
const int sL = 6;
const int bS = 5;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 0; // forward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
Wp = -0.05;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023,
0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
#endif
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_11) {
#ifndef HAVE_MKLDNN
const int sL = 6;
const int bS = 5;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 1; // backward
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx = 0.003;
Wr = 0.006;
b = 0.5;
hI = 1.;
cI = 2.;
Wp = -0.05;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209,
0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627,
0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32);
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
#endif
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests13, lstmLayer_12) {
#ifndef HAVE_MKLDNN
const int sL = 6;
const int bS = 5;
const int nIn = 4;
const int nOut = 3;
// input arguments
const int dataFormat = 0; // [sL,bS,nIn]
const int directionMode = 3; // bidirectional concat
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const int cellAct = 0; // tanh activation for cell state
const int outAct = 0; // tanh activation for output
const bool hasBiases = true; // biases array is provided
const bool hasSeqLen = true; // seqLen array is not provided
const auto hasInitH = true; // initial output is provided
const auto hasInitC = true; // initial cell state is provided
const auto hasPH = true; // peephole connections are absent
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
const auto retLastH = true; // do not return output at last time step
const auto retLastC = true; // return cells state at last time step
const double cellClip = 0; // do not apply clipping
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
x.linspace(0.5, 0.5);
Wx({0,1, 0,0, 0,0}) = 0.003;
Wx({1,2, 0,0, 0,0}) = -0.003;
Wr({0,1, 0,0, 0,0}) = 0.006;
Wr({1,2, 0,0, 0,0}) = -0.006;
b({0,1, 0,0}) = 0.5;
b({1,2, 0,0}) = -0.5;
hI({0,1, 0,0, 0,0}) = 1;
hI({1,2, 0,0, 0,0}) = -1;
cI({0,1, 0,0, 0,0}) = 2;
cI({1,2, 0,0, 0,0}) = -2;
Wp({0,1, 0,0}) = -0.05;
Wp({1,2, 0,0}) = 0.05;
std::initializer_list<double> tArgs = {cellClip};
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
NDArray expH('c', {sL, bS, 2*nOut}, {0., 0., 0., 0., 0., 0., 0.562925, 0.562925, 0.562925, -0.25361 , -0.25361 , -0.25361 , 0.570404, 0.570404, 0.570404, -0.157103,
-0.157103, -0.157103, 0.57777 , 0.57777 , 0.57777 , -0.116502, -0.116502, -0.116502,0.585023, 0.585023, 0.585023, -0.100025,
-0.100025, -0.100025, 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072,
0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714,0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473,
0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315,
0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32);
NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702,
0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32);
nd4j::ops::lstmLayer op;
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hL = results->at(1);
auto cL = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHL.isSameShape(hL));
ASSERT_TRUE(expHL.equalsTo(hL));
ASSERT_TRUE(expCL.isSameShape(cL));
ASSERT_TRUE(expCL.equalsTo(cL));
delete results;
#endif
}

View File

@ -505,9 +505,9 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
}
TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
int seqLen = 32;
int bS = 64;
int nIn = 32;
int seqLen = 8;
int bS = 16;
int nIn = 8;
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
auto x1 = NDArrayFactory::create<float>('f', {bS, nIn, seqLen});

View File

@ -28,6 +28,7 @@
#include <MmulHelper.h>
#include <GradCheck.h>
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/lstmLayer.h>
using namespace nd4j;
@ -2342,5 +2343,155 @@ TEST_F(HelpersTests1, softmaxDerivative_3) {
ASSERT_TRUE(expOutput.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
TEST_F(HelpersTests1, lstmLayerCell_1) {
const int bS = 2;
const int nIn = 10;
const int nOut = 4;
const float dataFormat = 0; // is ignored in cell op
const float cellClip = 5; // clipping value
const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
const float cellAct = 0; // tanh activation for cell state
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
const float outAct = 0; // tanh activation for output
const float outAlpha = 0; // alpha value for output activation, not required for tanh
const float outBeta = 0; // beta value for output activation, not required for tanh
NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32);
NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32);
std::vector<float> params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
x = 1.;
hI = 2.;
cI = 3.;
Wx = 0.5;
Wr = 0.4;
Wp = 0.3;
b = 0.7;
nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expC.isSameShape(c));
ASSERT_TRUE(expC.equalsTo(c));
}
///////////////////////////////////////////////////////////////////
TEST_F(HelpersTests1, lstmLayerCell_2) {
const int bS = 2;
const int nIn = 10;
const int nOut = 4;
const float dataFormat = 0; // is ignored in cell op
const float cellClip = 3; // clipping value
const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
const float cellAct = 0; // tanh activation for cell state
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
const float outAct = 0; // tanh activation for output
const float outAlpha = 0; // alpha value for output activation, not required for tanh
const float outBeta = 0; // beta value for output activation, not required for tanh
NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32);
NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, nd4j::DataType::FLOAT32);
NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, nd4j::DataType::FLOAT32);
std::vector<float> params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
x = 1.;
hI = 2.;
cI = 3.;
Wx = 0.5;
Wr = 0.4;
Wp = 0.3;
b = 0.7;
nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expC.isSameShape(c));
ASSERT_TRUE(expC.equalsTo(c));
}
///////////////////////////////////////////////////////////////////
TEST_F(HelpersTests1, lstmLayerCell_3) {
const int nIn = 10;
const int nOut = 4;
const float dataFormat = 0; // is ignored in cell op
const float cellClip = 5; // clipping value
const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
const float cellAct = 0; // tanh activation for cell state
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
const float outAct = 0; // tanh activation for output
const float outAlpha = 0; // alpha value for output activation, not required for tanh
const float outBeta = 0; // beta value for output activation, not required for tanh
NDArray x ('c', {nIn}, nd4j::DataType::FLOAT32);
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32);
NDArray hI('c', {nOut}, nd4j::DataType::FLOAT32);
NDArray cI('c', {nOut}, nd4j::DataType::FLOAT32);
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
NDArray h('c', {nOut}, nd4j::DataType::FLOAT32);
NDArray c('c', {nOut}, nd4j::DataType::FLOAT32);
NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32);
NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32);
std::vector<float> params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
x = 1.;
hI = 2.;
cI = 3.;
Wx = 0.5;
Wr = 0.4;
Wp = 0.3;
b = 0.7;
nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expC.isSameShape(c));
ASSERT_TRUE(expC.equalsTo(c));
}