- write 2 versions of new lstmLayer: one is based on own code, second uses mkl dnn api
parent
630bb3c9b6
commit
70bd925abd
|
@ -4313,17 +4313,16 @@ Nd4jLong NDArray::getOffset(const Nd4jLong i) const {
|
|||
return shape::getIndexOffset(i, _shapeInfo);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::like() {
|
||||
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
|
||||
|
||||
return res;
|
||||
return NDArray(shapeInfo(), this->dataType(), false, getContext());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
NDArray NDArray::ulike() {
|
||||
// FIXME: it should be non-memset array
|
||||
NDArray res(this->shapeInfo(), this->dataType(), false, this->getContext());
|
||||
|
||||
return res;
|
||||
return NDArray(this, false, getContext());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -268,6 +268,21 @@ nd4j::NDArray* MmulHelper::mmul(const nd4j::NDArray* A, const nd4j::NDArray* B,
|
|||
if(aRank == 2 && isBVector)
|
||||
return mmulMxV(A, B, C, alpha, beta, outOrder);
|
||||
|
||||
// vector x matrix, A{M} x B{M,N} = C{N} -> reduce to matrix x matrix A2{1,M} x B{M,N} = C2{1,N}, since there is no corresponding blas operation sgevm
|
||||
if(isAVector && bRank == 2) {
|
||||
NDArray* A2 = new NDArray(A->reshape(A->ordering(), {1, A->lengthOf()})); // A{M} -> A2{1,M}
|
||||
NDArray* C2 = C ? new NDArray(C->reshape(C->ordering(), {1, C->lengthOf()})) : nullptr; // C{N} -> C2{1,N}
|
||||
auto result = mmulMxM(A2, B, C2, alpha, beta, outOrder); // result{1,N}
|
||||
delete A2;
|
||||
delete C2;
|
||||
|
||||
if(!C) {
|
||||
result->reshapei({result->lengthOf()}); // result{1,N} -> result{N}
|
||||
return result;
|
||||
}
|
||||
return C;
|
||||
}
|
||||
|
||||
// batched matrix multiplication
|
||||
return mmulNxN(A, B, C, alpha, beta, outOrder);
|
||||
}
|
||||
|
|
|
@ -119,6 +119,8 @@
|
|||
|
||||
|
||||
#define TRANSFORM_STRICT_OPS \
|
||||
(2, ScaledTanh), \
|
||||
(3, Affine), \
|
||||
(4, TanhDerivative), \
|
||||
(5, HardTanhDerivative), \
|
||||
(6, SigmoidDerivative), \
|
||||
|
|
|
@ -0,0 +1,404 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_lstmLayer)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include<ops/declarable/helpers/lstmLayer.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(lstmLayer, 3, 1, false, 1, 5) {
|
||||
|
||||
// equations (no peephole connections)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// equations (peephole connections are present)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// notations:
|
||||
// bS - batch size
|
||||
// sL - sequence length, number of time steps
|
||||
// nIn - input size
|
||||
// nOut - output size (hidden size)
|
||||
|
||||
// INPUTS:
|
||||
|
||||
// *******
|
||||
// input x:
|
||||
// 1) [sL, bS, nIn] when dataFormat == 0
|
||||
// 2) [bS, sL, nIn] when dataFormat == 1
|
||||
// 3) [bS, nIn, sL] when dataFormat == 2
|
||||
|
||||
// *******
|
||||
// input weights Wx:
|
||||
// 1) [nIn, 4*nOut] when directionMode < 2
|
||||
// 2) [2, nIn, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// recurrent weights Wr:
|
||||
// 1) [nOut, 4*nOut] when directionMode < 2
|
||||
// 2) [2, nOut, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// peephole weights Wp:
|
||||
// 1) [3*nOut] when directionMode < 2
|
||||
// 2) [2, 3*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// biases b:
|
||||
// 1) [4*nOut] when directionMode < 2
|
||||
// 2) [2, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// sequence length array seqLen:
|
||||
// 1) [bS] always
|
||||
|
||||
// *******
|
||||
// initial output hI:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// initial cell state cI (same shape as in hI):
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
|
||||
// OUTPUTS:
|
||||
|
||||
// *******
|
||||
// output h:
|
||||
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
||||
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
|
||||
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
|
||||
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
|
||||
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
|
||||
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
|
||||
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
|
||||
|
||||
// *******
|
||||
// output at last step hL:
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// cell state at last step cL (same shape as in hL):
|
||||
// 1) [bS, nOut] when directionMode < 2
|
||||
// 2) [2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||
// !!! dimension 3*nOut implies order it, ft, ot
|
||||
|
||||
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
|
||||
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||
|
||||
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
const auto gateAct = INT_ARG(2); // activation for input (i), forget (f) and output (o) gates
|
||||
const auto cellAct = INT_ARG(3); // activation for cell state (c)
|
||||
const auto outAct = INT_ARG(4); // activation for output (h)
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
|
||||
const auto gateActHasAlpha = gateAct == 3 || gateAct == 4 || gateAct == 5 || gateAct == 6 || gateAct == 8;
|
||||
const auto cellActHasAlpha = cellAct == 3 || cellAct == 4 || cellAct == 5 || cellAct == 6 || cellAct == 8;
|
||||
const auto outActHasAlpha = outAct == 3 || outAct == 4 || outAct == 5 || outAct == 6 || outAct == 8;
|
||||
const auto gateActHasBeta = gateAct == 3 || gateAct == 6;
|
||||
const auto cellActHasBeta = cellAct == 3 || cellAct == 6;
|
||||
const auto outActHasBeta = outAct == 3 || outAct == 6;
|
||||
|
||||
uint count = 1;
|
||||
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||
const auto gateAlpha = gateActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto gateBeta = gateActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto cellAlpha = cellActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto cellBeta = cellActHasBeta ? T_ARG(count++) : 0;
|
||||
const auto outAlpha = outActHasAlpha ? T_ARG(count++) : 0;
|
||||
const auto outBeta = outActHasBeta ? T_ARG(count++) : 0;
|
||||
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
|
||||
count = 3;
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
|
||||
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
|
||||
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
|
||||
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||
|
||||
REQUIRE_TRUE(dataFormat < 3 || (dataFormat == 3 && directionMode == 4), 0, "LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead !", dataFormat, directionMode);
|
||||
REQUIRE_TRUE(cellClip >= 0 , 0, "LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) !");
|
||||
REQUIRE_TRUE(retFullSeq || retLastH || retLastC, 0, "LSTM_LAYER operation: please specify what output arrays to produce !");
|
||||
|
||||
count = 0;
|
||||
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
||||
auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step
|
||||
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
// inputs validations
|
||||
if(directionMode < 2) { // no bidirectional
|
||||
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
||||
// initial output validation
|
||||
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
||||
// initial cell validation
|
||||
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
||||
// peephole weights validation
|
||||
if(Wp != nullptr && (Wp->rankOf() != 1 || Wp->sizeAt(0) != 3*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
|
||||
}
|
||||
else { // bidirectional
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
||||
// initial output validation
|
||||
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
||||
// initial cell validation
|
||||
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
||||
// peephole weights validation
|
||||
if(Wp != nullptr && (Wp->rankOf() != 2 || Wp->sizeAt(0) != 2 || Wp->sizeAt(1) != 3*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 3*nOut}).c_str(), ShapeUtils::shapeAsString(Wp));
|
||||
}
|
||||
|
||||
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip),
|
||||
static_cast<float>(gateAct), static_cast<float>(gateAlpha), static_cast<float>(gateBeta),
|
||||
static_cast<float>(cellAct), static_cast<float>(cellAlpha), static_cast<float>(cellBeta),
|
||||
static_cast<float>(outAct), static_cast<float>(outAlpha), static_cast<float>(outBeta)};
|
||||
|
||||
if(directionMode == 0) { // forward
|
||||
|
||||
helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, true, h, hL, cL);
|
||||
}
|
||||
else if(directionMode == 1) { // backward
|
||||
|
||||
helpers::lstmLayerTimeLoop(x, Wx, Wr, b, seqLen, hI, cI, Wp, params, false, h, hL, cL);
|
||||
}
|
||||
else { // bidirectional
|
||||
|
||||
NDArray WxFwd = (*Wx)({0,1, 0,0, 0,0});
|
||||
NDArray WxBwd = (*Wx)({1,2, 0,0, 0,0});
|
||||
NDArray WrFwd = (*Wr)({0,1, 0,0, 0,0});
|
||||
NDArray WrBwd = (*Wr)({1,2, 0,0, 0,0});
|
||||
|
||||
NDArray *WpFwd(nullptr), *WpBwd(nullptr), *bFwd(nullptr), *bBwd(nullptr), *hIFwd(nullptr), *hIBwd(nullptr), *cIFwd(nullptr), *cIBwd(nullptr),
|
||||
*hLFwd(nullptr), *hLBwd(nullptr), *cLFwd(nullptr), *cLBwd(nullptr), *hFwd(nullptr), *hBwd(nullptr);
|
||||
|
||||
if(Wp) {
|
||||
WpFwd = new NDArray((*Wp)({0,1, 0,0}));
|
||||
WpBwd = new NDArray((*Wp)({1,2, 0,0}));
|
||||
}
|
||||
if(b) {
|
||||
bFwd = new NDArray((*b)({0,1, 0,0}));
|
||||
bBwd = new NDArray((*b)({1,2, 0,0}));
|
||||
}
|
||||
if(hI) {
|
||||
hIFwd = new NDArray((*hI)({0,1, 0,0, 0,0}));
|
||||
hIBwd = new NDArray((*hI)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
if(cI) {
|
||||
cIFwd = new NDArray((*cI)({0,1, 0,0, 0,0}));
|
||||
cIBwd = new NDArray((*cI)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
if(hL) {
|
||||
hLFwd = new NDArray((*hL)({0,1, 0,0, 0,0}));
|
||||
hLBwd = new NDArray((*hL)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
if(cL) {
|
||||
cLFwd = new NDArray((*cL)({0,1, 0,0, 0,0}));
|
||||
cLBwd = new NDArray((*cL)({1,2, 0,0, 0,0}));
|
||||
}
|
||||
|
||||
if(h) {
|
||||
if(directionMode == 2) { // sum
|
||||
hFwd = h;
|
||||
hBwd = new NDArray(h, false, h->getContext());
|
||||
}
|
||||
else if(directionMode == 3) { // concat
|
||||
hFwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, 0,nOut}) : (*h)({0,0, 0,nOut, 0,0}));
|
||||
hBwd = new NDArray(dataFormat <= 1 ? (*h)({0,0, 0,0, nOut,2*nOut}) : (*h)({0,0, nOut,2*nOut, 0,0}));
|
||||
}
|
||||
else { // directionMode == 4
|
||||
hFwd = new NDArray((*h)({0,0, 0,1, 0,0, 0,0}));
|
||||
hBwd = new NDArray((*h)({0,0, 1,2, 0,0, 0,0}));
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME - following two calls are independent and may run in different streams
|
||||
helpers::lstmLayerTimeLoop(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, params, true, hFwd, hLFwd, cLFwd);
|
||||
helpers::lstmLayerTimeLoop(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, params, false, hBwd, hLBwd, cLBwd);
|
||||
|
||||
if(h && directionMode == 2)
|
||||
*h += *hBwd;
|
||||
|
||||
delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd;
|
||||
delete cIBwd; delete hLFwd; delete hLBwd; delete cLFwd; delete cLBwd; delete hBwd;
|
||||
if(hFwd != h)
|
||||
delete hFwd;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(lstmLayer) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
||||
DECLARE_SHAPE_FN(lstmLayer) {
|
||||
|
||||
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX)
|
||||
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim
|
||||
|
||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong sL = dataFormat == 0 || dataFormat == 3 ? x->sizeAt(0) : ( dataFormat == 1 ? x->sizeAt(1) : x->sizeAt(2) );
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
DataType type;
|
||||
if(x->isR())
|
||||
type = x->dataType();
|
||||
else
|
||||
type = nd4j::DataType::FLOAT32;
|
||||
|
||||
std::vector<Nd4jLong*> shapes;
|
||||
|
||||
// evaluate h shape (output)
|
||||
if(retFullSeq) {
|
||||
|
||||
std::vector<Nd4jLong> hShape;
|
||||
|
||||
if(directionMode <= 2) { // single direction or bidirectional with sum
|
||||
if(dataFormat == 0)
|
||||
hShape = {sL, bS, nOut};
|
||||
else if(dataFormat == 1)
|
||||
hShape = {bS, sL, nOut};
|
||||
else if(dataFormat == 2)
|
||||
hShape = {bS, nOut, sL};
|
||||
}
|
||||
else if(directionMode == 3) { // bidirectional with concat
|
||||
|
||||
if(dataFormat == 0)
|
||||
hShape = {sL, bS, 2*nOut};
|
||||
else if(dataFormat == 1)
|
||||
hShape = {bS, sL, 2*nOut};
|
||||
else if(dataFormat == 2)
|
||||
hShape = {bS, 2*nOut, sL};
|
||||
}
|
||||
else { // bidirectional with extra output dimension equal to 2
|
||||
hShape = {sL, 2, bS, nOut};
|
||||
}
|
||||
|
||||
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hShape));
|
||||
}
|
||||
|
||||
// evaluate hL shape (output at last step)
|
||||
if(retLastH) {
|
||||
|
||||
std::vector<Nd4jLong> hLShape;
|
||||
|
||||
if(directionMode < 2)
|
||||
hLShape = {bS, nOut};
|
||||
else
|
||||
hLShape = {2, bS, nOut};
|
||||
|
||||
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), hLShape));
|
||||
|
||||
if(retLastC) // cL and hL have same shapes
|
||||
shapes.push_back(shapes.back());
|
||||
}
|
||||
|
||||
// evaluate cL shape (cell state at last step)
|
||||
if(retLastC && !retLastH) {
|
||||
|
||||
std::vector<Nd4jLong> cLShape;
|
||||
|
||||
if(directionMode < 2)
|
||||
cLShape = {bS, nOut};
|
||||
else
|
||||
cLShape = {2, bS, nOut};
|
||||
|
||||
shapes.push_back(ConstantShapeHelper::getInstance()->createShapeInfo(type, x->ordering(), cLShape));
|
||||
}
|
||||
|
||||
return new ShapeList(shapes);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -231,6 +231,11 @@ namespace ops {
|
|||
DECLARE_CUSTOM_OP(lstmBlock, 9, 7, false, 2, 2);
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
#if NOT_EXCLUDED(OP_lstmLayer)
|
||||
DECLARE_CUSTOM_OP(lstmLayer, 3, 1, false, 1, 5);
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
* Implementation of operations for Simple Recurrent Unit cell: "Training RNNs as Fast as CNNs" Tao Lei, Yu Zhang, Yoav Artzi
|
||||
|
|
|
@ -0,0 +1,460 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
// implementation of operation for LSTM cell with peep hole connections:
|
||||
// http://www.bioinf.jku.at/publications/older/2604.pdf
|
||||
// S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
|
||||
// and
|
||||
// https://research.google.com/pubs/archive/43905.pdf
|
||||
// Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014.
|
||||
|
||||
|
||||
#include <ops/declarable/helpers/lstmLayer.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
// #include <VariableSpace.h>
|
||||
// #include <ops/declarable/CustomOperations.h>
|
||||
// #include<ops/declarable/helpers/transforms.h>
|
||||
// #include <ops/declarable/helpers/legacy_helpers.h>
|
||||
// #include <array/NDArrayList.h>
|
||||
// #include <iterator>
|
||||
// #include <MmulHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||
const std::vector<float>& params,
|
||||
NDArray* h, NDArray* c) {
|
||||
|
||||
|
||||
/************************ THIS IS NOT OPTIMAZED CODE ***********************************/
|
||||
/** the objective is to provide math-readable code **/
|
||||
|
||||
// equations (no peephole connections)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// equations (peephole connections are present)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
|
||||
// IDs for activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
|
||||
|
||||
// params[0] - dataFormat, ignore
|
||||
// params[1] - directionMode, ignore
|
||||
// params[2] - cell clipping value, if it = 0 then do not apply clipping
|
||||
|
||||
// params[3] - activation ID for input (i), forget (f) and output (o) gates
|
||||
// params[4] - alpha value for gates activation
|
||||
// params[5] - beta value for gates activation
|
||||
|
||||
// params[6] - activation ID for cell state (c)
|
||||
// params[7] - alpha value for cell state activation
|
||||
// params[8] - beta value for cell state activation
|
||||
|
||||
// params[9] - activation ID for output (h)
|
||||
// params[10] - alpha value for output activation
|
||||
// params[11] - beta value for output activation
|
||||
|
||||
// INPUTS:
|
||||
// x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr
|
||||
// Wx - input weights [nIn, 4*nOut]
|
||||
// Wr - recurrent weights [nOut, 4*nOut]
|
||||
// b - biases [4*nOut], optional, may be nullptr
|
||||
// hI - previous (initial) output at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||||
// cI - previous (initial) cell state at time t-1, optional may be nullptr, [bS, nOut] or [nOut] if seqLen != nullptr
|
||||
// Wp - peephole weights [3*nOut], optional, may be nullptr
|
||||
|
||||
// OUTPUTS:
|
||||
// h - current output, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
|
||||
// c - current cell state, that is at current time step t, [bS, nOut] or [nOut] if seqLen != nullptr
|
||||
|
||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||
// !!! dimension 3*nOut implies order it, ft, ot
|
||||
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
auto z = mmul(*x, *Wx) + mmul(*hI, *Wr); // [bs, nIn] * [nIn, 4*nOut] + [bs, nOut] * [nOut, 4*nOut] = [bS, 4*nOut]
|
||||
//or [nIn] * [nIn, 4*nOut] + [nOut] * [nOut, 4*nOut] = [4*nOut]
|
||||
|
||||
// add biases if they are given
|
||||
if(b != nullptr)
|
||||
z += *b; // broadcast [bS, 4*nOut] + [4*nOut] = [bS, 4*nOut]
|
||||
|
||||
auto zi = x->rankOf() == 1 ? z({0, nOut}) : z({0,0, 0, nOut}); // input gate it, [bS, nOut]
|
||||
auto zf = x->rankOf() == 1 ? z({nOut, 2*nOut}) : z({0,0, nOut, 2*nOut}); // forget gate ft, [bS, nOut]
|
||||
auto zc = x->rankOf() == 1 ? z({2*nOut, 3*nOut}) : z({0,0, 2*nOut, 3*nOut}); // cell gate c't, [bS, nOut]
|
||||
auto zo = x->rankOf() == 1 ? z({3*nOut, 4*nOut}) : z({0,0, 3*nOut, 4*nOut}); // output gate ot, [bS, nOut]
|
||||
|
||||
// peephole connections for input and forget gates
|
||||
if(Wp != nullptr) {
|
||||
zi += *cI * (*Wp)({0, nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut]
|
||||
zf += *cI * (*Wp)({nOut, 2*nOut}); // broadcast: [bS, nOut] + [bS, nOut] ◦ [nOut] = [bS, nOut]
|
||||
}
|
||||
|
||||
applyActivation(zi, params[3], params[4], params[5], zi); // inplace
|
||||
applyActivation(zf, params[3], params[4], params[5], zf); // inplace
|
||||
applyActivation(zc, params[6], params[7], params[8], zc); // inplace
|
||||
|
||||
c->assign(zf * *cI + zi * zc); // [bS, nOut] ◦ [bS, nOut] + [bS, nOut] ◦ [bS, nOut] = [bS, nOut]
|
||||
|
||||
// if clipping value is non-zero then cell state is clipped by this value prior to the cell output activation
|
||||
if(params[2] != 0)
|
||||
c->applyScalar(scalar::LstmClip, params[2]);
|
||||
|
||||
// peephole connections for output gate
|
||||
if(Wp != nullptr)
|
||||
zo += *c * (*Wp)({2*nOut, 3*nOut}); // broadcast: [bS, nOut] + [nOut] ◦ [bS, nOut] = [bS, nOut]
|
||||
|
||||
applyActivation(zo, params[3], params[4], params[5], zo);
|
||||
|
||||
applyActivation(*c, params[9], params[10], params[11], *h);
|
||||
*h *= zo; // [bS, nOut] ◦ [bS, nOut]
|
||||
}
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||
const std::vector<float>& params,
|
||||
const bool forward,
|
||||
NDArray* h, NDArray* hL, NDArray* cL) {
|
||||
|
||||
// INPUTS:
|
||||
// x - current input [sL, bS, nIn], [bS, sL, nIn], [bS, nIn, sL],
|
||||
// Wx - input weights [nIn, 4*nOut]
|
||||
// Wr - recurrent weights [nOut, 4*nOut]
|
||||
// b - biases [4*nOut], optional, may be nullptr
|
||||
// seqLen - [bS], optional, may be nullptr
|
||||
// hI - initial output [bS, nOut], optional, may be nullptr
|
||||
// cI - initial cell state at time t-1 [bS, nOut], optional, may be nullptr
|
||||
// Wp - peephole weights [3*nOut], optional, may be nullptr
|
||||
|
||||
// OUTPUTS:
|
||||
// h - output [sL, bS, nOut], [bS, sL, nOut], [bS, nOut, sL], optional, may be nullptr
|
||||
// hL - output at last step [bS, nOut], optional, may be nullptr
|
||||
// cL - cell state at last step [bS, nOut], optional, may be nullptr
|
||||
|
||||
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
|
||||
// dataFormat: 0,3 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL]
|
||||
|
||||
const int dataFormat = params[0];
|
||||
const int directionMode = params[1];
|
||||
|
||||
const Nd4jLong sL = x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
const std::vector<Nd4jLong> shapeOut = {bS, nOut};
|
||||
|
||||
auto h0 = const_cast<NDArray*>(hI);
|
||||
if(!hI) {
|
||||
h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
|
||||
h0->nullify();
|
||||
}
|
||||
|
||||
auto c0 = const_cast<NDArray*>(cI);
|
||||
if(!cI) {
|
||||
c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
|
||||
c0->nullify();
|
||||
}
|
||||
|
||||
auto ct = cL;
|
||||
if(!cL)
|
||||
cL = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
|
||||
|
||||
auto ht = hL;
|
||||
if(!h && !hL)
|
||||
ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext());
|
||||
|
||||
// create sets of required (depends on seqLen presence) sub-arrays
|
||||
std::vector<int> dims;
|
||||
ResultSet *xSet(nullptr), *hSet(nullptr), *h0Set(nullptr), *c0Set(nullptr), *htSet(nullptr), *ctSet(nullptr);
|
||||
|
||||
if(!seqLen) {
|
||||
|
||||
dims = ShapeUtils::evalDimsToExclude(x->rankOf(), {dataFormat < 3 ? dataFormat : 0}); // points on bS and nIn/nOut axes
|
||||
|
||||
xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nIn]
|
||||
if(h)
|
||||
hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [bS, nOut]
|
||||
}
|
||||
else {
|
||||
|
||||
dims = dataFormat == 2 ? std::vector<int>({1}) : std::vector<int>({2}); // points on nIn/nOut axis
|
||||
|
||||
xSet = x->allTensorsAlongDimension(dims); // sub-arrays with shape [nIn]
|
||||
h0Set = h0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
|
||||
c0Set = c0->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
|
||||
ctSet = ct->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
|
||||
if(h)
|
||||
hSet = h->allTensorsAlongDimension(dims); // sub-arrays with shape [nOut]
|
||||
if(ht)
|
||||
htSet = ht->allTensorsAlongDimension({1}); // sub-arrays with shape [nOut]
|
||||
}
|
||||
|
||||
// loops
|
||||
if(forward) {
|
||||
|
||||
if(!seqLen) {
|
||||
|
||||
if(!h) { // seqLen and h are absent
|
||||
|
||||
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
|
||||
for (int t = 1; t < sL; ++t)
|
||||
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
|
||||
}
|
||||
else { // seqLen is absent and h is present
|
||||
|
||||
lstmLayerCell(xSet->at(0), Wx, Wr, b, h0, c0, Wp, params, hSet->at(0), ct); // first time step
|
||||
for (int t = 1; t < sL; ++t)
|
||||
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t - 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
|
||||
|
||||
if(hL)
|
||||
hL->assign(hSet->at(sL - 1)); // assign last output to hL if it is not nullptr
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if(!h) { // seqLen is present and h is absent
|
||||
|
||||
for (int e = 0; e < bS; ++e) {
|
||||
|
||||
const int limit = seqLen->e<int>(e);
|
||||
|
||||
if(limit == 0) {
|
||||
if(cL)
|
||||
ctSet->at(e)->nullify();
|
||||
if(hL)
|
||||
htSet->at(e)->nullify();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
|
||||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
|
||||
|
||||
for (int t = 1; t < limit; ++t) {
|
||||
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
|
||||
}
|
||||
}
|
||||
}
|
||||
else { // seqLen and h are present
|
||||
|
||||
for (int e = 0; e < bS; ++e) {
|
||||
|
||||
int limit = seqLen->e<int>(e);
|
||||
|
||||
if(limit == 0) {
|
||||
|
||||
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||||
|
||||
if(cL)
|
||||
ctSet->at(e)->nullify();
|
||||
if(hL)
|
||||
htSet->at(e)->nullify();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, 0, e);
|
||||
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
|
||||
|
||||
for (int t = 1; t < limit; ++t) {
|
||||
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||||
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
|
||||
indPrev = indCurr;
|
||||
}
|
||||
|
||||
if(hL)
|
||||
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if hL is not nullptr
|
||||
|
||||
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else { // backward
|
||||
|
||||
if(!seqLen) {
|
||||
|
||||
if(!h) { // seqLen and h are absent
|
||||
|
||||
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, ht, ct); // first time step
|
||||
for (int t = sL - 2; t >= 0; --t)
|
||||
lstmLayerCell(xSet->at(t), Wx, Wr, b, ht, ct, Wp, params, ht, ct); // rest time steps
|
||||
}
|
||||
else { // seqLen is absent and h is present
|
||||
|
||||
lstmLayerCell(xSet->at(sL - 1), Wx, Wr, b, h0, c0, Wp, params, hSet->at(sL - 1), ct); // first time step
|
||||
for (int t = sL - 2; t >= 0; --t)
|
||||
lstmLayerCell(xSet->at(t), Wx, Wr, b, hSet->at(t + 1), ct, Wp, params, hSet->at(t), ct); // rest time steps
|
||||
|
||||
if(hL)
|
||||
hL->assign(hSet->at(0)); // assign last output to hL if it is not nullptr
|
||||
}
|
||||
}
|
||||
else if(directionMode == 1) { // only backward, no bidirectional mode
|
||||
|
||||
if(!h) { // h is absent and seqLen is present
|
||||
|
||||
for (int e = 0; e < bS; ++e) {
|
||||
|
||||
const int limit = seqLen->e<int>(e);
|
||||
|
||||
if(limit == 0) {
|
||||
if(cL)
|
||||
ctSet->at(e)->nullify();
|
||||
if(hL)
|
||||
htSet->at(e)->nullify();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
|
||||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
|
||||
|
||||
for (int t = sL - 2; t >= sL - limit; --t) {
|
||||
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
|
||||
}
|
||||
}
|
||||
}
|
||||
else { // seqLen and h are present
|
||||
|
||||
for (int e = 0; e < bS; ++e) {
|
||||
|
||||
int limit = seqLen->e<int>(e);
|
||||
|
||||
if(limit == 0) {
|
||||
|
||||
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||||
|
||||
if(cL)
|
||||
ctSet->at(e)->nullify();
|
||||
if(hL)
|
||||
htSet->at(e)->nullify();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, sL - 1, e);
|
||||
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
|
||||
|
||||
for (int t = sL - 2; t >= sL - limit; --t) {
|
||||
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||||
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
|
||||
indPrev = indCurr;
|
||||
}
|
||||
|
||||
if(hL)
|
||||
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
|
||||
|
||||
tensorAlongTimeBatchDims(*h, dataFormat, 0,sL-limit, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||||
}
|
||||
}
|
||||
}
|
||||
else { // backward in bidirectional mode
|
||||
|
||||
if(!h) { // h is absent and seqLen is present
|
||||
|
||||
for (int e = 0; e < bS; ++e) {
|
||||
|
||||
const int limit = seqLen->e<int>(e);
|
||||
|
||||
if(limit == 0) {
|
||||
if(cL)
|
||||
ctSet->at(e)->nullify();
|
||||
if(hL)
|
||||
htSet->at(e)->nullify();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
|
||||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // first time step
|
||||
|
||||
for (int t = limit - 2; t >= 0; --t) {
|
||||
ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||||
lstmLayerCell(xSet->at(ind), Wx, Wr, b, htSet->at(e), ctSet->at(e), Wp, params, htSet->at(e), ctSet->at(e)); // rest time steps
|
||||
}
|
||||
}
|
||||
}
|
||||
else { // seqLen and h are present
|
||||
|
||||
for (int e = 0; e < bS; ++e) {
|
||||
|
||||
int limit = seqLen->e<int>(e);
|
||||
|
||||
if(limit == 0) {
|
||||
|
||||
tensorAlongTimeBatchDims(*h, dataFormat, 0,0, e,e+1).nullify(); // nullify for given e and whole time range
|
||||
|
||||
if(cL)
|
||||
ctSet->at(e)->nullify();
|
||||
if(hL)
|
||||
htSet->at(e)->nullify();
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
auto indPrev = getBatchTimeTotalIndex(dataFormat, sL, bS, limit - 1, e);
|
||||
lstmLayerCell(xSet->at(indPrev), Wx, Wr, b, h0Set->at(e), c0Set->at(e), Wp, params, hSet->at(indPrev), ctSet->at(e)); // first time step
|
||||
|
||||
for (int t = limit - 2; t >= 0; --t) {
|
||||
auto indCurr = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e);
|
||||
lstmLayerCell(xSet->at(indCurr), Wx, Wr, b, hSet->at(indPrev), ctSet->at(e), Wp, params, hSet->at(indCurr), ctSet->at(e)); // rest time steps
|
||||
indPrev = indCurr;
|
||||
}
|
||||
|
||||
if(hL)
|
||||
htSet->at(e)->assign(hSet->at(indPrev)); // assign last output to hL if it is not nullptr
|
||||
|
||||
tensorAlongTimeBatchDims(*h, dataFormat, limit,sL, e,e+1).nullify(); // nullify for given e and time range [limit, sL)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete xSet;
|
||||
delete hSet;
|
||||
delete h0Set;
|
||||
delete c0Set;
|
||||
delete htSet;
|
||||
delete ctSet;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_LSTMLAYER_H
|
||||
#define LIBND4J_LSTMLAYER_H
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/activations.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||
const std::vector<float>& params,
|
||||
NDArray* h, NDArray* c);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* seqLen, const NDArray* hI, const NDArray* cI, const NDArray* Wp,
|
||||
const std::vector<float>& params,
|
||||
const bool forward,
|
||||
NDArray* h, NDArray* hL, NDArray* cL);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static FORCEINLINE void applyActivation(NDArray& x, const int opId, const float alpha, const float beta, NDArray& z) {
|
||||
|
||||
switch (opId) {
|
||||
case 0:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::Tanh, &z);
|
||||
break;
|
||||
case 1:
|
||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::RELU, 0, &z);
|
||||
break;
|
||||
case 2:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::Sigmoid, &z);
|
||||
break;
|
||||
case 3: {
|
||||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::Affine, &z, &args);
|
||||
break;
|
||||
}
|
||||
case 4:
|
||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::LeakyRELU, alpha, &z);
|
||||
break;
|
||||
case 5:
|
||||
helpers::thresholdRelu(x.getContext(), x, alpha, z);
|
||||
break;
|
||||
case 6: {
|
||||
ExtraArguments args({ static_cast<double>(alpha), static_cast<double>(beta)});
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::ScaledTanh, &z, &args);
|
||||
break;
|
||||
}
|
||||
case 7:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::HardSigmoid, &z);
|
||||
break;
|
||||
case 8:
|
||||
(const_cast<NDArray&>(x)).applyScalar<float>(scalar::ELU, alpha, &z);
|
||||
break;
|
||||
case 9:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftSign, &z);
|
||||
break;
|
||||
case 10:
|
||||
(const_cast<NDArray&>(x)).applyTransform(transform::SoftPlus, &z);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("LSTM_LAYER operation: wrong id number of activation !");
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static FORCEINLINE NDArray tensorAlongTimeBatchDims(const NDArray& arr, const int dataFormat, const int t1, const int t2, const int b1, const int b2) {
|
||||
|
||||
if(dataFormat == 0 || dataFormat == 3)
|
||||
return arr({t1,t2, b1,b2, 0,0}); // TNS: [sL, bS, nIn]
|
||||
|
||||
if(dataFormat == 1)
|
||||
return arr({b1,b2, t1,t2, 0,0}); // NTS: [bS, sL ,nIn]
|
||||
|
||||
return arr({b1,b2, 0,0, t1,t2}); // NST: [bS, nIn, sL]
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL, const int bS, const int t, const int b) {
|
||||
|
||||
if(dataFormat == 0 || dataFormat == 3)
|
||||
return t * bS + b; // TNS: shape [sL, bS, nIn]
|
||||
|
||||
return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL]
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //LIBND4J_LSTMLAYER_H
|
|
@ -0,0 +1,546 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include "mkldnnUtils.h"
|
||||
|
||||
using namespace mkldnn;
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray* Wr,
|
||||
const NDArray* b, const NDArray* hI, const NDArray* cI,
|
||||
const std::vector<float>& params,
|
||||
NDArray* h, NDArray* hL, NDArray* cL) {
|
||||
|
||||
// equations (no peephole connections)
|
||||
// it = σ(Wxi * xt + Wri * ht-1 + bi)
|
||||
// ft = σ(Wxf * xt + Wrf * ht-1 + bf)
|
||||
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
|
||||
// ct = ft ◦ ct-1 + it ◦ c't
|
||||
// ot = σ(Wxo * xt + Wro * ht-1 + bo)
|
||||
// ht = ot ◦ tanh(ct)
|
||||
|
||||
// notations:
|
||||
// bS - batch size
|
||||
// sL - sequence length, number of time steps
|
||||
// nIn - input size
|
||||
// nOut - output size (hidden size)
|
||||
|
||||
// INPUTS:
|
||||
|
||||
// *******
|
||||
// input x:
|
||||
// 1) [sL, bS, nIn] when dataFormat == 0
|
||||
|
||||
// *******
|
||||
// input weights Wx:
|
||||
// 1) [1, 1, nIn, 4*nOut] when directionMode < 2
|
||||
// 2) [1, 2, nIn, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// recurrent weights Wr:
|
||||
// 1) [1, 1, nOut, 4*nOut] when directionMode < 2
|
||||
// 2) [1, 2, nOut, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// biases b:
|
||||
// 1) [1, 1, 4*nOut] when directionMode < 2
|
||||
// 2) [1, 2, 4*nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// initial output hI:
|
||||
// 1) [1, 1, bS, nOut] when directionMode < 2
|
||||
// 2) [1, 2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// initial cell state cI (same shape as in hI):
|
||||
// 1) [1, 1, bS, nOut] when directionMode < 2
|
||||
// 2) [1, 2, bS, nOut] when directionMode >= 2
|
||||
|
||||
|
||||
// OUTPUTS:
|
||||
|
||||
// *******
|
||||
// output h:
|
||||
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
|
||||
// 2) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
|
||||
|
||||
// *******
|
||||
// output at last step hL:
|
||||
// 1) [1, 1, bS, nOut] when directionMode < 2
|
||||
// 2) [1, 2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// *******
|
||||
// cell state at last step cL (same shape as in hL):
|
||||
// 1) [1, 1, bS, nOut] when directionMode < 2
|
||||
// 2) [1, 2, bS, nOut] when directionMode >= 2
|
||||
|
||||
// !!! dimension 4*nOut implies order it, ft, c't, ot
|
||||
// !!! dimension 3*nOut implies order it, ft, ot
|
||||
|
||||
// params = {dataFormat, directionMode, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
|
||||
|
||||
// dataFormat: 0 = [sL, bS, nIn]
|
||||
// directionMode: 0 = forward, 1 = backward, 2 = bidirectional sum, 3 = bidirectional concat
|
||||
|
||||
const int dataFormat = params[0];
|
||||
const int directionMode = params[1];
|
||||
|
||||
const int sL = x->sizeAt(0); // dataFormat == 0 ? x->sizeAt(0) : x->sizeAt(1);
|
||||
const int bS = x->sizeAt(1); // dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0);
|
||||
const int nIn = x->sizeAt(-1);
|
||||
const int nOut = Wx->sizeAt(-1);
|
||||
|
||||
const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensionss, 1 unidirectional, 2 for bidirectional
|
||||
const int hDirDim = directionMode <= 2 ? 1 : 2; // for h array, take into account bidirectional_sum mode (directionMode == 2)
|
||||
|
||||
// evaluate direction
|
||||
rnn_direction direction;
|
||||
switch (directionMode) {
|
||||
case 0:
|
||||
direction = rnn_direction::unidirectional_left2right;
|
||||
break;
|
||||
case 1:
|
||||
direction = rnn_direction::unidirectional_right2left;
|
||||
break;
|
||||
case 2:
|
||||
direction = rnn_direction::bidirectional_sum;
|
||||
break;
|
||||
default:
|
||||
direction = rnn_direction::bidirectional_concat;
|
||||
}
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
mkldnn_memory_desc_t empty;
|
||||
|
||||
mkldnn::memory::desc x_user_md, wx_user_md, wr_user_md, b_user_md, hI_user_md, cI_user_md, h_user_md, hL_user_md, cL_user_md,
|
||||
x_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md, hI_lstm_md, cI_lstm_md, h_lstm_md, hL_lstm_md, cL_lstm_md;
|
||||
|
||||
// input type
|
||||
mkldnn::memory::data_type xType;
|
||||
if(x->dataType() == DataType::FLOAT32)
|
||||
xType = mkldnn::memory::data_type::f32;
|
||||
else if(x->dataType() == DataType::HALF)
|
||||
xType = mkldnn::memory::data_type::f16;
|
||||
else
|
||||
xType = mkldnn::memory::data_type::u8;
|
||||
|
||||
// weights type
|
||||
mkldnn::memory::data_type wType = xType;
|
||||
if(xType == mkldnn::memory::data_type::u8)
|
||||
wType = mkldnn::memory::data_type::s8;
|
||||
|
||||
// bias type
|
||||
mkldnn::memory::data_type bType = xType;
|
||||
if(xType == mkldnn::memory::data_type::u8)
|
||||
bType = mkldnn::memory::data_type::f32;
|
||||
|
||||
// output type
|
||||
mkldnn::memory::data_type hType;
|
||||
if(h->dataType() == DataType::FLOAT32)
|
||||
hType = mkldnn::memory::data_type::f32;
|
||||
else if(h->dataType() == DataType::HALF)
|
||||
hType = mkldnn::memory::data_type::f16;
|
||||
else
|
||||
hType = mkldnn::memory::data_type::u8;
|
||||
|
||||
|
||||
// memory descriptors for arrays
|
||||
// x
|
||||
x_lstm_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::any);
|
||||
// x_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, nIn}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, nIn}, type, mkldnn::memory::format_tag::ntc);
|
||||
x_user_md = mkldnn::memory::desc({sL, bS, nIn}, xType, mkldnn::memory::format_tag::tnc);
|
||||
x_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
||||
|
||||
// wx
|
||||
wx_lstm_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::any);
|
||||
wx_user_md = mkldnn::memory::desc({1,dirDim,nIn,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
|
||||
wx_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
|
||||
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
|
||||
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
|
||||
wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3];
|
||||
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
|
||||
|
||||
// wr
|
||||
wr_lstm_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::any);
|
||||
wr_user_md = mkldnn::memory::desc({1,dirDim,nOut,4,nOut}, wType, mkldnn::memory::format_tag::ldigo);
|
||||
wr_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
|
||||
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
|
||||
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
|
||||
wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3];
|
||||
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
|
||||
|
||||
// h
|
||||
h_lstm_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::any);
|
||||
// h_user_md = dataFormat == 0 ? mkldnn::memory::desc({sL, bS, hDirDim*nOut}, type, mkldnn::memory::format_tag::tnc) : mkldnn::memory::desc({bS, sL, hDirDim*nOut}, type, mkldnn::memory::format_tag::ntc);
|
||||
h_user_md = mkldnn::memory::desc({sL, bS, hDirDim*nOut}, hType, mkldnn::memory::format_tag::tnc);
|
||||
h_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
|
||||
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
|
||||
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
|
||||
|
||||
// b
|
||||
if(b) {
|
||||
b_lstm_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::any);
|
||||
b_user_md = mkldnn::memory::desc({1,dirDim,4,nOut}, bType, mkldnn::memory::format_tag::ldgo);
|
||||
b_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
|
||||
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
|
||||
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
|
||||
b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3];
|
||||
}
|
||||
|
||||
// hI
|
||||
if(hI) {
|
||||
hI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
|
||||
hI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
|
||||
hI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
|
||||
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
|
||||
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
|
||||
hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3];
|
||||
}
|
||||
|
||||
// cI
|
||||
if(cI) {
|
||||
cI_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::any);
|
||||
cI_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, xType, mkldnn::memory::format_tag::ldnc);
|
||||
cI_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
|
||||
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
|
||||
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
|
||||
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3];
|
||||
}
|
||||
|
||||
// hL
|
||||
if(hL) {
|
||||
hL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::any);
|
||||
hL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
|
||||
hL_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
|
||||
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
|
||||
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
|
||||
hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3];
|
||||
}
|
||||
|
||||
if(cL) {
|
||||
cL_lstm_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
|
||||
cL_user_md = mkldnn::memory::desc({1,dirDim,bS,nOut}, hType, mkldnn::memory::format_tag::ldnc);
|
||||
cL_user_md.data.format_kind = mkldnn_blocked; // overrides format
|
||||
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
|
||||
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
|
||||
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
|
||||
cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3];
|
||||
}
|
||||
|
||||
// lstm memory description
|
||||
lstm_forward::desc lstm_desc(prop_kind::forward_inference, direction,
|
||||
x_lstm_md, hI_lstm_md, cI_lstm_md, wx_lstm_md, wr_lstm_md, b_lstm_md,
|
||||
h_lstm_md, hL_lstm_md, cL_lstm_md);
|
||||
|
||||
mkldnn::stream stream(engine);
|
||||
|
||||
// lstm primitive description
|
||||
lstm_forward::primitive_desc lstm_prim_desc(lstm_desc, engine);
|
||||
|
||||
// arguments (memory buffers) necessary for calculations
|
||||
std::unordered_map<int, mkldnn::memory> args;
|
||||
|
||||
// provide memory and check whether reorder is required
|
||||
// x
|
||||
auto x_user_mem = mkldnn::memory(x_user_md, engine, x->getBuffer());
|
||||
const bool xReorder = lstm_prim_desc.src_layer_desc() != x_user_mem.get_desc();
|
||||
auto x_lstm_mem = xReorder ? mkldnn::memory(lstm_prim_desc.src_layer_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
reorder(x_user_mem, x_lstm_mem).execute(stream, x_user_mem, x_lstm_mem);
|
||||
args[MKLDNN_ARG_SRC_LAYER] = x_lstm_mem;
|
||||
|
||||
// wx
|
||||
auto wx_user_mem = mkldnn::memory(wx_user_md, engine, Wx->getBuffer());
|
||||
const bool wxReorder = lstm_prim_desc.weights_layer_desc()!= wx_user_mem.get_desc();
|
||||
auto wx_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_layer_desc(), engine) : wx_user_mem;
|
||||
if (wxReorder)
|
||||
reorder(wx_user_mem, wx_lstm_mem).execute(stream, wx_user_mem, wx_lstm_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS_LAYER] = wx_lstm_mem;
|
||||
|
||||
// wr
|
||||
auto wr_user_mem = mkldnn::memory(wr_user_md, engine, Wr->getBuffer());
|
||||
const bool wrReorder = lstm_prim_desc.weights_iter_desc() != wr_user_mem.get_desc();
|
||||
auto wr_lstm_mem = wxReorder ? mkldnn::memory(lstm_prim_desc.weights_iter_desc(), engine) : wr_user_mem;
|
||||
if (wrReorder)
|
||||
reorder(wr_user_mem, wr_lstm_mem).execute(stream, wr_user_mem, wr_lstm_mem);
|
||||
args[MKLDNN_ARG_WEIGHTS_ITER] = wr_lstm_mem;
|
||||
|
||||
// h
|
||||
auto h_user_mem = mkldnn::memory(h_user_md, engine, h->getBuffer());
|
||||
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
|
||||
auto h_lstm_mem = hReorder ? mkldnn::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
|
||||
args[MKLDNN_ARG_DST_LAYER] = h_lstm_mem;
|
||||
|
||||
// b
|
||||
if(b) {
|
||||
auto b_user_mem = mkldnn::memory(b_user_md, engine, b->getBuffer());
|
||||
const bool bReorder = lstm_prim_desc.bias_desc() != b_user_mem.get_desc();
|
||||
auto b_lstm_mem = bReorder ? mkldnn::memory(lstm_prim_desc.bias_desc(), engine) : b_user_mem;
|
||||
if (bReorder)
|
||||
reorder(b_user_mem, b_lstm_mem).execute(stream, b_user_mem, b_lstm_mem);
|
||||
args[MKLDNN_ARG_BIAS] = b_lstm_mem;
|
||||
}
|
||||
|
||||
// hI
|
||||
if(hI) {
|
||||
auto hI_user_mem = mkldnn::memory(hI_user_md, engine, hI->getBuffer());
|
||||
const bool hIReorder = lstm_prim_desc.src_iter_desc() != hI_user_mem.get_desc();
|
||||
auto hI_lstm_mem = hIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_desc(), engine) : hI_user_mem;
|
||||
if (hIReorder)
|
||||
reorder(hI_user_mem, hI_lstm_mem).execute(stream, hI_user_mem, hI_lstm_mem);
|
||||
args[MKLDNN_ARG_SRC_ITER] = hI_lstm_mem;
|
||||
}
|
||||
|
||||
// cI
|
||||
if(cI) {
|
||||
auto cI_user_mem = mkldnn::memory(cI_user_md, engine, cI->getBuffer());
|
||||
const bool cIReorder = lstm_prim_desc.src_iter_c_desc() != cI_user_mem.get_desc();
|
||||
auto cI_lstm_mem = cIReorder ? mkldnn::memory(lstm_prim_desc.src_iter_c_desc(), engine) : cI_user_mem;
|
||||
if (cIReorder)
|
||||
reorder(cI_user_mem, cI_lstm_mem).execute(stream, cI_user_mem, cI_lstm_mem);
|
||||
args[MKLDNN_ARG_SRC_ITER_C] = cI_lstm_mem;
|
||||
}
|
||||
|
||||
bool hLReorder(false), cLReorder(false);
|
||||
mkldnn::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
|
||||
|
||||
// hL
|
||||
if(hL) {
|
||||
hL_user_mem = mkldnn::memory(hL_user_md, engine, hL->getBuffer());
|
||||
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
|
||||
hL_lstm_mem = hLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
|
||||
args[MKLDNN_ARG_DST_ITER] = hL_lstm_mem;
|
||||
}
|
||||
|
||||
// cL
|
||||
if(cL) {
|
||||
cL_user_mem = mkldnn::memory(cL_user_md, engine, cL->getBuffer());
|
||||
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
|
||||
cL_lstm_mem = cLReorder ? mkldnn::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
|
||||
args[MKLDNN_ARG_DST_ITER_C] = cL_lstm_mem;
|
||||
}
|
||||
|
||||
// run calculations
|
||||
lstm_forward(lstm_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (hReorder)
|
||||
reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem);
|
||||
if(hLReorder)
|
||||
reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem);
|
||||
if(cLReorder)
|
||||
reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(lstmLayer) {
|
||||
|
||||
const auto dataFormat = INT_ARG(0); // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nOut] (for ONNX)
|
||||
const auto directionMode = INT_ARG(1); // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasSeqLen = B_ARG(1); // indicates whether seqLen array is provided
|
||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||
const auto hasPH = B_ARG(4); // indicates whether peephole connections are present
|
||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
|
||||
const auto cellClip = T_ARG(0); // cell clipping value, if it = 0 then do not apply clipping
|
||||
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
|
||||
int count = 3;
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto seqLen = hasSeqLen ? INPUT_VARIABLE(count++) : nullptr; // seqLen vector
|
||||
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
|
||||
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
|
||||
const auto Wp = hasPH ? INPUT_VARIABLE(count++) : nullptr; // peephole weights
|
||||
|
||||
REQUIRE_TRUE(cellClip == 0 , 0, "LSTM_LAYER_MKLDNN operation: cell clipping is not supported currently !");
|
||||
REQUIRE_TRUE(retFullSeq, 0, "LSTM_LAYER_MKLDNN operation: option to calculate full time sequence output h should be always true in case of mkl dnn library !");
|
||||
REQUIRE_TRUE(hasPH == false , 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support peephole connections !");
|
||||
REQUIRE_TRUE(hasSeqLen == false, 0, "LSTM_LAYER_MKLDNN operation: mkl dnn library doesn't support array specifying max time step per each example in batch !");
|
||||
REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!");
|
||||
REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !");
|
||||
REQUIRE_TRUE((retLastH && retLastC) || (!retLastH && !retLastC), 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !");
|
||||
|
||||
count = 0;
|
||||
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
||||
auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step
|
||||
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
// inputs validations
|
||||
if(directionMode < 2) { // no bidirectional
|
||||
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 2 || Wx->sizeAt(0) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 2 || Wr->sizeAt(0) != nOut || Wr->sizeAt(1) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 1 || b->sizeAt(0) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
||||
// initial output validation
|
||||
if(hI != nullptr && (hI->rankOf() != 2 || hI->sizeAt(0) != bS || hI->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
||||
// initial cell validation
|
||||
if(cI != nullptr && (cI->rankOf() != 2 || cI->sizeAt(0) != bS || cI->sizeAt(1) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
||||
}
|
||||
else { // bidirectional
|
||||
// Wx validation
|
||||
if(Wx->rankOf() != 3 || Wx->sizeAt(0) != 2 || Wx->sizeAt(1) != nIn)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of input weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nIn, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wx));
|
||||
// Wr validation
|
||||
if(Wr->rankOf() != 3 || Wr->sizeAt(0) != 2 || Wr->sizeAt(1) != nOut || Wr->sizeAt(2) != 4*nOut)
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of recurrent weights, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, nOut, 4*nOut}).c_str(), ShapeUtils::shapeAsString(Wr));
|
||||
// biases validation
|
||||
if(b != nullptr && (b->rankOf() != 2 || b->sizeAt(0) != 2 || b->sizeAt(1) != 4*nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of biases, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, 4*nOut}).c_str(), ShapeUtils::shapeAsString(b));
|
||||
// initial output validation
|
||||
if(hI != nullptr && (hI->rankOf() != 3 || hI->sizeAt(0) != 2 || hI->sizeAt(1) != bS || hI->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(hI));
|
||||
// initial cell validation
|
||||
if(cI != nullptr && (cI->rankOf() != 3 || cI->sizeAt(0) != 2 || cI->sizeAt(1) != bS || cI->sizeAt(2) != nOut))
|
||||
REQUIRE_TRUE(false, 0, "LSTM_LAYER_MKLDNN operation: wrong shape of initial cell state, expected is %s, but got %s instead !", ShapeUtils::shapeAsString({2, bS, nOut}).c_str(), ShapeUtils::shapeAsString(cI));
|
||||
}
|
||||
|
||||
std::vector<float> params = {static_cast<float>(dataFormat), static_cast<float>(directionMode), static_cast<float>(cellClip)};
|
||||
|
||||
const int dirDim = directionMode < 2 ? 1 : 2; // number of dimensions, 1 unidirectional, 2 for bidirectional
|
||||
|
||||
// permut x and h to tnc format if they have ntc format
|
||||
NDArray* xP(const_cast<NDArray*>(x)), *hP(h);
|
||||
if(dataFormat == 1) {
|
||||
xP = new NDArray(x->permute({1,0,2})); // [bS, sL, nIn] -> [sL, bS, nIn]
|
||||
hP = new NDArray(h->permute({1,0,2})); // [bS, sL, dirDim*nOn] -> [sL, bS, dirDim*nOn]
|
||||
}
|
||||
|
||||
// reshape arrays in accordance to mkl allowed formats
|
||||
NDArray *WxR(nullptr), *WrR(nullptr), *bR(nullptr), *hIR(nullptr), *cIR(nullptr), *hLR(nullptr), *cLR(nullptr);
|
||||
|
||||
WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut}));
|
||||
WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut}));
|
||||
if(b)
|
||||
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
|
||||
if(hI)
|
||||
hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut}));
|
||||
if(cI)
|
||||
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
|
||||
if(hL)
|
||||
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}));
|
||||
if(cL)
|
||||
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}));
|
||||
|
||||
lstmLayerMKLDNN(xP, WxR, WrR, bR, hIR, cIR, params, hP, hLR, cLR);
|
||||
|
||||
delete WxR;
|
||||
delete WrR;
|
||||
delete bR;
|
||||
delete hIR;
|
||||
delete cIR;
|
||||
delete hLR;
|
||||
delete cLR;
|
||||
|
||||
if(dataFormat == 1) {
|
||||
delete xP;
|
||||
delete hP;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
PLATFORM_CHECK(lstmLayer) {
|
||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
||||
// if (::optimalLevel() < 2) {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
const auto hasBiases = B_ARG(0); // indicates whether biases array is provided
|
||||
const auto hasInitH = B_ARG(2); // indicates whether initial output is provided
|
||||
const auto hasInitC = B_ARG(3); // indicates whether initial cell state is provided
|
||||
const auto retFullSeq = B_ARG(5); // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
|
||||
const auto retLastH = B_ARG(6); // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
const auto retLastC = B_ARG(7); // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
|
||||
|
||||
const auto x = INPUT_VARIABLE(0); // input
|
||||
const auto Wx = INPUT_VARIABLE(1); // input weights
|
||||
const auto Wr = INPUT_VARIABLE(2); // recurrent weights
|
||||
|
||||
int count = 3;
|
||||
const auto b = hasBiases ? INPUT_VARIABLE(count++) : nullptr; // biases
|
||||
const auto hI = hasInitH ? INPUT_VARIABLE(count++) : nullptr; // initial output
|
||||
const auto cI = hasInitC ? INPUT_VARIABLE(count++) : nullptr; // initial cell state
|
||||
|
||||
count = 0;
|
||||
auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output
|
||||
auto hL = retLastH ? OUTPUT_VARIABLE(count++) : nullptr; // output at last step
|
||||
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
|
||||
|
||||
DataType xType = x->dataType();
|
||||
DataType WxType = Wx->dataType();
|
||||
DataType WrType = Wr->dataType();
|
||||
DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32);
|
||||
DataType hIType = hI != nullptr ? hI->dataType() : xType;
|
||||
DataType cIType = cI != nullptr ? hI->dataType() : xType;
|
||||
DataType hType = h != nullptr ? h->dataType() : xType;
|
||||
DataType hLType = hL != nullptr ? hL->dataType() : xType;
|
||||
DataType cLType = cL != nullptr ? cL->dataType() : xType;
|
||||
|
||||
return block.isUseMKLDNN() && (
|
||||
(xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) ||
|
||||
(xType==DataType::HALF && WxType==DataType::HALF && WrType==DataType::HALF && bType==DataType::HALF && hIType==DataType::HALF && cIType==DataType::HALF && hType==DataType::HALF && hLType==DataType::HALF && cLType==DataType::HALF) ||
|
||||
(xType==DataType::UINT8 && WxType==DataType::INT8 && WrType==DataType::INT8 && bType==DataType::FLOAT32 && hIType==DataType::UINT8 && cIType==DataType::UINT8 && (hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32 || hType==DataType::UINT8 && hLType==DataType::UINT8 && cLType==DataType::UINT8))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -63,6 +63,8 @@ namespace nd4j{
|
|||
DECLARE_PLATFORM(lrn);
|
||||
|
||||
DECLARE_PLATFORM(batchnorm_new);
|
||||
|
||||
DECLARE_PLATFORM(lstmLayer);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1857,6 +1857,17 @@ namespace simdOps {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class Affine {
|
||||
public:
|
||||
no_op_exec_special_same
|
||||
no_op_exec_special_same_cuda
|
||||
|
||||
op_def static X op(X d1, X *params) {
|
||||
return params[0] * d1 + params[1];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class SigmoidDerivative {
|
||||
public:
|
||||
|
@ -2051,6 +2062,17 @@ namespace simdOps {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class ScaledTanh {
|
||||
public:
|
||||
no_op_exec_special_same
|
||||
no_op_exec_special_same_cuda
|
||||
|
||||
op_def static X op(X d1, X *params) {
|
||||
return params[0] * nd4j::math::nd4j_tanh<X, X>(params[1] * d1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X>
|
||||
class RectifiedTanh {
|
||||
public:
|
||||
|
|
|
@ -983,5 +983,952 @@ TEST_F(DeclarableOpsTests13, mergemax_2) {
|
|||
ASSERT_EQ(20, status);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_1) {
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 3;
|
||||
const int nIn = 3;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = false; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx = 0.003;
|
||||
Wr = 0.006;
|
||||
b = 0.5;
|
||||
hI = 1.;
|
||||
cI = 2.;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
auto expH = NDArrayFactory::create<float>('c', {sL, bS, nOut}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434,
|
||||
0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338,
|
||||
0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287,
|
||||
0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327,
|
||||
0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625});
|
||||
|
||||
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861});
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *h = results->at(0);
|
||||
auto *cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expClast.isSameShape(cL));
|
||||
ASSERT_TRUE(expClast.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_2) {
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 3;
|
||||
const int nIn = 3;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = false; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx = 0.003;
|
||||
Wr = 0.006;
|
||||
b = 0.5;
|
||||
hI = 1.;
|
||||
cI = 2.;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
auto expH = NDArrayFactory::create<float>('c', {bS, sL, nOut}, {0.575735, 0.575735, 0.575735, 0.541562, 0.541562, 0.541562, 0.514003, 0.514003, 0.514003, 0.495597, 0.495597, 0.495597, 0.485999, 0.485999, 0.485999,
|
||||
0.596965, 0.596965, 0.596965, 0.571978, 0.571978, 0.571978, 0.552888, 0.552888, 0.552888, 0.540606, 0.540606, 0.540606, 0.534764, 0.534764, 0.534764,
|
||||
0.61725 , 0.61725 , 0.61725 , 0.599828, 0.599828, 0.599828, 0.587627, 0.587627, 0.587627, 0.580408, 0.580408, 0.580408, 0.577735, 0.577735, 0.577735});
|
||||
|
||||
auto expClast = NDArrayFactory::create<float>('c', {bS, nOut}, {0.996965, 0.996965, 0.996965, 1.146756, 1.146756, 1.146756, 1.301922, 1.301922, 1.301922});
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto *h = results->at(0);
|
||||
auto *cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expClast.isSameShape(cL));
|
||||
ASSERT_TRUE(expClast.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_3) {
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = false; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL,bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx = 0.003;
|
||||
Wr = 0.006;
|
||||
b = 0.5;
|
||||
hI = 1.;
|
||||
cI = 2.;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990, 0.534701, 0.534701, 0.534701, 0.549139,
|
||||
0.549139, 0.549139, 0.571900, 0.571900, 0.571900, 0.583561, 0.583561, 0.583561, 0.605106, 0.605106,
|
||||
0.605106, 0.614114, 0.614114, 0.614114, 0.635354, 0.635354, 0.635354, 0.642045, 0.642045, 0.642045}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {bS, nOut}, {0.493883, 0.493883, 0.493883, 0.510990, 0.510990, 0.510990}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {bS, nOut}, {1.061274, 1.061274, 1.061274, 1.115888, 1.115888, 1.115888}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_4) {
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 3; // bidirectional concat
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = false; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
||||
b({0,1, 0,0}) = 0.5;
|
||||
b({1,2, 0,0}) = -0.5;
|
||||
hI({0,1, 0,0, 0,0}) = 1;
|
||||
hI({1,2, 0,0, 0,0}) = -1;
|
||||
cI({0,1, 0,0, 0,0}) = 2;
|
||||
cI({1,2, 0,0, 0,0}) = -2;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107642, -0.107642, -0.107642, 0.585289, 0.585289, 0.585289,
|
||||
-0.106937, -0.106937, -0.106937, 0.556517, 0.556517, 0.556517, -0.111647, -0.111647, -0.111647,
|
||||
0.567274, 0.567274, 0.567274, -0.110214, -0.110214, -0.110214, 0.547395, 0.547395, 0.547395,
|
||||
-0.123305, -0.123305, -0.123305, 0.560640, 0.560640, 0.560640, -0.120862, -0.120862, -0.120862,
|
||||
0.550714, 0.550714, 0.550714, -0.156223, -0.156223, -0.156223, 0.565308, 0.565308, 0.565308,
|
||||
-0.152313, -0.152313, -0.152313, 0.563741, 0.563741, 0.563741, -0.234128, -0.234128, -0.234128,
|
||||
0.578676, 0.578676, 0.578676, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
|
||||
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
|
||||
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_5) {
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
const int dataFormat = 1; // [bS,sL,nIn]
|
||||
const int directionMode = 3; // bidirectional concat
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = false; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {bS, sL, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
||||
b({0,1, 0,0}) = 0.5;
|
||||
b({1,2, 0,0}) = -0.5;
|
||||
hI({0,1, 0,0, 0,0}) = 1;
|
||||
hI({1,2, 0,0, 0,0}) = -1;
|
||||
cI({0,1, 0,0, 0,0}) = 2;
|
||||
cI({1,2, 0,0, 0,0}) = -2;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {bS, sL, 2*nOut}, {0.577661, 0.577661, 0.577661, -0.107659, -0.107659, -0.107659, 0.548099, 0.548099, 0.548099, -0.113406, -0.113406, -0.113406,
|
||||
0.526881, 0.526881, 0.526881, -0.12883 , -0.12883 , -0.12883 , 0.515882, 0.515882, 0.515882, -0.16868 , -0.16868 , -0.16868 ,
|
||||
0.51409 , 0.51409 , 0.51409 , -0.255185, -0.255185, -0.255185, 0.614599, 0.614599, 0.614599, -0.102739, -0.102739, -0.102739,
|
||||
0.599572, 0.599572, 0.599572, -0.105802, -0.105802, -0.105802,0.591089, 0.591089, 0.591089, -0.116681, -0.116681, -0.116681,
|
||||
0.588694, 0.588694, 0.588694, -0.149201, -0.149201, -0.149201,0.591492, 0.591492, 0.591492, -0.228917, -0.228917, -0.228917}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {2,bS, nOut}, {0.51409 , 0.51409 , 0.51409 , 0.591492, 0.591492, 0.591492,
|
||||
-0.107659, -0.107659, -0.107659, -0.102739, -0.102739, -0.102739}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {2,bS, nOut}, {1.07293 , 1.07293 , 1.07293,1.346609, 1.346609, 1.346609,
|
||||
-0.295811, -0.295811, -0.295811,-0.305394, -0.305394, -0.305394}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
// h->printBuffer();
|
||||
// hL->printBuffer();
|
||||
// cL->printBuffer();
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_6) {
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 2; // bidirectional sum
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = false; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
||||
b({0,1, 0,0}) = 0.5;
|
||||
b({1,2, 0,0}) = -0.5;
|
||||
hI({0,1, 0,0, 0,0}) = 1;
|
||||
hI({1,2, 0,0, 0,0}) = -1;
|
||||
cI({0,1, 0,0, 0,0}) = 2;
|
||||
cI({1,2, 0,0, 0,0}) = -2;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, nOut}, {0.470019, 0.470019, 0.470019, 0.478352, 0.478352, 0.478352, 0.444871, 0.444871, 0.444871, 0.457060,
|
||||
0.457060, 0.457060, 0.424090, 0.424090, 0.424090, 0.439778, 0.439778, 0.439778, 0.394491, 0.394491,
|
||||
0.394491, 0.412995, 0.412995, 0.412995, 0.329613, 0.329613, 0.329613, 0.349760, 0.349760, 0.349760}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {2,bS, nOut}, {0.563741, 0.563741, 0.563741, 0.578676, 0.578676, 0.578676, -0.107642,
|
||||
-0.107642, -0.107642, -0.106937, -0.106937, -0.106937}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {2,bS, nOut}, {1.217757, 1.217757, 1.217757, 1.272398, 1.272398, 1.272398, -0.295768,
|
||||
-0.295768, -0.295768, -0.298453, -0.298453, -0.298453}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_7) {
|
||||
#ifndef HAVE_MKLDNN
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx = 0.003;
|
||||
Wr = 0.006;
|
||||
b = 0.5;
|
||||
hI = 1.;
|
||||
cI = 2.;
|
||||
Wp = -0.05;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, nOut}, {0.55533 , 0.55533 , 0.55533 , 0.562925, 0.562925, 0.562925, 0.531795, 0.531795, 0.531795, 0.542556,
|
||||
0.542556, 0.542556, 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, 0.524805, 0.524805,
|
||||
0.524805, 0.539187, 0.539187, 0.539187, 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {bS, nOut}, {0.538309, 0.538309, 0.538309,0.552923, 0.552923, 0.552923}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_8) {
|
||||
#ifndef HAVE_MKLDNN
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 1.; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx = 0.003;
|
||||
Wr = 0.006;
|
||||
b = 0.5;
|
||||
hI = 1.;
|
||||
cI = 2.;
|
||||
Wp = -0.05;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, nOut}, {0.436221, 0.436221, 0.436221,0.450573, 0.450573, 0.450573,0.463602, 0.463602, 0.463602, 0.474674, 0.474674, 0.474674,
|
||||
0.484039, 0.484039, 0.484039,0.490679, 0.490679, 0.490679, 0.494871, 0.494871, 0.494871, 0.499028, 0.499028, 0.499028,
|
||||
0.504649, 0.504649, 0.504649, 0.508719, 0.508719, 0.508719}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {bS, nOut}, {0.436221, 0.436221, 0.436221, 0.450573, 0.450573, 0.450573}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {bS, nOut}, {0.879804, 0.879804, 0.879804,0.914666, 0.914666, 0.914666}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_9) {
|
||||
#ifndef HAVE_MKLDNN
|
||||
|
||||
const int sL = 5;
|
||||
const int bS = 2;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 3; // bidirectional concat
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = false; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
||||
b({0,1, 0,0}) = 0.5;
|
||||
b({1,2, 0,0}) = -0.5;
|
||||
hI({0,1, 0,0, 0,0}) = 1;
|
||||
hI({1,2, 0,0, 0,0}) = -1;
|
||||
cI({0,1, 0,0, 0,0}) = 2;
|
||||
cI({1,2, 0,0, 0,0}) = -2;
|
||||
Wp({0,1, 0,0}) = -0.05;
|
||||
Wp({1,2, 0,0}) = 0.05;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, 2*nOut}, { 0.55533 , 0.55533 , 0.55533 , -0.104502, -0.104502, -0.104502, 0.562925, 0.562925, 0.562925, -0.103843, -0.103843, -0.103843,
|
||||
0.531795, 0.531795, 0.531795, -0.107456, -0.107456, -0.107456,0.542556, 0.542556, 0.542556, -0.106139, -0.106139, -0.106139,
|
||||
0.521466, 0.521466, 0.521466, -0.11681 , -0.11681 , -0.11681 , 0.534638, 0.534638, 0.534638, -0.11458 , -0.11458 , -0.11458 ,
|
||||
0.524805, 0.524805, 0.524805, -0.145177, -0.145177, -0.145177,0.539187, 0.539187, 0.539187, -0.14157 , -0.14157 , -0.14157 ,
|
||||
0.538309, 0.538309, 0.538309, -0.218056, -0.218056, -0.218056,0.552923, 0.552923, 0.552923, -0.213068, -0.213068, -0.213068}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {2,bS, nOut}, {0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923, -0.104502, -0.104502, -0.104502,
|
||||
-0.103843, -0.103843, -0.103843}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {2,bS, nOut}, {1.147089, 1.147089, 1.147089, 1.197228, 1.197228, 1.197228, -0.289425, -0.289425, -0.289425,
|
||||
-0.292174, -0.292174, -0.292174}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_10) {
|
||||
#ifndef HAVE_MKLDNN
|
||||
|
||||
const int sL = 6;
|
||||
const int bS = 5;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 0; // forward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx = 0.003;
|
||||
Wr = 0.006;
|
||||
b = 0.5;
|
||||
hI = 1.;
|
||||
cI = 2.;
|
||||
Wp = -0.05;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.570404, 0.570404, 0.570404, 0.57777 , 0.57777 , 0.57777 , 0.585023, 0.585023, 0.585023,
|
||||
0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, 0.586163, 0.586163, 0.586163, 0.595462, 0.595462, 0.595462, 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0., 0.611224, 0.611224, 0.611224, 0.621298, 0.621298, 0.621298, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0.655858, 0.655858, 0.655858, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, 0., 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 1.534275, 1.534275, 1.534275, 1.40183, 1.40183, 1.40183, 1.449675, 1.449675, 1.449675, 1.767702, 1.767702, 1.767702}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_11) {
|
||||
#ifndef HAVE_MKLDNN
|
||||
|
||||
const int sL = 6;
|
||||
const int bS = 5;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 1; // backward
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx = 0.003;
|
||||
Wr = 0.006;
|
||||
b = 0.5;
|
||||
hI = 1.;
|
||||
cI = 2.;
|
||||
Wp = -0.05;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, nOut}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.61209,
|
||||
0.61209, 0.61209,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.652042, 0.652042, 0.652042, 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0., 0.677708, 0.677708, 0.677708, 0.684177, 0.684177, 0.684177, 0., 0., 0.,0., 0., 0.,0.699627, 0.699627,
|
||||
0.699627,0.705371, 0.705371, 0.705371,0.710989, 0.710989, 0.710989, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087,
|
||||
0.724087, 0.724087, 0.729084, 0.729084, 0.729084, 0.734004, 0.734004, 0.734004 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {bS, nOut}, {0., 0., 0., 0.719014, 0.719014, 0.719014, 0.699627, 0.699627, 0.699627, 0.677708, 0.677708, 0.677708, 0.61209, 0.61209, 0.61209}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {bS, nOut}, {0., 0., 0., 2.092814, 2.092814, 2.092814, 2.08832, 2.08832, 2.08832, 2.009851, 2.009851, 2.009851, 1.646034, 1.646034, 1.646034}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
#endif
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests13, lstmLayer_12) {
|
||||
#ifndef HAVE_MKLDNN
|
||||
|
||||
const int sL = 6;
|
||||
const int bS = 5;
|
||||
const int nIn = 4;
|
||||
const int nOut = 3;
|
||||
|
||||
// input arguments
|
||||
const int dataFormat = 0; // [sL,bS,nIn]
|
||||
const int directionMode = 3; // bidirectional concat
|
||||
const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const int cellAct = 0; // tanh activation for cell state
|
||||
const int outAct = 0; // tanh activation for output
|
||||
|
||||
const bool hasBiases = true; // biases array is provided
|
||||
const bool hasSeqLen = true; // seqLen array is not provided
|
||||
const auto hasInitH = true; // initial output is provided
|
||||
const auto hasInitC = true; // initial cell state is provided
|
||||
const auto hasPH = true; // peephole connections are absent
|
||||
const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut]
|
||||
const auto retLastH = true; // do not return output at last time step
|
||||
const auto retLastC = true; // return cells state at last time step
|
||||
|
||||
const double cellClip = 0; // do not apply clipping
|
||||
|
||||
NDArray x('c', {sL, bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {2,nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {2,nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b('c', {2,4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {2,bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray seqLen('c', {bS}, {0,1,2,3,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {2,3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(0.5, 0.5);
|
||||
Wx({0,1, 0,0, 0,0}) = 0.003;
|
||||
Wx({1,2, 0,0, 0,0}) = -0.003;
|
||||
Wr({0,1, 0,0, 0,0}) = 0.006;
|
||||
Wr({1,2, 0,0, 0,0}) = -0.006;
|
||||
b({0,1, 0,0}) = 0.5;
|
||||
b({1,2, 0,0}) = -0.5;
|
||||
hI({0,1, 0,0, 0,0}) = 1;
|
||||
hI({1,2, 0,0, 0,0}) = -1;
|
||||
cI({0,1, 0,0, 0,0}) = 2;
|
||||
cI({1,2, 0,0, 0,0}) = -2;
|
||||
Wp({0,1, 0,0}) = -0.05;
|
||||
Wp({1,2, 0,0}) = 0.05;
|
||||
|
||||
std::initializer_list<double> tArgs = {cellClip};
|
||||
std::initializer_list<Nd4jLong> iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct};
|
||||
std::initializer_list<bool> bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC};
|
||||
|
||||
NDArray expH('c', {sL, bS, 2*nOut}, {0., 0., 0., 0., 0., 0., 0.562925, 0.562925, 0.562925, -0.25361 , -0.25361 , -0.25361 , 0.570404, 0.570404, 0.570404, -0.157103,
|
||||
-0.157103, -0.157103, 0.57777 , 0.57777 , 0.57777 , -0.116502, -0.116502, -0.116502,0.585023, 0.585023, 0.585023, -0.100025,
|
||||
-0.100025, -0.100025, 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072,
|
||||
0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714,0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438,
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473,
|
||||
0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0., 0., 0., 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0.,
|
||||
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expHL('c', {2,bS, nOut}, {0., 0., 0., 0.562925, 0.562925, 0.562925, 0.576568, 0.576568, 0.576568, 0.611224, 0.611224, 0.611224, 0.692315, 0.692315, 0.692315,
|
||||
0., 0., 0., -0.25361 , -0.25361 , -0.25361 , -0.157103, -0.157103, -0.157103,-0.116502, -0.116502, -0.116502, -0.100025, -0.100025, -0.100025}, nd4j::DataType::FLOAT32);
|
||||
NDArray expCL('c', {2,bS, nOut}, {0., 0., 0.,1.534275, 1.534275, 1.534275,1.40183 , 1.40183 , 1.40183 ,1.449675, 1.449675, 1.449675,1.767702, 1.767702, 1.767702,
|
||||
0., 0., 0.,-0.86636 , -0.86636 , -0.86636 ,-0.470245, -0.470245, -0.470245,-0.341856, -0.341856, -0.341856,-0.294986, -0.294986, -0.294986}, nd4j::DataType::FLOAT32);
|
||||
|
||||
nd4j::ops::lstmLayer op;
|
||||
auto results = op.execute({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||
|
||||
auto h = results->at(0);
|
||||
auto hL = results->at(1);
|
||||
auto cL = results->at(2);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
|
||||
ASSERT_TRUE(expHL.isSameShape(hL));
|
||||
ASSERT_TRUE(expHL.equalsTo(hL));
|
||||
|
||||
ASSERT_TRUE(expCL.isSameShape(cL));
|
||||
ASSERT_TRUE(expCL.equalsTo(cL));
|
||||
|
||||
delete results;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -505,9 +505,9 @@ TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
|||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_lstmBlock_2) {
|
||||
int seqLen = 32;
|
||||
int bS = 64;
|
||||
int nIn = 32;
|
||||
int seqLen = 8;
|
||||
int bS = 16;
|
||||
int nIn = 8;
|
||||
|
||||
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
|
||||
auto x1 = NDArrayFactory::create<float>('f', {bS, nIn, seqLen});
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <MmulHelper.h>
|
||||
#include <GradCheck.h>
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/lstmLayer.h>
|
||||
|
||||
|
||||
using namespace nd4j;
|
||||
|
@ -2342,5 +2343,155 @@ TEST_F(HelpersTests1, softmaxDerivative_3) {
|
|||
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests1, lstmLayerCell_1) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nIn = 10;
|
||||
const int nOut = 4;
|
||||
|
||||
const float dataFormat = 0; // is ignored in cell op
|
||||
const float cellClip = 5; // clipping value
|
||||
const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
|
||||
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
|
||||
const float cellAct = 0; // tanh activation for cell state
|
||||
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
|
||||
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
|
||||
const float outAct = 0; // tanh activation for output
|
||||
const float outAlpha = 0; // alpha value for output activation, not required for tanh
|
||||
const float outBeta = 0; // beta value for output activation, not required for tanh
|
||||
|
||||
NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32);
|
||||
NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32);
|
||||
|
||||
std::vector<float> params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
|
||||
|
||||
x = 1.;
|
||||
hI = 2.;
|
||||
cI = 3.;
|
||||
Wx = 0.5;
|
||||
Wr = 0.4;
|
||||
Wp = 0.3;
|
||||
b = 0.7;
|
||||
|
||||
nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
ASSERT_TRUE(expC.isSameShape(c));
|
||||
ASSERT_TRUE(expC.equalsTo(c));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests1, lstmLayerCell_2) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nIn = 10;
|
||||
const int nOut = 4;
|
||||
|
||||
const float dataFormat = 0; // is ignored in cell op
|
||||
const float cellClip = 3; // clipping value
|
||||
const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
|
||||
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
|
||||
const float cellAct = 0; // tanh activation for cell state
|
||||
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
|
||||
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
|
||||
const float outAct = 0; // tanh activation for output
|
||||
const float outAlpha = 0; // alpha value for output activation, not required for tanh
|
||||
const float outBeta = 0; // beta value for output activation, not required for tanh
|
||||
|
||||
NDArray x ('c', {bS, nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray h('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray c('c', {bS, nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, nd4j::DataType::FLOAT32);
|
||||
NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, nd4j::DataType::FLOAT32);
|
||||
|
||||
std::vector<float> params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
|
||||
|
||||
x = 1.;
|
||||
hI = 2.;
|
||||
cI = 3.;
|
||||
Wx = 0.5;
|
||||
Wr = 0.4;
|
||||
Wp = 0.3;
|
||||
b = 0.7;
|
||||
|
||||
nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
ASSERT_TRUE(expC.isSameShape(c));
|
||||
ASSERT_TRUE(expC.equalsTo(c));
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(HelpersTests1, lstmLayerCell_3) {
|
||||
|
||||
const int nIn = 10;
|
||||
const int nOut = 4;
|
||||
|
||||
const float dataFormat = 0; // is ignored in cell op
|
||||
const float cellClip = 5; // clipping value
|
||||
const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates
|
||||
const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid
|
||||
const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid
|
||||
const float cellAct = 0; // tanh activation for cell state
|
||||
const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh
|
||||
const float cellBeta = 0; // beta value for cell state activation, not required for tanh
|
||||
const float outAct = 0; // tanh activation for output
|
||||
const float outAlpha = 0; // alpha value for output activation, not required for tanh
|
||||
const float outBeta = 0; // beta value for output activation, not required for tanh
|
||||
|
||||
NDArray x ('c', {nIn}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wx('c', {nIn, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wr('c', {nOut, 4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray b ('c', {4*nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray hI('c', {nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray cI('c', {nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray Wp('c', {3*nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray h('c', {nOut}, nd4j::DataType::FLOAT32);
|
||||
NDArray c('c', {nOut}, nd4j::DataType::FLOAT32);
|
||||
|
||||
NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, nd4j::DataType::FLOAT32);
|
||||
NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, nd4j::DataType::FLOAT32);
|
||||
|
||||
std::vector<float> params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta};
|
||||
|
||||
x = 1.;
|
||||
hI = 2.;
|
||||
cI = 3.;
|
||||
Wx = 0.5;
|
||||
Wr = 0.4;
|
||||
Wp = 0.3;
|
||||
b = 0.7;
|
||||
|
||||
nd4j::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c);
|
||||
|
||||
ASSERT_TRUE(expH.isSameShape(h));
|
||||
ASSERT_TRUE(expH.equalsTo(h));
|
||||
ASSERT_TRUE(expC.isSameShape(c));
|
||||
ASSERT_TRUE(expC.equalsTo(c));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue