From 70bd925abd102eeef60365e666337e1fe91bb7e0 Mon Sep 17 00:00:00 2001 From: Yurii Date: Thu, 17 Oct 2019 20:44:52 +0300 Subject: [PATCH] - write 2 versions of new lstmLayer: one is based on own code, second uses mkl dnn api --- libnd4j/blas/NDArray.hpp | 9 +- libnd4j/include/helpers/impl/MmulHelper.cpp | 15 + libnd4j/include/loops/legacy_ops.h | 2 + .../generic/recurrent/lstmLayer.cpp | 404 ++++++++ .../ops/declarable/headers/recurrent.h | 5 + .../ops/declarable/helpers/impl/lstmLayer.cpp | 460 +++++++++ .../ops/declarable/helpers/lstmLayer.h | 117 +++ .../declarable/platform/mkldnn/lstmLayer.cpp | 546 ++++++++++ .../declarable/platform/mkldnn/mkldnnUtils.h | 2 + libnd4j/include/ops/ops.h | 22 + .../layers_tests/DeclarableOpsTests13.cpp | 947 ++++++++++++++++++ .../layers_tests/DeclarableOpsTests15.cpp | 6 +- .../tests_cpu/layers_tests/HelpersTests1.cpp | 151 +++ 13 files changed, 2678 insertions(+), 8 deletions(-) create mode 100644 libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp create mode 100644 libnd4j/include/ops/declarable/helpers/lstmLayer.h create mode 100644 libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 917630bce..ba0c34f6c 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -4313,17 +4313,16 @@ Nd4jLong NDArray::getOffset(const Nd4jLong i) const { return shape::getIndexOffset(i, _shapeInfo); } +//////////////////////////////////////////////////////////////////////// NDArray NDArray::like() { - NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext()); - return res; + return NDArray(shapeInfo(), this->dataType(), false, getContext()); } +//////////////////////////////////////////////////////////////////////// NDArray NDArray::ulike() { - // FIXME: it should be non-memset array - NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext()); - return res; + return NDArray(this, false, getContext()); } //////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/helpers/impl/MmulHelper.cpp b/libnd4j/include/helpers/impl/MmulHelper.cpp index ef84cc077..b50104bee 100644 --- a/libnd4j/include/helpers/impl/MmulHelper.cpp +++ b/libnd4j/include/helpers/impl/MmulHelper.cpp @@ -268,6 +268,21 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B, if(aRank == 2 && isBVector) return mmulMxV(A, B, C, alpha, beta, outOrder); + // vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm + if(isAVector && bRank == 2) { + NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M} + NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N} + auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N} + delete A2; + delete C2; + + if(!C) { + result->reshapei({result->lengthOf()}); // result{1,N} -> result{N} + return result; + } + return C; + } + // batched matrix multiplication return mmulNxN(A, B, C, alpha, beta, outOrder); } diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 1fbd06f2b..4b1f3448f 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -119,6 +119,8 @@ #define TRANSFORM_STRICT_OPS \ + (2, ScaledTanh), \ + (3, Affine), \ (4, TanhDerivative), \ (5, HardTanhDerivative), \ (6, SigmoidDerivative), \ diff --git a/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp new file mode 100644 index 000000000..ed1e9e0f3 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/recurrent/lstmLayer.cpp @@ -0,0 +1,404 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#if NOT_EXCLUDED(OP_lstmLayer) + +#include +#include + +namespace nd4j { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + // 2) [bS, sL, nIn] when dataFormat == 1 + // 3) [bS, nIn, sL] when dataFormat == 2 + + // ******* + // input weights Wx: + // 1) [nIn, 4*nOut] when directionMode < 2 + // 2) [2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [nOut, 4*nOut] when directionMode < 2 + // 2) [2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // peephole weights Wp: + // 1) [3*nOut] when directionMode < 2 + // 2) [2, 3*nOut] when directionMode >= 2 + + // ******* + // biases b: + // 1) [4*nOut] when directionMode < 2 + // 2) [2, 4*nOut] when directionMode >= 2 + + // ******* + // sequence length array seqLen: + // 1) [bS] always + + // ******* + // initial output hI: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI): + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + + // OUTPUTS: + + // ******* + // output h: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1 + // 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2 + // 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + // 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1 + // 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2 + // 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 + + // ******* + // output at last step hL: + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL): + // 1) [bS, nOut] when directionMode < 2 + // 2) [2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + + // integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates + const auto cellAct = INT_ARG(3); // activation for cell state (c) + const auto outAct = INT_ARG(4); // activation for output (h) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8; + const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8; + const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8; + const auto gateActHasBeta = gateAct == 3 || gateAct == 6; + const auto cellActHasBeta = cellAct == 3 || cellAct == 6; + const auto outActHasBeta = outAct == 3 || outAct == 6; + + uint count = 1; + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0; + const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0; + const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0; + const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0; + const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0; + const auto outBeta = outActHasBeta ? T_ARG(count++) : 0; + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode); + REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) !"); + REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER operation: please specify what output arrays to produce !"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if(directionMode < 2) { // no bidirectional + + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp)); + } + else { // bidirectional + // Wx validation + if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + // peephole weights validation + if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp)); + } + + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip), + static_cast(gateAct), static_cast(gateAlpha), static_cast(gateBeta), + static_cast(cellAct), static_cast(cellAlpha), static_cast(cellBeta), + static_cast(outAct), static_cast(outAlpha), static_cast(outBeta)}; + + if(directionMode == 0) { // forward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, h, hL, cL); + } + else if(directionMode == 1) { // backward + + helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, h, hL, cL); + } + else { // bidirectional + + NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0}); + NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0}); + NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0}); + NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0}); + + NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr), + *hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), *hFwd(nullptr), *hBwd(nullptr); + + if(Wp) { + WpFwd = new NDArray((*Wp)({0,1, 0,0})); + WpBwd = new NDArray((*Wp)({1,2, 0,0})); + } + if(b) { + bFwd = new NDArray((*b)({0,1, 0,0})); + bBwd = new NDArray((*b)({1,2, 0,0})); + } + if(hI) { + hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0})); + hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0})); + } + if(cI) { + cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0})); + cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0})); + } + if(hL) { + hLFwd = new NDArray((*hL)({0,1, 0,0, 0,0})); + hLBwd = new NDArray((*hL)({1,2, 0,0, 0,0})); + } + if(cL) { + cLFwd = new NDArray((*cL)({0,1, 0,0, 0,0})); + cLBwd = new NDArray((*cL)({1,2, 0,0, 0,0})); + } + + if(h) { + if(directionMode == 2) { // sum + hFwd = h; + hBwd = new NDArray(h, false, h->getContext()); + } + else if(directionMode == 3) { // concat + hFwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, 0,nOut}) : (*h)({0,0, 0,nOut, 0,0})); + hBwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, nOut,2*nOut}) : (*h)({0,0, nOut,2*nOut, 0,0})); + } + else { // directionMode == 4 + hFwd = new NDArray((*h)({0,0, 0,1, 0,0, 0,0})); + hBwd = new NDArray((*h)({0,0, 1,2, 0,0, 0,0})); + } + } + + // FIXME - following two calls are independent and may run in different streams + helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, params, true, hFwd, hLFwd, cLFwd); + helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, params, false, hBwd, hLBwd, cLBwd); + + if(h && directionMode == 2) + *h += *hBwd; + + delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; + delete cIBwd; delete hLFwd; delete hLBwd; delete cLFwd; delete cLBwd; delete hBwd; + if(hFwd != h) + delete hFwd; + } + + return Status::OK(); +} + +DECLARE_TYPES(lstmLayer) { + getOpDescriptor() + ->setAllowedInputTypes(nd4j::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + + +DECLARE_SHAPE_FN(lstmLayer) { + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim + + const auto retFullSeq = B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument) + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) ); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + DataType type; + if(x->isR()) + type = x->dataType(); + else + type = nd4j::DataType::FLOAT32; + + std::vector shapes; + + // evaluate h shape (output) + if(retFullSeq) { + + std::vector hShape; + + if(directionMode <= 2) { // single direction or bidirectional with sum + if(dataFormat == 0) + hShape = {sL, bS, nOut}; + else if(dataFormat == 1) + hShape = {bS, sL, nOut}; + else if(dataFormat == 2) + hShape = {bS, nOut, sL}; + } + else if(directionMode == 3) { // bidirectional with concat + + if(dataFormat == 0) + hShape = {sL, bS, 2*nOut}; + else if(dataFormat == 1) + hShape = {bS, sL, 2*nOut}; + else if(dataFormat == 2) + hShape = {bS, 2*nOut, sL}; + } + else { // bidirectional with extra output dimension equal to 2 + hShape = {sL, 2, bS, nOut}; + } + + shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hShape)); + } + + // evaluate hL shape (output at last step) + if(retLastH) { + + std::vector hLShape; + + if(directionMode < 2) + hLShape = {bS, nOut}; + else + hLShape = {2, bS, nOut}; + + shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hLShape)); + + if(retLastC) // cL and hL have same shapes + shapes.push_back(shapes.back()); + } + + // evaluate cL shape (cell state at last step) + if(retLastC && !retLastH) { + + std::vector cLShape; + + if(directionMode < 2) + cLShape = {bS, nOut}; + else + cLShape = {2, bS, nOut}; + + shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), cLShape)); + } + + return new ShapeList(shapes); +} + + +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index 4b2eddc57..a17db7ec6 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -231,6 +231,11 @@ namespace ops { DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2); #endif + ////////////////////////////////////////////////////////////////////////// + #if NOT_EXCLUDED(OP_lstmLayer) + DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5); + #endif + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp new file mode 100644 index 000000000..528642bb6 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -0,0 +1,460 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +// implementation of operation for LSTM cell with peep hole connections: +// http://www.bioinf.jku.at/publications/older/2604.pdf +// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. +// and +// https://research.google.com/pubs/archive/43905.pdf +// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. + + +#include +#include +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +namespace nd4j { +namespace ops { +namespace helpers { + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* h, NDArray* c) { + + + /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ + /** the objective is to provide math-readable code **/ + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // equations (peephole connections are present) + // it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo) + // ht = ot ◦ tanh(ct) + + + // IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus + + // params[0] - dataFormat, ignore + // params[1] - directionMode, ignore + // params[2] - cell clipping value, if it = 0 then do not apply clipping + + // params[3] - activation ID for input (i), forget (f) and output (o) gates + // params[4] - alpha value for gates activation + // params[5] - beta value for gates activation + + // params[6] - activation ID for cell state (c) + // params[7] - alpha value for cell state activation + // params[8] - beta value for cell state activation + + // params[9] - activation ID for output (h) + // params[10] - alpha value for output activation + // params[11] - beta value for output activation + + // INPUTS: + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr + // c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut] + //or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut] + + // add biases if they are given + if(b != nullptr) + z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut] + + auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut] + auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut] + auto zc = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut] + auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut] + + // peephole connections for input and forget gates + if(Wp != nullptr) { + zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] + zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut] + } + + applyActivation(zi, params[3], params[4], params[5], zi); // inplace + applyActivation(zf, params[3], params[4], params[5], zf); // inplace + applyActivation(zc, params[6], params[7], params[8], zc); // inplace + + c->assign(zf * *cI + zi * zc); // [bS, nOut] ◦ [bS, nOut] + [bS, nOut] ◦ [bS, nOut] = [bS, nOut] + + // if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation + if(params[2] != 0) + c->applyScalar(scalar::LstmClip, params[2]); + + // peephole connections for output gate + if(Wp != nullptr) + zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [nOut] ◦ [bS, nOut] = [bS, nOut] + + applyActivation(zo, params[3], params[4], params[5], zo); + + applyActivation(*c, params[9], params[10], params[11], *h); + *h *= zo; // [bS, nOut] ◦ [bS, nOut] +} + + + +////////////////////////////////////////////////////////////////////////// +void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + const bool forward, + NDArray* h, NDArray* hL, NDArray* cL) { + + // INPUTS: + // x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL], + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // seqLen - [bS], optional, may be nullptr + // hI - initial output [bS, nOut], optional, may be nullptr + // cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + + // OUTPUTS: + // h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr + // hL - output at last step [bS, nOut], optional, may be nullptr + // cL - cell state at last step [bS, nOut], optional, may be nullptr + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + // dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL] + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const Nd4jLong sL = x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + const std::vector shapeOut = {bS, nOut}; + + auto h0 = const_cast(hI); + if(!hI) { + h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + h0->nullify(); + } + + auto c0 = const_cast(cI); + if(!cI) { + c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + c0->nullify(); + } + + auto ct = cL; + if(!cL) + cL = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + + auto ht = hL; + if(!h && !hL) + ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + + // create sets of required (depends on seqLen presence) sub-arrays + std::vector dims; + ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr); + + if(!seqLen) { + + dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes + + xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nIn] + if(h) + hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nOut] + } + else { + + dims = dataFormat == 2 ? std::vector({1}) : std::vector({2}); // points on nIn/nOut axis + + xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [nIn] + h0Set = h0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + c0Set = c0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + ctSet = ct->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + if(h) + hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [nOut] + if(ht) + htSet = ht->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut] + } + + // loops + if(forward) { + + if(!seqLen) { + + if(!h) { // seqLen and h are absent + + lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step + for (int t = 1; t < sL; ++t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps + } + else { // seqLen is absent and h is present + + lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step + for (int t = 1; t < sL; ++t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps + + if(hL) + hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr + } + } + else { + + if(!h) { // seqLen is present and h is absent + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + + for (int t = 1; t < limit; ++t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + } + } + } + else { // seqLen and h are present + + for (int e = 0; e < bS; ++e) { + + int limit = seqLen->e(e); + + if(limit == 0) { + + tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + + continue; + } + + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e); + lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + + for (int t = 1; t < limit; ++t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if(hL) + htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr + + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + } + else { // backward + + if(!seqLen) { + + if(!h) { // seqLen and h are absent + + lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step + for (int t = sL - 2; t >= 0; --t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps + } + else { // seqLen is absent and h is present + + lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step + for (int t = sL - 2; t >= 0; --t) + lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps + + if(hL) + hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr + } + } + else if(directionMode == 1) { // only backward, no bidirectional mode + + if(!h) { // h is absent and seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + + for (int t = sL - 2; t >= sL - limit; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + } + } + } + else { // seqLen and h are present + + for (int e = 0; e < bS; ++e) { + + int limit = seqLen->e(e); + + if(limit == 0) { + + tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + + continue; + } + + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e); + lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + + for (int t = sL - 2; t >= sL - limit; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if(hL) + htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr + + tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + else { // backward in bidirectional mode + + if(!h) { // h is absent and seqLen is present + + for (int e = 0; e < bS; ++e) { + + const int limit = seqLen->e(e); + + if(limit == 0) { + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + continue; + } + + auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps + } + } + } + else { // seqLen and h are present + + for (int e = 0; e < bS; ++e) { + + int limit = seqLen->e(e); + + if(limit == 0) { + + tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range + + if(cL) + ctSet->at(e)->nullify(); + if(hL) + htSet->at(e)->nullify(); + + continue; + } + + auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e); + lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step + + for (int t = limit - 2; t >= 0; --t) { + auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); + lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps + indPrev = indCurr; + } + + if(hL) + htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr + + tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL) + } + } + } + } + + delete xSet; + delete hSet; + delete h0Set; + delete c0Set; + delete htSet; + delete ctSet; +} + + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h new file mode 100644 index 000000000..7d94c32e0 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#ifndef LIBND4J_LSTMLAYER_H +#define LIBND4J_LSTMLAYER_H + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + NDArray* h, NDArray* c); + +////////////////////////////////////////////////////////////////////////// +void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp, + const std::vector& params, + const bool forward, + NDArray* h, NDArray* hL, NDArray* cL); + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) { + + switch (opId) { + case 0: + (const_cast(x)).applyTransform(transform::Tanh, &z); + break; + case 1: + (const_cast(x)).applyScalar(scalar::RELU, 0, &z); + break; + case 2: + (const_cast(x)).applyTransform(transform::Sigmoid, &z); + break; + case 3: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::Affine, &z, &args); + break; + } + case 4: + (const_cast(x)).applyScalar(scalar::LeakyRELU, alpha, &z); + break; + case 5: + helpers::thresholdRelu(x.getContext(), x, alpha, z); + break; + case 6: { + ExtraArguments args({ static_cast(alpha), static_cast(beta)}); + (const_cast(x)).applyTransform(transform::ScaledTanh, &z, &args); + break; + } + case 7: + (const_cast(x)).applyTransform(transform::HardSigmoid, &z); + break; + case 8: + (const_cast(x)).applyScalar(scalar::ELU, alpha, &z); + break; + case 9: + (const_cast(x)).applyTransform(transform::SoftSign, &z); + break; + case 10: + (const_cast(x)).applyTransform(transform::SoftPlus, &z); + break; + default: + throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !"); + } +} + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) { + + if(dataFormat == 0 || dataFormat == 3) + return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn] + + if(dataFormat == 1) + return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn] + + return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL] +} + +////////////////////////////////////////////////////////////////////////// +static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) { + + if(dataFormat == 0 || dataFormat == 3) + return t * bS + b; // TNS: shape [sL, bS, nIn] + + return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] +} + + +} +} +} + + +#endif //LIBND4J_LSTMLAYER_H diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp new file mode 100644 index 000000000..e22487f43 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -0,0 +1,546 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// + +#include +#include "mkldnnUtils.h" + +using namespace mkldnn; + +namespace nd4j { +namespace ops { +namespace platforms { + +static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* Wr, + const NDArray* b, const NDArray* hI, const NDArray* cI, + const std::vector& params, + NDArray* h, NDArray* hL, NDArray* cL) { + + // equations (no peephole connections) + // it = σ(Wxi * xt + Wri * ht-1 + bi) + // ft = σ(Wxf * xt + Wrf * ht-1 + bf) + // c't = tanh(Wxc * xt + Wrc * ht-1 + bc) + // ct = ft ◦ ct-1 + it ◦ c't + // ot = σ(Wxo * xt + Wro * ht-1 + bo) + // ht = ot ◦ tanh(ct) + + // notations: + // bS - batch size + // sL - sequence length, number of time steps + // nIn - input size + // nOut - output size (hidden size) + + // INPUTS: + + // ******* + // input x: + // 1) [sL, bS, nIn] when dataFormat == 0 + + // ******* + // input weights Wx: + // 1) [1, 1, nIn, 4*nOut] when directionMode < 2 + // 2) [1, 2, nIn, 4*nOut] when directionMode >= 2 + + // ******* + // recurrent weights Wr: + // 1) [1, 1, nOut, 4*nOut] when directionMode < 2 + // 2) [1, 2, nOut, 4*nOut] when directionMode >= 2 + + // ******* + // biases b: + // 1) [1, 1, 4*nOut] when directionMode < 2 + // 2) [1, 2, 4*nOut] when directionMode >= 2 + + // ******* + // initial output hI: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // initial cell state cI (same shape as in hI): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + + // OUTPUTS: + + // ******* + // output h: + // 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0 + // 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0 + + // ******* + // output at last step hL: + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // ******* + // cell state at last step cL (same shape as in hL): + // 1) [1, 1, bS, nOut] when directionMode < 2 + // 2) [1, 2, bS, nOut] when directionMode >= 2 + + // !!! dimension 4*nOut implies order it, ft, c't, ot + // !!! dimension 3*nOut implies order it, ft, ot + + // params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + // dataFormat: 0 = [sL, bS, nIn] + // directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = bidirectional concat + + const int dataFormat = params[0]; + const int directionMode = params[1]; + + const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1); + const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0); + const int nIn = x->sizeAt(-1); + const int nOut = Wx->sizeAt(-1); + + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional + const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2) + + // evaluate direction + rnn_direction direction; + switch (directionMode) { + case 0: + direction = rnn_direction::unidirectional_left2right; + break; + case 1: + direction = rnn_direction::unidirectional_right2left; + break; + case 2: + direction = rnn_direction::bidirectional_sum; + break; + default: + direction = rnn_direction::bidirectional_concat; + } + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + mkldnn_memory_desc_t empty; + + mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md, + x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md; + + // input type + mkldnn::memory::data_type xType; + if(x->dataType() == DataType::FLOAT32) + xType = mkldnn::memory::data_type::f32; + else if(x->dataType() == DataType::HALF) + xType = mkldnn::memory::data_type::f16; + else + xType = mkldnn::memory::data_type::u8; + + // weights type + mkldnn::memory::data_type wType = xType; + if(xType == mkldnn::memory::data_type::u8) + wType = mkldnn::memory::data_type::s8; + + // bias type + mkldnn::memory::data_type bType = xType; + if(xType == mkldnn::memory::data_type::u8) + bType = mkldnn::memory::data_type::f32; + + // output type + mkldnn::memory::data_type hType; + if(h->dataType() == DataType::FLOAT32) + hType = mkldnn::memory::data_type::f32; + else if(h->dataType() == DataType::HALF) + hType = mkldnn::memory::data_type::f16; + else + hType = mkldnn::memory::data_type::u8; + + + // memory descriptors for arrays + // x + x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any); + // x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc); + x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc); + x_user_md.data.format_kind = mkldnn_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; + x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; + x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; + + // wx + wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any); + wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); + wx_user_md.data.format_kind = mkldnn_blocked; // overrides format + wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0]; + wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1]; + wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2]; + wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3]; + wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4]; + + // wr + wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any); + wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo); + wr_user_md.data.format_kind = mkldnn_blocked; // overrides format + wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0]; + wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1]; + wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2]; + wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3]; + wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4]; + + // h + h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any); + // h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc); + h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc); + h_user_md.data.format_kind = mkldnn_blocked; // overrides format + h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0]; + h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1]; + h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2]; + + // b + if(b) { + b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any); + b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo); + b_user_md.data.format_kind = mkldnn_blocked; // overrides format + b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0]; + b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1]; + b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2]; + b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3]; + } + + // hI + if(hI) { + hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); + hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); + hI_user_md.data.format_kind = mkldnn_blocked; // overrides format + hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0]; + hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1]; + hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2]; + hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3]; + } + + // cI + if(cI) { + cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any); + cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc); + cI_user_md.data.format_kind = mkldnn_blocked; // overrides format + cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0]; + cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1]; + cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2]; + cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3]; + } + + // hL + if(hL) { + hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any); + hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); + hL_user_md.data.format_kind = mkldnn_blocked; // overrides format + hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0]; + hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1]; + hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2]; + hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3]; + } + + if(cL) { + cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); + cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc); + cL_user_md.data.format_kind = mkldnn_blocked; // overrides format + cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0]; + cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1]; + cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2]; + cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3]; + } + + // lstm memory description + lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction, + x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, + h_lstm_md, hL_lstm_md, cL_lstm_md); + + mkldnn::stream stream(engine); + + // lstm primitive description + lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + // provide memory and check whether reorder is required + // x + auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer()); + const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc(); + auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem; + if (xReorder) + reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem); + args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem; + + // wx + auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer()); + const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc(); + auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem; + if (wxReorder) + reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem); + args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem; + + // wr + auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer()); + const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc(); + auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem; + if (wrReorder) + reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem); + args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem; + + // h + auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer()); + const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc(); + auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem; + args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem; + + // b + if(b) { + auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer()); + const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc(); + auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem; + if (bReorder) + reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem); + args[MKLDNN_ARG_BIAS] = b_lstm_mem; + } + + // hI + if(hI) { + auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer()); + const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc(); + auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem; + if (hIReorder) + reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem); + args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem; + } + + // cI + if(cI) { + auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer()); + const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc(); + auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem; + if (cIReorder) + reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem); + args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem; + } + + bool hLReorder(false), cLReorder(false); + mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem; + + // hL + if(hL) { + hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer()); + hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc(); + hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem; + args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem; + } + + // cL + if(cL) { + cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer()); + cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc(); + cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem; + args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem; + } + + // run calculations + lstm_forward(lstm_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (hReorder) + reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem); + if(hLReorder) + reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem); + if(cLReorder) + reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem); + + stream.wait(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(lstmLayer) { + + const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX) + const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3) + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto hasPH = B_ARG(4); // indicates whether peephole connections are present + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights + + REQUIRE_TRUE(cellClip == 0 , 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported currently !"); + REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !"); + REQUIRE_TRUE(hasPH == false , 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !"); + REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !"); + REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); + REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); + REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + // evaluate dimensions + const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat); + const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2); + const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1); + const Nd4jLong nOut = Wx->sizeAt(-1) / 4; + + // inputs validations + if(directionMode < 2) { // no bidirectional + + // Wx validation + if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + } + else { // bidirectional + // Wx validation + if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx)); + // Wr validation + if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr)); + // biases validation + if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b)); + // initial output validation + if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI)); + // initial cell validation + if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut)) + REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI)); + } + + std::vector params = {static_cast(dataFormat), static_cast(directionMode), static_cast(cellClip)}; + + const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional + + // permut x and h to tnc format if they have ntc format + NDArray* xP(const_cast(x)), *hP(h); + if(dataFormat == 1) { + xP = new NDArray(x->permute({1,0,2})); // [bS, sL, nIn] -> [sL, bS, nIn] + hP = new NDArray(h->permute({1,0,2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn] + } + + // reshape arrays in accordance to mkl allowed formats + NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr); + + WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut})); + WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut})); + if(b) + bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut})); + if(hI) + hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut})); + if(cI) + cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut})); + if(hL) + hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut})); + if(cL) + cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut})); + + lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR); + + delete WxR; + delete WrR; + delete bR; + delete hIR; + delete cIR; + delete hLR; + delete cLR; + + if(dataFormat == 1) { + delete xP; + delete hP; + } + + return Status::OK(); +} + +PLATFORM_CHECK(lstmLayer) { + // we don't want to use mkldnn if cpu doesn't support avx/avx2 + // if (::optimalLevel() < 2) { + // return false; + // } + + const auto hasBiases = B_ARG(0); // indicates whether biases array is provided + const auto hasInitH = B_ARG(2); // indicates whether initial output is provided + const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided + const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1} + const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument) + + const auto x = INPUT_VARIABLE(0); // input + const auto Wx = INPUT_VARIABLE(1); // input weights + const auto Wr = INPUT_VARIABLE(2); // recurrent weights + + int count = 3; + const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases + const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output + const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state + + count = 0; + auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output + auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step + auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step + + DataType xType = x->dataType(); + DataType WxType = Wx->dataType(); + DataType WrType = Wr->dataType(); + DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32); + DataType hIType = hI != nullptr ? hI->dataType() : xType; + DataType cIType = cI != nullptr ? hI->dataType() : xType; + DataType hType = h != nullptr ? h->dataType() : xType; + DataType hLType = hL != nullptr ? hL->dataType() : xType; + DataType cLType = cL != nullptr ? cL->dataType() : xType; + + return block.isUseMKLDNN() && ( + (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || + (xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) || + (xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8)) + ); +} + + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 4e79974a5..8e09624e9 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -63,6 +63,8 @@ namespace nd4j{ DECLARE_PLATFORM(lrn); DECLARE_PLATFORM(batchnorm_new); + + DECLARE_PLATFORM(lstmLayer); } } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index 132b58033..601481b21 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1857,6 +1857,17 @@ namespace simdOps { } }; + template + class Affine { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + return params[0] * d1 + params[1]; + } + }; + template class SigmoidDerivative { public: @@ -2051,6 +2062,17 @@ namespace simdOps { } }; + template + class ScaledTanh { + public: + no_op_exec_special_same + no_op_exec_special_same_cuda + + op_def static X op(X d1, X *params) { + return params[0] * nd4j::math::nd4j_tanh(params[1] * d1); + } + }; + template class RectifiedTanh { public: diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 2ef9e2309..9d460f152 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -983,5 +983,952 @@ TEST_F(DeclarableOpsTests13, mergemax_2) { ASSERT_EQ(20, status); } +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_1) { + + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434, + 0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338, + 0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287, + 0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327, + 0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625}); + + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *h = results->at(0); + auto *cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_2) { + + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735, 0.575735, 0.575735, 0.541562, 0.541562, 0.541562, 0.514003, 0.514003, 0.514003, 0.495597, 0.495597, 0.495597, 0.485999, 0.485999, 0.485999, + 0.596965, 0.596965, 0.596965, 0.571978, 0.571978, 0.571978, 0.552888, 0.552888, 0.552888, 0.540606, 0.540606, 0.540606, 0.534764, 0.534764, 0.534764, + 0.61725 , 0.61725 , 0.61725 , 0.599828, 0.599828, 0.599828, 0.587627, 0.587627, 0.587627, 0.580408, 0.580408, 0.580408, 0.577735, 0.577735, 0.577735}); + + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965, 0.996965, 0.996965, 1.146756, 1.146756, 1.146756, 1.301922, 1.301922, 1.301922}); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *h = results->at(0); + auto *cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_3) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL,bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139, + 0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106, + 0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_4) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289, + -0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647, + 0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395, + -0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862, + 0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308, + -0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128, + 0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, + -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, + -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_5) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406, + 0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 , + 0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739, + 0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681, + 0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492, + -0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609, + -0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + // h->printBuffer(); + // hL->printBuffer(); + // cL->printBuffer(); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_6) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 2; // bidirectional sum + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060, + 0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491, + 0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642, + -0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768, + -0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_7) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.55533 , 0.55533 , 0.55533 , 0.562925, 0.562925, 0.562925, 0.531795, 0.531795, 0.531795, 0.542556, + 0.542556, 0.542556, 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, 0.524805, 0.524805, + 0.524805, 0.539187, 0.539187, 0.539187, 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.538309, 0.538309, 0.538309,0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_8) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 1.; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674, + 0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028, + 0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_9) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + Wp({0,1, 0,0}) = -0.05; + Wp({1,2, 0,0}) = 0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843, + 0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139, + 0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 , + 0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 , + 0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502, + -0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425, + -0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_10) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023, + 0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_11) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209, + 0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627, + 0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, + 0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_12) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + Wp({0,1, 0,0}) = -0.05; + Wp({1,2, 0,0}) = 0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, {0., 0., 0., 0., 0., 0., 0.562925, 0.562925, 0.562925, -0.25361 , -0.25361 , -0.25361 , 0.570404, 0.570404, 0.570404, -0.157103, + -0.157103, -0.157103, 0.57777 , 0.57777 , 0.57777 , -0.116502, -0.116502, -0.116502,0.585023, 0.585023, 0.585023, -0.100025, + -0.100025, -0.100025, 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072, + 0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714,0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438, + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473, + 0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315, + 0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702, + 0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32); + + nd4j::ops::lstmLayer op; + auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto h = results->at(0); + auto hL = results->at(1); + auto cL = results->at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + delete results; + #endif +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index fc7f29e3c..6eabc964a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -505,9 +505,9 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { } TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { - int seqLen = 32; - int bS = 64; - int nIn = 32; + int seqLen = 8; + int bS = 16; + int nIn = 8; auto x0 = NDArrayFactory::create(5); auto x1 = NDArrayFactory::create('f', {bS, nIn, seqLen}); diff --git a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp index 9db8a5f06..2ed43d08a 100644 --- a/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/HelpersTests1.cpp @@ -28,6 +28,7 @@ #include #include #include +#include using namespace nd4j; @@ -2342,5 +2343,155 @@ TEST_F(HelpersTests1, softmaxDerivative_3) { ASSERT_TRUE(expOutput.equalsTo(output)); } +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_1) { + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_2) { + + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 3; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, nd4j::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, nd4j::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_3) { + + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {nIn}, nd4j::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32); + NDArray hI('c', {nOut}, nd4j::DataType::FLOAT32); + NDArray cI('c', {nOut}, nd4j::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32); + + NDArray h('c', {nOut}, nd4j::DataType::FLOAT32); + NDArray c('c', {nOut}, nd4j::DataType::FLOAT32); + + NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32); + NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +}