2021-02-09 05:16:31 +01:00
/*
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License , Version 2.0 which is available at
* * https : //www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership .
* * Unless required by applicable law or agreed to in writing , software
* * distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* * License for the specific language governing permissions and limitations
* * under the License .
* *
* * SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
*/
2019-10-17 19:44:52 +02:00
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-10-17 19:44:52 +02:00
# if NOT_EXCLUDED(OP_lstmLayer)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/lstmLayer.h>
2020-04-13 12:21:51 +02:00
2020-03-02 10:49:41 +01:00
namespace sd {
2019-10-17 19:44:52 +02:00
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( lstmLayer , 3 , 1 , false , 1 , 5 ) {
// equations (no peephole connections)
// it = σ (Wxi * xt + Wri * ht-1 + bi)
// ft = σ (Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ (Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ (Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ (Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
2020-04-13 12:21:51 +02:00
// ct = clip(ft ◦ ct-1 + it ◦ c't)
2019-10-17 19:44:52 +02:00
// ot = σ (Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// sL - sequence length, number of time steps
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// *******
// input x:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// input weights Wx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// recurrent weights Wr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
2020-04-13 12:21:51 +02:00
// peephole weights Wp, optional:
2019-10-17 19:44:52 +02:00
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
2020-04-13 12:21:51 +02:00
// biases b, optional:
2019-10-17 19:44:52 +02:00
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// *******
2020-04-13 12:21:51 +02:00
// sequence length array seqLen, optional:
// 1) [bS]
2019-10-17 19:44:52 +02:00
// *******
2020-04-13 12:21:51 +02:00
// initial output hI, optional:
2019-10-17 19:44:52 +02:00
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
2020-04-13 12:21:51 +02:00
// initial cell state cI (same shape as in hI), optional:
2019-10-17 19:44:52 +02:00
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// OUTPUTS:
// *******
2020-04-13 12:21:51 +02:00
// output h, optional:
2019-10-17 19:44:52 +02:00
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// *******
2020-04-13 12:21:51 +02:00
// output at last step hL, optional:
2019-10-17 19:44:52 +02:00
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
2020-04-13 12:21:51 +02:00
// cell state at last step cL (same shape as in hL), optional:
2019-10-17 19:44:52 +02:00
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
2020-04-13 12:21:51 +02:00
const auto dataFormat = INT_ARG ( 0 ) ; // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
2019-10-17 19:44:52 +02:00
const auto directionMode = INT_ARG ( 1 ) ; // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG ( 2 ) ; // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG ( 3 ) ; // activation for cell state (c)
const auto outAct = INT_ARG ( 4 ) ; // activation for output (h)
const auto hasBiases = B_ARG ( 0 ) ; // indicates whether biases array is provided
const auto hasSeqLen = B_ARG ( 1 ) ; // indicates whether seqLen array is provided
const auto hasInitH = B_ARG ( 2 ) ; // indicates whether initial output is provided
const auto hasInitC = B_ARG ( 3 ) ; // indicates whether initial cell state is provided
const auto hasPH = B_ARG ( 4 ) ; // indicates whether peephole connections are present
const auto retFullSeq = B_ARG ( 5 ) ; // indicates whether to return whole time sequence h {h_0, h_1, ... , h_sL-1}
2020-04-13 12:21:51 +02:00
const auto retLastH = B_ARG ( 6 ) ; // indicates whether to return output at last time step only
const auto retLastC = B_ARG ( 7 ) ; // indicates whether to return cells state at last time step only
2019-10-17 19:44:52 +02:00
const auto gateActHasAlpha = gateAct = = 3 | | gateAct = = 4 | | gateAct = = 5 | | gateAct = = 6 | | gateAct = = 8 ;
const auto cellActHasAlpha = cellAct = = 3 | | cellAct = = 4 | | cellAct = = 5 | | cellAct = = 6 | | cellAct = = 8 ;
const auto outActHasAlpha = outAct = = 3 | | outAct = = 4 | | outAct = = 5 | | outAct = = 6 | | outAct = = 8 ;
const auto gateActHasBeta = gateAct = = 3 | | gateAct = = 6 ;
const auto cellActHasBeta = cellAct = = 3 | | cellAct = = 6 ;
const auto outActHasBeta = outAct = = 3 | | outAct = = 6 ;
uint count = 1 ;
const auto cellClip = T_ARG ( 0 ) ; // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG ( count + + ) : 0 ;
const auto gateBeta = gateActHasBeta ? T_ARG ( count + + ) : 0 ;
const auto cellAlpha = cellActHasAlpha ? T_ARG ( count + + ) : 0 ;
const auto cellBeta = cellActHasBeta ? T_ARG ( count + + ) : 0 ;
const auto outAlpha = outActHasAlpha ? T_ARG ( count + + ) : 0 ;
const auto outBeta = outActHasBeta ? T_ARG ( count + + ) : 0 ;
const auto x = INPUT_VARIABLE ( 0 ) ; // input
const auto Wx = INPUT_VARIABLE ( 1 ) ; // input weights
const auto Wr = INPUT_VARIABLE ( 2 ) ; // recurrent weights
count = 3 ;
const auto b = hasBiases ? INPUT_VARIABLE ( count + + ) : nullptr ; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE ( count + + ) : nullptr ; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE ( count + + ) : nullptr ; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE ( count + + ) : nullptr ; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE ( count + + ) : nullptr ; // peephole weights
REQUIRE_TRUE ( dataFormat < 3 | | ( dataFormat = = 3 & & directionMode = = 4 ) , 0 , " LSTM_LAYER operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead ! " , dataFormat , directionMode ) ;
REQUIRE_TRUE ( cellClip > = 0 , 0 , " LSTM_LAYER operation: cell clipping value should be nonnegative (>=0) ! " ) ;
REQUIRE_TRUE ( retFullSeq | | retLastH | | retLastC , 0 , " LSTM_LAYER operation: please specify what output arrays to produce ! " ) ;
count = 0 ;
auto h = retFullSeq ? OUTPUT_VARIABLE ( count + + ) : nullptr ; // output
auto hL = retLastH ? OUTPUT_VARIABLE ( count + + ) : nullptr ; // output at last step
auto cL = retLastC ? OUTPUT_VARIABLE ( count + + ) : nullptr ; // cell state at last step
// evaluate dimensions
const Nd4jLong sL = dataFormat = = 3 ? x - > sizeAt ( 0 ) : x - > sizeAt ( dataFormat ) ;
2020-04-13 12:21:51 +02:00
const Nd4jLong bS = dataFormat = = 1 | | dataFormat = = 2 ? x - > sizeAt ( 0 ) : x - > sizeAt ( 1 ) ;
const Nd4jLong nIn = dataFormat = = 2 ? x - > sizeAt ( 1 ) : x - > sizeAt ( 2 ) ;
2019-10-17 19:44:52 +02:00
const Nd4jLong nOut = Wx - > sizeAt ( - 1 ) / 4 ;
// inputs validations
if ( directionMode < 2 ) { // no bidirectional
// Wx validation
if ( Wx - > rankOf ( ) ! = 2 | | Wx - > sizeAt ( 0 ) ! = nIn )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { nIn , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wx ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// Wr validation
if ( Wr - > rankOf ( ) ! = 2 | | Wr - > sizeAt ( 0 ) ! = nOut | | Wr - > sizeAt ( 1 ) ! = 4 * nOut )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { nOut , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wr ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// biases validation
if ( b ! = nullptr & & ( b - > rankOf ( ) ! = 1 | | b - > sizeAt ( 0 ) ! = 4 * nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// initial output validation
if ( hI ! = nullptr & & ( hI - > rankOf ( ) ! = 2 | | hI - > sizeAt ( 0 ) ! = bS | | hI - > sizeAt ( 1 ) ! = nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( hI ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// initial cell validation
if ( cI ! = nullptr & & ( cI - > rankOf ( ) ! = 2 | | cI - > sizeAt ( 0 ) ! = bS | | cI - > sizeAt ( 1 ) ! = nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( cI ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// peephole weights validation
if ( Wp ! = nullptr & & ( Wp - > rankOf ( ) ! = 1 | | Wp - > sizeAt ( 0 ) ! = 3 * nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 3 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wp ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
}
else { // bidirectional
// Wx validation
if ( Wx - > rankOf ( ) ! = 3 | | Wx - > sizeAt ( 0 ) ! = 2 | | Wx - > sizeAt ( 1 ) ! = nIn )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of input weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , nIn , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wx ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// Wr validation
if ( Wr - > rankOf ( ) ! = 3 | | Wr - > sizeAt ( 0 ) ! = 2 | | Wr - > sizeAt ( 1 ) ! = nOut | | Wr - > sizeAt ( 2 ) ! = 4 * nOut )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of recurrent weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , nOut , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wr ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// biases validation
if ( b ! = nullptr & & ( b - > rankOf ( ) ! = 2 | | b - > sizeAt ( 0 ) ! = 2 | | b - > sizeAt ( 1 ) ! = 4 * nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of biases, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// initial output validation
if ( hI ! = nullptr & & ( hI - > rankOf ( ) ! = 3 | | hI - > sizeAt ( 0 ) ! = 2 | | hI - > sizeAt ( 1 ) ! = bS | | hI - > sizeAt ( 2 ) ! = nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of initial output, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( hI ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// initial cell validation
if ( cI ! = nullptr & & ( cI - > rankOf ( ) ! = 3 | | cI - > sizeAt ( 0 ) ! = 2 | | cI - > sizeAt ( 1 ) ! = bS | | cI - > sizeAt ( 2 ) ! = nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong shape of initial cell state, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( cI ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
// peephole weights validation
if ( Wp ! = nullptr & & ( Wp - > rankOf ( ) ! = 2 | | Wp - > sizeAt ( 0 ) ! = 2 | | Wp - > sizeAt ( 1 ) ! = 3 * nOut ) )
2019-11-03 11:37:19 +01:00
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER operation: wrong peephole weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , 3 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wp ) . c_str ( ) ) ;
2019-10-17 19:44:52 +02:00
}
std : : vector < float > params = { static_cast < float > ( dataFormat ) , static_cast < float > ( directionMode ) , static_cast < float > ( cellClip ) ,
static_cast < float > ( gateAct ) , static_cast < float > ( gateAlpha ) , static_cast < float > ( gateBeta ) ,
static_cast < float > ( cellAct ) , static_cast < float > ( cellAlpha ) , static_cast < float > ( cellBeta ) ,
static_cast < float > ( outAct ) , static_cast < float > ( outAlpha ) , static_cast < float > ( outBeta ) } ;
if ( directionMode = = 0 ) { // forward
helpers : : lstmLayerTimeLoop ( x , Wx , Wr , b , seqLen , hI , cI , Wp , params , true , h , hL , cL ) ;
}
else if ( directionMode = = 1 ) { // backward
helpers : : lstmLayerTimeLoop ( x , Wx , Wr , b , seqLen , hI , cI , Wp , params , false , h , hL , cL ) ;
}
else { // bidirectional
NDArray WxFwd = ( * Wx ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ;
NDArray WxBwd = ( * Wx ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ;
NDArray WrFwd = ( * Wr ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ;
NDArray WrBwd = ( * Wr ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ;
NDArray * WpFwd ( nullptr ) , * WpBwd ( nullptr ) , * bFwd ( nullptr ) , * bBwd ( nullptr ) , * hIFwd ( nullptr ) , * hIBwd ( nullptr ) , * cIFwd ( nullptr ) , * cIBwd ( nullptr ) ,
* hLFwd ( nullptr ) , * hLBwd ( nullptr ) , * cLFwd ( nullptr ) , * cLBwd ( nullptr ) , * hFwd ( nullptr ) , * hBwd ( nullptr ) ;
if ( Wp ) {
WpFwd = new NDArray ( ( * Wp ) ( { 0 , 1 , 0 , 0 } ) ) ;
WpBwd = new NDArray ( ( * Wp ) ( { 1 , 2 , 0 , 0 } ) ) ;
}
if ( b ) {
bFwd = new NDArray ( ( * b ) ( { 0 , 1 , 0 , 0 } ) ) ;
bBwd = new NDArray ( ( * b ) ( { 1 , 2 , 0 , 0 } ) ) ;
}
if ( hI ) {
hIFwd = new NDArray ( ( * hI ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
hIBwd = new NDArray ( ( * hI ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( cI ) {
cIFwd = new NDArray ( ( * cI ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
cIBwd = new NDArray ( ( * cI ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( hL ) {
hLFwd = new NDArray ( ( * hL ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
hLBwd = new NDArray ( ( * hL ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( cL ) {
cLFwd = new NDArray ( ( * cL ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
cLBwd = new NDArray ( ( * cL ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( h ) {
if ( directionMode = = 2 ) { // sum
hFwd = h ;
hBwd = new NDArray ( h , false , h - > getContext ( ) ) ;
}
else if ( directionMode = = 3 ) { // concat
hFwd = new NDArray ( dataFormat < = 1 ? ( * h ) ( { 0 , 0 , 0 , 0 , 0 , nOut } ) : ( * h ) ( { 0 , 0 , 0 , nOut , 0 , 0 } ) ) ;
hBwd = new NDArray ( dataFormat < = 1 ? ( * h ) ( { 0 , 0 , 0 , 0 , nOut , 2 * nOut } ) : ( * h ) ( { 0 , 0 , nOut , 2 * nOut , 0 , 0 } ) ) ;
}
else { // directionMode == 4
hFwd = new NDArray ( ( * h ) ( { 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
hBwd = new NDArray ( ( * h ) ( { 0 , 0 , 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
}
// FIXME - following two calls are independent and may run in different streams
helpers : : lstmLayerTimeLoop ( x , & WxFwd , & WrFwd , bFwd , seqLen , hIFwd , cIFwd , WpFwd , params , true , hFwd , hLFwd , cLFwd ) ;
helpers : : lstmLayerTimeLoop ( x , & WxBwd , & WrBwd , bBwd , seqLen , hIBwd , cIBwd , WpBwd , params , false , hBwd , hLBwd , cLBwd ) ;
if ( h & & directionMode = = 2 )
* h + = * hBwd ;
delete WpFwd ; delete WpBwd ; delete bFwd ; delete bBwd ; delete hIFwd ; delete hIBwd ; delete cIFwd ;
delete cIBwd ; delete hLFwd ; delete hLBwd ; delete cLFwd ; delete cLBwd ; delete hBwd ;
if ( hFwd ! = h )
delete hFwd ;
}
return Status : : OK ( ) ;
}
DECLARE_TYPES ( lstmLayer ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-10-17 19:44:52 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
DECLARE_SHAPE_FN ( lstmLayer ) {
const auto dataFormat = INT_ARG ( 0 ) ; // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, 2, bS, nIn] (for ONNX)
const auto directionMode = INT_ARG ( 1 ) ; // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim
const auto retFullSeq = B_ARG ( 5 ) ; // indicates whether to return whole h {h_0, h_1, ... , h_sL-1}, if true, format would be [sL,bS,nOut] (exact shape depends on dataFormat argument)
const auto retLastH = B_ARG ( 6 ) ; // indicates whether to return output at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto retLastC = B_ARG ( 7 ) ; // indicates whether to return cells state at last time step only, in this case shape would be [bS, nOut] (exact shape depends on dataFormat argument)
const auto x = INPUT_VARIABLE ( 0 ) ; // input
const auto Wx = INPUT_VARIABLE ( 1 ) ; // input weights
const auto Wr = INPUT_VARIABLE ( 2 ) ; // recurrent weights
// evaluate dimensions
2020-04-13 12:21:51 +02:00
const Nd4jLong sL = dataFormat = = 3 ? x - > sizeAt ( 0 ) : x - > sizeAt ( dataFormat ) ;
const Nd4jLong bS = dataFormat = = 1 | | dataFormat = = 2 ? x - > sizeAt ( 0 ) : x - > sizeAt ( 1 ) ;
const Nd4jLong nIn = dataFormat = = 2 ? x - > sizeAt ( 1 ) : x - > sizeAt ( 2 ) ;
2019-10-17 19:44:52 +02:00
const Nd4jLong nOut = Wx - > sizeAt ( - 1 ) / 4 ;
DataType type ;
if ( x - > isR ( ) )
type = x - > dataType ( ) ;
else
2020-03-02 10:49:41 +01:00
type = sd : : DataType : : FLOAT32 ;
2019-10-17 19:44:52 +02:00
2020-05-09 07:06:14 +02:00
auto shapes = SHAPELIST ( ) ;
2019-10-17 19:44:52 +02:00
// evaluate h shape (output)
if ( retFullSeq ) {
std : : vector < Nd4jLong > hShape ;
if ( directionMode < = 2 ) { // single direction or bidirectional with sum
if ( dataFormat = = 0 )
hShape = { sL , bS , nOut } ;
else if ( dataFormat = = 1 )
hShape = { bS , sL , nOut } ;
else if ( dataFormat = = 2 )
hShape = { bS , nOut , sL } ;
}
else if ( directionMode = = 3 ) { // bidirectional with concat
if ( dataFormat = = 0 )
hShape = { sL , bS , 2 * nOut } ;
else if ( dataFormat = = 1 )
hShape = { bS , sL , 2 * nOut } ;
else if ( dataFormat = = 2 )
hShape = { bS , 2 * nOut , sL } ;
}
else { // bidirectional with extra output dimension equal to 2
hShape = { sL , 2 , bS , nOut } ;
}
2020-06-06 14:26:55 +02:00
shapes - > push_back ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( type , x - > ordering ( ) , hShape ) ) ;
2019-10-17 19:44:52 +02:00
}
// evaluate hL shape (output at last step)
if ( retLastH ) {
std : : vector < Nd4jLong > hLShape ;
if ( directionMode < 2 )
hLShape = { bS , nOut } ;
else
hLShape = { 2 , bS , nOut } ;
2020-06-06 14:26:55 +02:00
shapes - > push_back ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( type , x - > ordering ( ) , hLShape ) ) ;
2019-10-17 19:44:52 +02:00
if ( retLastC ) // cL and hL have same shapes
2020-05-09 07:06:14 +02:00
shapes - > push_back ( shapes - > at ( shapes - > size ( ) - 1 ) ) ;
2019-10-17 19:44:52 +02:00
}
// evaluate cL shape (cell state at last step)
if ( retLastC & & ! retLastH ) {
std : : vector < Nd4jLong > cLShape ;
if ( directionMode < 2 )
cLShape = { bS , nOut } ;
else
cLShape = { 2 , bS , nOut } ;
2020-06-06 14:26:55 +02:00
shapes - > push_back ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( type , x - > ordering ( ) , cLShape ) ) ;
2019-10-17 19:44:52 +02:00
}
2020-05-09 07:06:14 +02:00
return shapes ;
2019-10-17 19:44:52 +02:00
}
2020-04-13 12:21:51 +02:00
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( lstmLayer_bp , 4 , 1 , false , 1 , 5 ) {
// equations (no peephole connections)
// it = σ (Wxi * xt + Wri * ht-1 + bi)
// ft = σ (Wxf * xt + Wrf * ht-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = ft ◦ ct-1 + it ◦ c't
// ot = σ (Wxo * xt + Wro * ht-1 + bo)
// ht = ot ◦ tanh(ct)
// equations (peephole connections are present)
// it = σ (Wxi * xt + Wri * ht-1 + Wpi ◦ ct-1 + bi)
// ft = σ (Wxf * xt + Wrf * ht-1 + Wpf ◦ ct-1 + bf)
// c't = tanh(Wxc * xt + Wrc * ht-1 + bc)
// ct = clip(ft ◦ ct-1 + it ◦ c't)
// ot = σ (Wxo * xt + Wro * ht-1 + Wpo ◦ ct + bo)
// ht = ot ◦ tanh(ct)
// notations:
// bS - batch size
// sL - sequence length, number of time steps
// nIn - input size
// nOut - output size (hidden size)
// INPUTS:
// *******
// input x:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// input weights Wx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// recurrent weights Wr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// peephole weights Wp, optional:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// biases b, optional:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// *******
// sequence length array seqLen, optional:
// 1) [bS]
// *******
// initial output hI, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// initial cell state cI (same shape as in hI), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs. output dLdh, optional:
// 1) [sL, bS, nOut] when directionMode <= 2 && dataFormat == 0
// 2) [bS, sL, nOut] when directionMode <= 2 && dataFormat == 1
// 3) [bS, nOut, sL] when directionMode <= 2 && dataFormat == 2
// 4) [sL, bS, 2*nOut] when directionMode == 3 && dataFormat == 0
// 5) [bS, sL, 2*nOut] when directionMode == 3 && dataFormat == 1
// 6) [bS, 2*nOut, sL] when directionMode == 3 && dataFormat == 2
// 7) [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3
// *******
// gradient vs output at last time step dLdhL, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs cell state at last time step dLdcL(same shape as in dLdhL), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// OUTPUTS:
// *******
// gradient vs. input dLdx:
// 1) [sL, bS, nIn] when dataFormat == 0
// 2) [bS, sL, nIn] when dataFormat == 1
// 3) [bS, nIn, sL] when dataFormat == 2
// *******
// gradient vs. input weights dLdWx:
// 1) [nIn, 4*nOut] when directionMode < 2
// 2) [2, nIn, 4*nOut] when directionMode >= 2
// *******
// gradient vs. recurrent weights dLdWr:
// 1) [nOut, 4*nOut] when directionMode < 2
// 2) [2, nOut, 4*nOut] when directionMode >= 2
// *******
// gradient vs. peephole weights dLdWp, optional:
// 1) [3*nOut] when directionMode < 2
// 2) [2, 3*nOut] when directionMode >= 2
// *******
// gradient vs. biases dLdb, optional:
// 1) [4*nOut] when directionMode < 2
// 2) [2, 4*nOut] when directionMode >= 2
// gradient vs. sequence length array dLdsL, optional (do not calculate it!!!):
// 1) [bS] always
// *******
// gradient vs. initial output dLdhI, optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// *******
// gradient vs. initial cell state dLdcI (same shape as in dLdhI), optional:
// 1) [bS, nOut] when directionMode < 2
// 2) [2, bS, nOut] when directionMode >= 2
// !!! dimension 4*nOut implies order it, ft, c't, ot
// !!! dimension 3*nOut implies order it, ft, ot
const auto dataFormat = INT_ARG ( 0 ) ; // for unidirectional: 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn], 2 = [bS, nIn, sL], for bidirectional: 3 = [sL, bS, nIn] && [sL, 2, bS, nOut] (for ONNX)
const auto directionMode = INT_ARG ( 1 ) ; // direction: 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat, 4 = bidirectional extra output dim (in conjunction with format dataFormat = 3)
// integer numbers corresponding to activations: 0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus
const auto gateAct = INT_ARG ( 2 ) ; // activation for input (i), forget (f) and output (o) gates
const auto cellAct = INT_ARG ( 3 ) ; // activation for cell state (c)
const auto outAct = INT_ARG ( 4 ) ; // activation for output (h)
const auto hasBiases = B_ARG ( 0 ) ; // indicates whether biases array is provided
const auto hasSeqLen = B_ARG ( 1 ) ; // indicates whether seqLen array is provided
const auto hasInitH = B_ARG ( 2 ) ; // indicates whether initial output is provided
const auto hasInitC = B_ARG ( 3 ) ; // indicates whether initial cell state is provided
const auto hasPH = B_ARG ( 4 ) ; // indicates whether peephole connections are present
const auto retFullSeq = B_ARG ( 5 ) ; // indicates whether gradient vs. outputs is given for whole time sequence dLdh {dLdh_0, dLdh_1, ... , dLdh_sL-1}
const auto retLastH = B_ARG ( 6 ) ; // indicates whether gradient vs. output at last time step (dLdhL) is given
const auto retLastC = B_ARG ( 7 ) ; // indicates whether gradient vs. cell state at last time step (dLdcL) is given
const auto gateActHasAlpha = gateAct = = 3 | | gateAct = = 4 | | gateAct = = 5 | | gateAct = = 6 | | gateAct = = 8 ;
const auto cellActHasAlpha = cellAct = = 3 | | cellAct = = 4 | | cellAct = = 5 | | cellAct = = 6 | | cellAct = = 8 ;
const auto outActHasAlpha = outAct = = 3 | | outAct = = 4 | | outAct = = 5 | | outAct = = 6 | | outAct = = 8 ;
const auto gateActHasBeta = gateAct = = 3 | | gateAct = = 6 ;
const auto cellActHasBeta = cellAct = = 3 | | cellAct = = 6 ;
const auto outActHasBeta = outAct = = 3 | | outAct = = 6 ;
uint count = 1 ;
const auto cellClip = T_ARG ( 0 ) ; // cell clipping value, if it = 0 then do not apply clipping
const auto gateAlpha = gateActHasAlpha ? T_ARG ( count + + ) : 0 ;
const auto gateBeta = gateActHasBeta ? T_ARG ( count + + ) : 0 ;
const auto cellAlpha = cellActHasAlpha ? T_ARG ( count + + ) : 0 ;
const auto cellBeta = cellActHasBeta ? T_ARG ( count + + ) : 0 ;
const auto outAlpha = outActHasAlpha ? T_ARG ( count + + ) : 0 ;
const auto outBeta = outActHasBeta ? T_ARG ( count + + ) : 0 ;
REQUIRE_TRUE ( dataFormat < 3 | | ( dataFormat = = 3 & & directionMode = = 4 ) , 0 , " LSTM_LAYER_BP operation: if argument dataFormat = 3, then directionMode = 4, but got dataFormat = %i and directionMode = %i instead ! " , dataFormat , directionMode ) ;
REQUIRE_TRUE ( cellClip > = 0 , 0 , " LSTM_LAYER_BP operation: cell clipping value should be nonnegative (>=0) ! " ) ;
REQUIRE_TRUE ( retFullSeq | | retLastH | | retLastC , 0 , " LSTM_LAYER_BP operation: please specify at least one of three input gradient arrays: dLdh, dLdhL or dLdcL ! " ) ;
const auto x = INPUT_VARIABLE ( 0 ) ; // input
const auto Wx = INPUT_VARIABLE ( 1 ) ; // input weights
const auto Wr = INPUT_VARIABLE ( 2 ) ; // recurrent weights
count = 3 ;
const auto b = hasBiases ? INPUT_VARIABLE ( count + + ) : nullptr ; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE ( count + + ) : nullptr ; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE ( count + + ) : nullptr ; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE ( count + + ) : nullptr ; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE ( count + + ) : nullptr ; // peephole weights
const auto dLdh = retFullSeq ? INPUT_VARIABLE ( count + + ) : nullptr ; // gradient vs. output
const auto dLdhL = retLastH ? INPUT_VARIABLE ( count + + ) : nullptr ; // gradient vs. output at last time step
const auto dLdcL = retLastC ? INPUT_VARIABLE ( count + + ) : nullptr ; // gradient vs. cell state at last time step
count = 3 ;
auto dLdx = OUTPUT_VARIABLE ( 0 ) ; // gradient vs. input
auto dLdWx = OUTPUT_NULLIFIED ( 1 ) ; // gradient vs. input weights
auto dLdWr = OUTPUT_NULLIFIED ( 2 ) ; // gradient vs. recurrent weights
auto dLdb = hasBiases ? OUTPUT_NULLIFIED ( count + + ) : nullptr ; // gradient vs. biases
auto dLdsL = hasSeqLen ? INPUT_VARIABLE ( count + + ) : nullptr ; // gradient vs. seqLen vector, we don't calculate it !!!
auto dLdhI = hasInitH ? OUTPUT_NULLIFIED ( count + + ) : nullptr ; // gradient vs. initial output
auto dLdcI = hasInitC ? OUTPUT_NULLIFIED ( count + + ) : nullptr ; // gradient vs. initial cell state
auto dLdWp = hasPH ? OUTPUT_NULLIFIED ( count ) : nullptr ; // gradient vs. peephole weights
// evaluate dimensions
const Nd4jLong sL = dataFormat = = 3 ? x - > sizeAt ( 0 ) : x - > sizeAt ( dataFormat ) ;
const Nd4jLong bS = dataFormat = = 1 | | dataFormat = = 2 ? x - > sizeAt ( 0 ) : x - > sizeAt ( 1 ) ;
const Nd4jLong nIn = dataFormat = = 2 ? x - > sizeAt ( 1 ) : x - > sizeAt ( 2 ) ;
const Nd4jLong nOut = Wx - > sizeAt ( - 1 ) / 4 ;
// inputs validations
if ( directionMode < 2 ) { // no bidirectional
// Wx validation
if ( Wx - > rankOf ( ) ! = 2 | | Wx - > sizeAt ( 0 ) ! = nIn )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { nIn , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wx ) . c_str ( ) ) ;
// Wr validation
if ( Wr - > rankOf ( ) ! = 2 | | Wr - > sizeAt ( 0 ) ! = nOut | | Wr - > sizeAt ( 1 ) ! = 4 * nOut )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { nOut , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wr ) . c_str ( ) ) ;
// biases validation
if ( b ! = nullptr & & ( b - > rankOf ( ) ! = 1 | | b - > sizeAt ( 0 ) ! = 4 * nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ;
// initial output validation
if ( hI ! = nullptr & & ( hI - > rankOf ( ) ! = 2 | | hI - > sizeAt ( 0 ) ! = bS | | hI - > sizeAt ( 1 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( hI ) . c_str ( ) ) ;
// initial cell validation
if ( cI ! = nullptr & & ( cI - > rankOf ( ) ! = 2 | | cI - > sizeAt ( 0 ) ! = bS | | cI - > sizeAt ( 1 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( cI ) . c_str ( ) ) ;
// peephole weights validation
if ( Wp ! = nullptr & & ( Wp - > rankOf ( ) ! = 1 | | Wp - > sizeAt ( 0 ) ! = 3 * nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 3 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wp ) . c_str ( ) ) ;
// gradient vs. output at last time step validation
if ( dLdhL ! = nullptr & & ( dLdhL - > rankOf ( ) ! = 2 | | dLdhL - > sizeAt ( 0 ) ! = bS | | dLdhL - > sizeAt ( 1 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdhL ) . c_str ( ) ) ;
// gradient vs. cell state at last time step validation
if ( dLdcL ! = nullptr & & ( dLdcL - > rankOf ( ) ! = 2 | | dLdcL - > sizeAt ( 0 ) ! = bS | | dLdcL - > sizeAt ( 1 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdcL ) . c_str ( ) ) ;
}
else { // bidirectional
// Wx validation
if ( Wx - > rankOf ( ) ! = 3 | | Wx - > sizeAt ( 0 ) ! = 2 | | Wx - > sizeAt ( 1 ) ! = nIn )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of input weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , nIn , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wx ) . c_str ( ) ) ;
// Wr validation
if ( Wr - > rankOf ( ) ! = 3 | | Wr - > sizeAt ( 0 ) ! = 2 | | Wr - > sizeAt ( 1 ) ! = nOut | | Wr - > sizeAt ( 2 ) ! = 4 * nOut )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of recurrent weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , nOut , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wr ) . c_str ( ) ) ;
// biases validation
if ( b ! = nullptr & & ( b - > rankOf ( ) ! = 2 | | b - > sizeAt ( 0 ) ! = 2 | | b - > sizeAt ( 1 ) ! = 4 * nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of biases, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , 4 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ;
// initial output validation
if ( hI ! = nullptr & & ( hI - > rankOf ( ) ! = 3 | | hI - > sizeAt ( 0 ) ! = 2 | | hI - > sizeAt ( 1 ) ! = bS | | hI - > sizeAt ( 2 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of initial output, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( hI ) . c_str ( ) ) ;
// initial cell validation
if ( cI ! = nullptr & & ( cI - > rankOf ( ) ! = 3 | | cI - > sizeAt ( 0 ) ! = 2 | | cI - > sizeAt ( 1 ) ! = bS | | cI - > sizeAt ( 2 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of initial cell state, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( cI ) . c_str ( ) ) ;
// peephole weights validation
if ( Wp ! = nullptr & & ( Wp - > rankOf ( ) ! = 2 | | Wp - > sizeAt ( 0 ) ! = 2 | | Wp - > sizeAt ( 1 ) ! = 3 * nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong peephole weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , 3 * nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wp ) . c_str ( ) ) ;
// gradient vs. output at last time step validation
if ( dLdhL ! = nullptr & & ( dLdhL - > rankOf ( ) ! = 3 | | dLdhL - > sizeAt ( 0 ) ! = 2 | | dLdhL - > sizeAt ( 1 ) ! = bS | | dLdhL - > sizeAt ( 2 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of gradient vs. output at last time step, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdhL ) . c_str ( ) ) ;
// gradient vs. cell state at last time step validation
if ( dLdcL ! = nullptr & & ( dLdcL - > rankOf ( ) ! = 3 | | dLdcL - > sizeAt ( 0 ) ! = 2 | | dLdcL - > sizeAt ( 1 ) ! = bS | | dLdcL - > sizeAt ( 2 ) ! = nOut ) )
REQUIRE_TRUE ( false , 0 , " LSTM_LAYER_BP operation: wrong shape of gradient vs. cell state at last time, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { 2 , bS , nOut } ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdcL ) . c_str ( ) ) ;
}
// gradient vs. output validation
if ( dLdh ) {
int factor = directionMode < = 2 ? 1 : 2 ;
std : : vector < Nd4jLong > expdLdhShape ;
if ( dataFormat = = 0 ) expdLdhShape = std : : vector < Nd4jLong > { sL , bS , factor * nOut } ;
else if ( dataFormat = = 1 ) expdLdhShape = std : : vector < Nd4jLong > { bS , sL , factor * nOut } ;
else if ( dataFormat = = 2 ) expdLdhShape = std : : vector < Nd4jLong > { bS , factor * nOut , sL } ;
else expdLdhShape = std : : vector < Nd4jLong > { sL , 2 , bS , nOut } ;
REQUIRE_TRUE ( dLdh - > isSameShape ( expdLdhShape ) , 0 , " LSTM_LAYER_CELL_BP operation: wrong shape of gradient vs. output, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expdLdhShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdh ) . c_str ( ) ) ;
}
std : : vector < float > params = { static_cast < float > ( dataFormat ) , static_cast < float > ( directionMode ) , static_cast < float > ( cellClip ) ,
static_cast < float > ( gateAct ) , static_cast < float > ( gateAlpha ) , static_cast < float > ( gateBeta ) ,
static_cast < float > ( cellAct ) , static_cast < float > ( cellAlpha ) , static_cast < float > ( cellBeta ) ,
static_cast < float > ( outAct ) , static_cast < float > ( outAlpha ) , static_cast < float > ( outBeta ) } ;
if ( directionMode = = 0 ) { // forward
helpers : : lstmLayerTimeLoopBp ( x , Wx , Wr , b , seqLen , hI , cI , Wp , dLdh , dLdhL , dLdcL , params , true , dLdx , dLdWx , dLdWr , dLdb , dLdhI , dLdcI , dLdWp ) ;
}
else if ( directionMode = = 1 ) { // backward
helpers : : lstmLayerTimeLoopBp ( x , Wx , Wr , b , seqLen , hI , cI , Wp , dLdh , dLdhL , dLdcL , params , false , dLdx , dLdWx , dLdWr , dLdb , dLdhI , dLdcI , dLdWp ) ;
}
else { // bidirectional
NDArray WxFwd = ( * Wx ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ;
NDArray WxBwd = ( * Wx ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ;
NDArray dLdWxFwd = ( * dLdWx ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ;
NDArray dLdWxBwd = ( * dLdWx ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ;
NDArray WrFwd = ( * Wr ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ;
NDArray WrBwd = ( * Wr ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ;
NDArray dLdWrFwd = ( * dLdWr ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ;
NDArray dLdWrBwd = ( * dLdWr ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ;
NDArray * WpFwd ( nullptr ) , * WpBwd ( nullptr ) , * bFwd ( nullptr ) , * bBwd ( nullptr ) , * hIFwd ( nullptr ) , * hIBwd ( nullptr ) , * cIFwd ( nullptr ) , * cIBwd ( nullptr ) ,
* dLdhFwd ( nullptr ) , * dLdhBwd ( nullptr ) , * dLdhLFwd ( nullptr ) , * dLdhLBwd ( nullptr ) , * dLdcLFwd ( nullptr ) , * dLdcLBwd ( nullptr ) ,
* dLdWpFwd ( nullptr ) , * dLdWpBwd ( nullptr ) , * dLdbFwd ( nullptr ) , * dLdbBwd ( nullptr ) ,
* dLdhIFwd ( nullptr ) , * dLdhIBwd ( nullptr ) , * dLdcIFwd ( nullptr ) , * dLdcIBwd ( nullptr ) ;
if ( Wp ) {
WpFwd = new NDArray ( ( * Wp ) ( { 0 , 1 , 0 , 0 } ) ) ;
WpBwd = new NDArray ( ( * Wp ) ( { 1 , 2 , 0 , 0 } ) ) ;
dLdWpFwd = new NDArray ( ( * dLdWp ) ( { 0 , 1 , 0 , 0 } ) ) ;
dLdWpBwd = new NDArray ( ( * dLdWp ) ( { 1 , 2 , 0 , 0 } ) ) ;
}
if ( b ) {
bFwd = new NDArray ( ( * b ) ( { 0 , 1 , 0 , 0 } ) ) ;
bBwd = new NDArray ( ( * b ) ( { 1 , 2 , 0 , 0 } ) ) ;
dLdbFwd = new NDArray ( ( * dLdb ) ( { 0 , 1 , 0 , 0 } ) ) ;
dLdbBwd = new NDArray ( ( * dLdb ) ( { 1 , 2 , 0 , 0 } ) ) ;
}
if ( hI ) {
hIFwd = new NDArray ( ( * hI ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
hIBwd = new NDArray ( ( * hI ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
dLdhIFwd = new NDArray ( ( * dLdhI ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
dLdhIBwd = new NDArray ( ( * dLdhI ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( cI ) {
cIFwd = new NDArray ( ( * cI ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
cIBwd = new NDArray ( ( * cI ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
dLdcIFwd = new NDArray ( ( * dLdcI ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
dLdcIBwd = new NDArray ( ( * dLdcI ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( dLdhL ) {
dLdhLFwd = new NDArray ( ( * dLdhL ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
dLdhLBwd = new NDArray ( ( * dLdhL ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( dLdcL ) {
dLdcLFwd = new NDArray ( ( * dLdcL ) ( { 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
dLdcLBwd = new NDArray ( ( * dLdcL ) ( { 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
if ( dLdh ) {
if ( directionMode = = 2 ) { // sum
2020-04-16 07:09:04 +02:00
dLdhFwd = dLdh ;
dLdhBwd = dLdh ;
2020-04-13 12:21:51 +02:00
}
else if ( directionMode = = 3 ) { // concat
dLdhFwd = new NDArray ( dataFormat < = 1 ? ( * dLdh ) ( { 0 , 0 , 0 , 0 , 0 , nOut } ) : ( * dLdh ) ( { 0 , 0 , 0 , nOut , 0 , 0 } ) ) ;
dLdhBwd = new NDArray ( dataFormat < = 1 ? ( * dLdh ) ( { 0 , 0 , 0 , 0 , nOut , 2 * nOut } ) : ( * dLdh ) ( { 0 , 0 , nOut , 2 * nOut , 0 , 0 } ) ) ;
}
else { // directionMode == 4
dLdhFwd = new NDArray ( ( * dLdh ) ( { 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 } ) ) ;
dLdhBwd = new NDArray ( ( * dLdh ) ( { 0 , 0 , 1 , 2 , 0 , 0 , 0 , 0 } ) ) ;
}
}
2020-04-16 07:09:04 +02:00
NDArray dLdxBwd = dLdx - > ulike ( ) ;
2020-04-13 12:21:51 +02:00
2020-04-16 07:09:04 +02:00
// FIXME - following two calls are independent and may run in different streams
2020-04-13 12:21:51 +02:00
helpers : : lstmLayerTimeLoopBp ( x , & WxFwd , & WrFwd , bFwd , seqLen , hIFwd , cIFwd , WpFwd , dLdhFwd , dLdhLFwd , dLdcLFwd , params , true , dLdx , & dLdWxFwd , & dLdWrFwd , dLdbFwd , dLdhIFwd , dLdcIFwd , dLdWpFwd ) ;
helpers : : lstmLayerTimeLoopBp ( x , & WxBwd , & WrBwd , bBwd , seqLen , hIBwd , cIBwd , WpBwd , dLdhBwd , dLdhLBwd , dLdcLBwd , params , false , & dLdxBwd , & dLdWxBwd , & dLdWrBwd , dLdbBwd , dLdhIBwd , dLdcIBwd , dLdWpBwd ) ;
* dLdx + = dLdxBwd ;
delete WpFwd ; delete WpBwd ; delete bFwd ; delete bBwd ; delete hIFwd ; delete hIBwd ; delete cIFwd ; delete cIBwd ;
2020-04-16 07:09:04 +02:00
delete dLdhLFwd ; delete dLdhLBwd ; delete dLdcLFwd ; delete dLdcLBwd ;
2020-04-13 12:21:51 +02:00
delete dLdWpFwd ; delete dLdWpBwd ; delete dLdbFwd ; delete dLdbBwd ;
delete dLdhIFwd ; delete dLdhIBwd ; delete dLdcIFwd ; delete dLdcIBwd ;
2020-04-16 07:09:04 +02:00
if ( ! ( dLdh & & directionMode = = 2 ) ) { delete dLdhFwd ; delete dLdhBwd ; }
2020-04-13 12:21:51 +02:00
}
return Status : : OK ( ) ;
}
DECLARE_TYPES ( lstmLayer_bp ) {
getOpDescriptor ( )
- > setAllowedInputTypes ( sd : : DataType : : ANY )
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
DECLARE_SHAPE_FN ( lstmLayer_bp ) {
const auto hasBiases = B_ARG ( 0 ) ; // indicates whether biases array is provided
const auto hasSeqLen = B_ARG ( 1 ) ; // indicates whether seqLen array is provided
const auto hasInitH = B_ARG ( 2 ) ; // indicates whether initial output is provided
const auto hasInitC = B_ARG ( 3 ) ; // indicates whether initial cell state is provided
const auto hasPH = B_ARG ( 4 ) ; // indicates whether peephole connections are present
int count = 3 ;
const auto x = INPUT_VARIABLE ( 0 ) ; // input
const auto Wx = INPUT_VARIABLE ( 1 ) ; // input weights
const auto Wr = INPUT_VARIABLE ( 2 ) ; // recurrent weights
const auto b = hasBiases ? INPUT_VARIABLE ( count + + ) : nullptr ; // biases
const auto seqLen = hasSeqLen ? INPUT_VARIABLE ( count + + ) : nullptr ; // seqLen vector
const auto hI = hasInitH ? INPUT_VARIABLE ( count + + ) : nullptr ; // initial output
const auto cI = hasInitC ? INPUT_VARIABLE ( count + + ) : nullptr ; // initial cell state
const auto Wp = hasPH ? INPUT_VARIABLE ( count + + ) : nullptr ; // peephole weights
2020-05-09 07:06:14 +02:00
auto outShapes = SHAPELIST ( x - > shapeInfo ( ) , Wx - > shapeInfo ( ) , Wr - > shapeInfo ( ) ) ;
2020-04-13 12:21:51 +02:00
if ( b ! = nullptr )
2020-05-09 07:06:14 +02:00
outShapes - > push_back ( b - > shapeInfo ( ) ) ;
2020-04-13 12:21:51 +02:00
if ( seqLen ! = nullptr )
2020-05-09 07:06:14 +02:00
outShapes - > push_back ( seqLen - > shapeInfo ( ) ) ;
2020-04-13 12:21:51 +02:00
if ( hI ! = nullptr )
2020-05-09 07:06:14 +02:00
outShapes - > push_back ( hI - > shapeInfo ( ) ) ;
2020-04-13 12:21:51 +02:00
if ( cI ! = nullptr )
2020-05-09 07:06:14 +02:00
outShapes - > push_back ( cI - > shapeInfo ( ) ) ;
2020-04-13 12:21:51 +02:00
if ( Wp ! = nullptr )
2020-05-09 07:06:14 +02:00
outShapes - > push_back ( Wp - > shapeInfo ( ) ) ;
2020-04-13 12:21:51 +02:00
2020-05-09 07:06:14 +02:00
return outShapes ;
2020-04-13 12:21:51 +02:00
}
2019-10-17 19:44:52 +02:00
}
}
# endif