2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								/*******************************************************************************
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  Copyright  ( c )  2015 - 2018  Skymind ,  Inc . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  This  program  and  the  accompanying  materials  are  made  available  under  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  terms  of  the  Apache  License ,  Version  2.0  which  is  available  at 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  https : //www.apache.org/licenses/LICENSE-2.0.
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  Unless  required  by  applicable  law  or  agreed  to  in  writing ,  software 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  distributed  under  the  License  is  distributed  on  an  " AS IS "  BASIS ,  WITHOUT 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  WARRANTIES  OR  CONDITIONS  OF  ANY  KIND ,  either  express  or  implied .  See  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  License  for  the  specific  language  governing  permissions  and  limitations 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  under  the  License . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  SPDX - License - Identifier :  Apache - 2.0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// implementation of operations for Simple Recurrent Unit: arXiv:1709.02755v2 [cs.CL] 12 Sep 2017
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//@author Yurii Shyrma
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								# include  <system/op_boilerplate.h> 
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								# if NOT_EXCLUDED(OP_sru) 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# include  <ops/declarable/CustomOperations.h> 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# include  <ops/declarable/helpers/sru.h> 
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								# include  <helpers/MmulHelper.h> 
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								# include  <helpers/PointersManager.h> 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								namespace  sd  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								namespace  ops   { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								CUSTOM_OP_IMPL ( sru ,  5 ,  2 ,  false ,  0 ,  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  x     =  INPUT_VARIABLE ( 0 ) ;                                    // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  w     =  INPUT_VARIABLE ( 1 ) ;                                    // W, 2d tensor of weights [3*inSize x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  b     =  INPUT_VARIABLE ( 2 ) ;                                    // B, row of biases with twice length [2*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  c0    =  INPUT_VARIABLE ( 3 ) ;                                    // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  mask  =  block . width ( )  >  4  ?  INPUT_VARIABLE ( 4 )  :  nullptr ;      // optional,  2d tensor of dropout mask [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  h  =  OUTPUT_VARIABLE ( 0 ) ;                                      // cell outputs, [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  c  =  OUTPUT_VARIABLE ( 1 ) ;                                      // cell states,  [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  rank    =  x - > rankOf ( ) ;               // = 3
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  auto  bS      =  x - > sizeAt ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  auto  inSize  =  x - > sizeAt ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  auto  time    =  x - > sizeAt ( 2 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // input shapes validation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( w - > rankOf ( )   = =  rank - 1 ,  0 ,  " SRU operation: wrong rank of weights array, expected is %i, but got %i instead ! " ,  rank - 1 ,  w - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( b - > rankOf ( )   = =  1 ,       0 ,  " SRU operation: wrong rank of biases  array, expected is %i, but got %i instead ! " ,  1 ,  b - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0 - > rankOf ( )  = =  rank - 1 ,  0 ,  " SRU operation: wrong rank of initial state array, expected is %i, but got %i instead ! " ,  rank - 1 ,  c0 - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( mask - > rankOf ( )  = =  rank - 1 ,  0 ,  " SRU operation: wrong rank of mask array, expected is %i, but got %i instead ! " ,  rank - 1 ,  mask - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  wCorrectShape   =  { 3 * inSize ,  inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  bCorrectShape   =  { 2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  c0CorrectShape  =  { bS ,  inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( w - > isSameShape ( wCorrectShape ) ,   0 ,  " SRU operation: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( wCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( w ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( b - > isSameShape ( bCorrectShape ) ,   0 ,  " SRU operation: wrong shape of biases  array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( bCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0 - > isSameShape ( c0CorrectShape ) ,  0 ,  " SRU operation: wrong shape of initial state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( c0 ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( mask - > isSameShape ( c0CorrectShape ) ,  0 ,  " SRU operation: wrong shape of mask array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( mask ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    //  xm = x * mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  xm  =  x ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask )  { 
							 
						 
					
						
							
								
									
										
										
										
											2020-05-09 08:06:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        xm  =  new  NDArray ( x - > shapeInfo ( ) ,  true ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        x - > applyBroadcast ( broadcast : : Multiply ,  { 0 ,  1 } ,  * mask ,  * xm ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // time loop
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    helpers : : sruTimeLoop ( block . launchContext ( ) ,  xm ,  c0 ,  w ,  b ,  h ,  c ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        delete  xm ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  Status : : OK ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        DECLARE_TYPES ( sru )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            getOpDescriptor ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								                    - > setAllowedInputTypes ( sd : : DataType : : ANY ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								                    - > setAllowedOutputTypes ( { ALL_FLOATS } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								DECLARE_SHAPE_FN ( sru )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  xShapeInfo     =  inputShape - > at ( 0 ) ;                                    // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  wShapeInfo     =  inputShape - > at ( 1 ) ;                                    // W, 2d tensor of weights [3*inSize x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  bShapeInfo     =  inputShape - > at ( 2 ) ;                                    // B, row of biases with twice length [2*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  c0ShapeInfo    =  inputShape - > at ( 3 ) ;                                    // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
 
							 
						 
					
						
							
								
									
										
										
										
											2020-05-09 08:06:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  maskShapeInfo  =  block . width ( )  >  4  ?  inputShape - > at ( 4 )  :  nullptr ;      // optional,  2d tensor of dropout mask [bS x inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  rank    =  xShapeInfo [ 0 ] ;               // = 3
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  bS      =  xShapeInfo [ 1 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  inSize  =  xShapeInfo [ 2 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  time    =  xShapeInfo [ 3 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // input shapes validation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( wShapeInfo [ 0 ]   = =  rank - 1 ,  0 ,  " SRU operation: wrong rank of weights array, expected is %i, but got %i instead ! " ,  rank - 1 ,  wShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( bShapeInfo [ 0 ]   = =  1 ,       0 ,  " SRU operation: wrong rank of biases  array, expected is %i, but got %i instead ! " ,  1 ,  bShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0ShapeInfo [ 0 ]  = =  rank - 1 ,  0 ,  " SRU operation: wrong rank of initial state array, expected is %i, but got %i instead ! " ,  rank - 1 ,  c0ShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( maskShapeInfo ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( maskShapeInfo [ 0 ]  = =  rank - 1 ,  0 ,  " SRU operation: wrong rank of mask array, expected is %i, but got %i instead ! " ,  rank - 1 ,  maskShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  wCorrectShape   =  { 3 * inSize ,  inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  bCorrectShape   =  { 2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  c0CorrectShape  =  { bS ,  inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( wShapeInfo ,   wCorrectShape ) ,   0 ,  " SRU operation: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( wCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( wShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( bShapeInfo ,   bCorrectShape ) ,   0 ,  " SRU operation: wrong shape of biases  array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( bCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( bShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( c0ShapeInfo ,  c0CorrectShape ) ,  0 ,  " SRU operation: wrong shape of initial state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( c0ShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( maskShapeInfo ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( maskShapeInfo ,  c0CorrectShape ) ,  0 ,  " SRU operation: wrong shape of mask array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( maskShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    Nd4jLong *  newShapeInfo1  =  nullptr ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ALLOCATE ( newShapeInfo1 ,  block . getWorkspace ( ) ,  shape : : shapeInfoLength ( rank ) ,  Nd4jLong ) ;        // [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    newShapeInfo1 [ 0 ]  =  rank ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    newShapeInfo1 [ 1 ]  =  bS ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    newShapeInfo1 [ 2 ]  =  inSize ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    newShapeInfo1 [ 3 ]  =  time ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeUtils : : updateStridesAndType ( newShapeInfo1 ,  xShapeInfo ,  shape : : order ( xShapeInfo ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor ( newShapeInfo1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    RELEASE ( newShapeInfo1 ,  block . getWorkspace ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2020-06-06 15:26:55 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  result  =  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    return  SHAPELIST ( result ,  result ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								CUSTOM_OP_IMPL ( sru_bp ,  8 ,  4 ,  true ,  0 ,  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  x         =  INPUT_VARIABLE ( 0 ) ;                 // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  w         =  INPUT_VARIABLE ( 1 ) ;                 // W, 2d tensor of weights [3K x K]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  b         =  INPUT_VARIABLE ( 2 ) ;                 // B, row of biases with twice length [1 x 2*K]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  c0        =  INPUT_VARIABLE ( 3 ) ;                 // C_{0}, 2d tensor of initial state [bS x K] at time t=0
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  c         =  INPUT_VARIABLE ( 4 ) ;                 // C, [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inGradCt  =  INPUT_VARIABLE ( 5 ) ;                 // [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inGradH   =  INPUT_VARIABLE ( 6 ) ;                 // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    NDArray *  mask      =  nullptr ;                       // optional,  2d tensor of dropout mask [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    bool  applyMask  =  false ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    if  ( block . width ( )  >  7 )  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        mask  =  INPUT_VARIABLE ( 7 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        applyMask  =  true ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradX     =  OUTPUT_VARIABLE ( 0 ) ;               // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradW     =  OUTPUT_VARIABLE ( 1 ) ;               // [bS x 3K x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradB     =  OUTPUT_VARIABLE ( 2 ) ;               // [1 x 2K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradInit  =  OUTPUT_VARIABLE ( 3 ) ;               // [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  bS       =  x - > shapeOf ( ) [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  K        =  x - > shapeOf ( ) [ 1 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  N        =  x - > shapeOf ( ) [ 2 ] ;                      // N - number of time steps
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradBias  =  NDArrayFactory : : create_ ( x - > ordering ( ) ,  { bS ,  2 * K ,  N } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradU     =  NDArrayFactory : : create_ ( x - > ordering ( ) ,  { bS ,  3 * K ,  N } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradHX    =  NDArrayFactory : : create_ ( x - > ordering ( ) ,  { bS ,  K ,  N } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gct       =  NDArrayFactory : : create_ ( c - > ordering ( ) ,  { bS ,  K } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradTanh  =  NDArrayFactory : : create_ ( c - > ordering ( ) ,  { bS ,  K } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradCt    =  NDArrayFactory : : create_ ( c - > ordering ( ) ,  { bS ,  K } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  ftMinus   =  NDArrayFactory : : create_ ( c - > ordering ( ) ,  { bS ,  K } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  rtMinus   =  NDArrayFactory : : create_ ( c - > ordering ( ) ,  { bS ,  K } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  temp1     =  NDArrayFactory : : create_ ( c - > ordering ( ) ,  { bS ,  K } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  temp2     =  NDArrayFactory : : create_ ( c - > ordering ( ) ,  { bS ,  K } ,  gradX - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    //  x = x * mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( applyMask ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        x - > applyBroadcast ( broadcast : : Multiply ,  { 0 ,  1 } ,  * mask ,  * x ) ;             // apply mask
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    // multiplication matrix wi = matmul(w,x), U = WX
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  wi  =  MmulHelper : : mmul ( w ,  x ,  nullptr ,  1. ,  0. ) ;       // U [bS x 3K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  wiZ  =  ( * wi ) ( { 0 , 0 ,   0 , K ,      0 , 0 } ,  true ) ;            // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  wiF  =  ( * wi ) ( { 0 , 0 ,   K , 2 * K ,    0 , 0 } ,  true ) ;            // forget gate [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  wiR  =  ( * wi ) ( { 0 , 0 ,   2 * K , 3 * K ,  0 , 0 } ,  true ) ;            // reset gate [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  bF   =  ( * b )  ( { 0 , 0 ,   0 , K   } ,  true ) ;                   // biases for forget gate [1 x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  bR   =  ( * b )  ( { 0 , 0 ,   K , 2 * K } ,  true ) ;                   // biases for reset gate [1 x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradBF  =  ( * gradBias ) ( { 0 , 0 ,   0 , K ,      0 , 0 } ,  true ) ;   // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradBR  =  ( * gradBias ) ( { 0 , 0 ,   K , 2 * K ,    0 , 0 } ,  true ) ;   // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradUZ  =  ( * gradU )    ( { 0 , 0 ,   0 , K ,      0 , 0 } ,  true  ) ;  // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradUF  =  ( * gradU )    ( { 0 , 0 ,   K , 2 * K ,    0 , 0 } ,  true  ) ;  // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradUR  =  ( * gradU )    ( { 0 , 0 ,   2 * K , 3 * K ,  0 , 0 } ,  true  ) ;  // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    NDArray *   ct_1  =  nullptr ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    std : : vector < Nd4jLong >  idx  =  { 0 , 0 ,  0 , 0 ,  0 , 0 } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    for  ( int  t  =  N - 1 ;  t  > = 0  ;  - - t )  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        // initialization
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        idx [ 4 ]  =  t ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        idx [ 5 ]  =  t  +  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  xt  =  ( * x ) ( idx ) ;                 // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  zt  =  wiZ ( idx ) ;                  // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  ft  =  wiF ( idx ) ;                  // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  rt  =  wiR ( idx ) ;                  // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  ct  =  ( * c ) ( idx ) ;                 // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  inGradHt  =  ( * inGradH ) ( idx ) ;     // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  gradBRt   =  gradBR ( idx ) ;         // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  gradBFt   =  gradBF ( idx ) ;         // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  gradHXt   =  ( * gradHX ) ( idx ) ;      // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  gradUZt   =  gradUZ ( idx ) ;         // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  gradUFt   =  gradUF ( idx ) ;         // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  gradURt   =  gradUR ( idx ) ;         // [bS x K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        if ( t  ! =  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            idx [ 4 ]  =  t  -  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            idx [ 5 ]  =  t ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ct_1   =  new  NDArray ( ( * c ) ( idx ) ) ;         // previous c_{t-1}
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        else 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            ct_1  =  c0 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        ///////////////// forward
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        ft . addRowVector ( bF ,  ft ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        rt . addRowVector ( bR ,  rt ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        ft . applyTransform ( transform : : Sigmoid ,  ft ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        rt . applyTransform ( transform : : Sigmoid ,  rt ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        ct . applyTransform ( transform : : Tanh ,  * gct ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        // ftMinus = 1-ft,  rtMinus = 1-rt
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        ft . applyTransform ( transform : : OneMinus ,  * ftMinus ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        rt . applyTransform ( transform : : OneMinus ,  * rtMinus ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        ///////////////// backward
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        gct - > applyPairwiseTransform ( pairwise : : Subtract ,  xt ,  * temp1 ) ;                  // temp1 = (g_ct - xt)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        rtMinus - > applyPairwiseTransform ( pairwise : : Multiply ,  rt ,  * temp2 ) ;              // temp2 = (1.0f - rt) * rt;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        temp1 - > applyPairwiseTransform ( pairwise : : Multiply ,  * temp2 ) ;                    // temp1 = (g_ct - xt) * (1.0f - rt) * rt;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        inGradHt . applyPairwiseTransform ( pairwise : : Multiply ,  * temp1 ,  gradBRt ) ;        // = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // bF, TODO - tanh
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // gradTanh = (1.0f - g_ct * g_ct);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        gct - > applyPairwiseTransform ( pairwise : : Multiply ,  * gct ,  * gradTanh ) ;              // gradTanh = g_ct * g_ct
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        gradTanh - > applyTransform ( transform : : OneMinus ,  * gradTanh ) ;                             // gradTanh = (1.0f - g_ct * g_ct)
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        // gradCt  = inGradHt * rt * gradTanh
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        rt . applyPairwiseTransform ( pairwise : : Multiply ,  * gradTanh ,  * gradCt ) ;            // gradCt = rt * gradTanh
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        inGradHt . applyPairwiseTransform ( pairwise : : Multiply ,  * gradCt ,  * gradCt ) ;        // gradCt = inGradHt * rt * gradTanh
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        gradCt - > applyPairwiseTransform ( pairwise : : Add ,  * inGradCt ,  * temp1 ) ;               // temp1 = (gradCt + inGradCt)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        ct_1 - > applyPairwiseTransform ( pairwise : : Subtract ,  zt ,  * temp2 ) ;                 // temp2 = (ct_1 - zt)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        temp1 - > applyPairwiseTransform ( pairwise : : Multiply ,  * ftMinus ,  * temp1 ) ;           // temp1 = (gradCt + inGradCt)*(1-ft)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        temp1 - > applyPairwiseTransform ( pairwise : : Multiply ,  ft ,  * temp1 ) ;                // temp1 = (gradCt + inGradCt)*(1-ft)*ft
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        temp1 - > applyPairwiseTransform ( pairwise : : Multiply ,  * temp2 ,  gradBFt ) ;           // gradBFt = (gradCt + inGradCt) * (ct_1 - zt) * (1 - ft) * ft;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        inGradHt . applyPairwiseTransform ( pairwise : : Multiply ,  * rtMinus ,  gradHXt ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        rt . applyPairwiseTransform ( pairwise : : Multiply ,  * gradTanh ,  * temp1 ) ;         // temp1 = rt * grad_tanh
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        inGradHt . applyPairwiseTransform ( pairwise : : Multiply ,  * temp1 ,  * temp1 ) ;      // temp1 = inGradHt * rt * grad_tanh
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        temp1 - > applyPairwiseTransform ( pairwise : : Add ,  * inGradCt ,  * temp1 ) ;          // temp1 = inGradHt * rt * grad_tanh + inGradCt
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        temp1 - > applyPairwiseTransform ( pairwise : : Multiply ,  * ftMinus ,  gradUZt ) ;     // gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        gradUFt . assign ( & gradBFt ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        gradURt . assign ( & gradBRt ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        gradCt - > applyPairwiseTransform ( pairwise : : Add ,  * inGradCt ,  * temp1 ) ;          // temp1 = (gradCt + inGradCt)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        temp1 - > applyPairwiseTransform ( pairwise : : Multiply ,  ft ,  * inGradCt ) ;        // inGradCt = (gradCt + inGradCt) * ft;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        if ( t  ! =  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            delete  ct_1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // gradInit
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    gradInit - > assign ( inGradCt ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    // gradX
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  weightsT  =  w - > transpose ( ) ;                                             // [K x 3K]
 
							 
						 
					
						
							
								
									
										
											 
										 
										
											
												Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
											 
										 
										
											2019-06-28 01:37:04 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    MmulHelper : : mmul ( & weightsT ,  gradU ,  gradX ,  1. ,  0. ) ;                     // [bS x K x N]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    gradX - > applyPairwiseTransform ( pairwise : : Add ,  * gradHX ,  * gradX ) ;         // + grad_highway_x
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    if ( applyMask ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        gradX - > applyBroadcast ( broadcast : : Multiply ,  { 0 , 1 } ,  * mask ,  * gradX ) ;   // apply mask
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    // gradB
 
							 
						 
					
						
							
								
									
										
										
										
											2020-07-26 21:59:27 +09:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  gradB2  =  gradB - > reshape ( gradB - > ordering ( ) ,  { 2 * K } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    gradBias - > reduceAlongDimension ( reduce : : Sum ,  gradB2 ,  { 0 , 2 } ) ;     // [1 x 2K]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // gradW [bS x 3K x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x - > permutei ( { 0 ,  2 ,  1 } ) ;                                                // [bS x N x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    MmulHelper : : mmul ( gradU ,  x ,  gradW ,  1. ,  0. ) ;           // [bS x 3K x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    delete  gct ;    delete  gradU ;  delete  gradHX ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    delete  temp1 ;  delete  temp2 ;  delete  gradCt ;  delete  wi ; 
							 
						 
					
						
							
								
									
										
											 
										 
										
											
												Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
											 
										 
										
											2019-06-28 01:37:04 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    delete  gradTanh ;  delete  ftMinus ;  delete  rtMinus ;  delete  gradBias ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    return  Status : : OK ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        DECLARE_TYPES ( sru_bp )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            getOpDescriptor ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								                    - > setAllowedInputTypes ( sd : : DataType : : ANY ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								                    - > setAllowedOutputTypes ( { ALL_FLOATS } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								DECLARE_SHAPE_FN ( sru_bp )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inShape  =  inputShape - > at ( 0 ) ;    // [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  bS    =  inShape [ 1 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inSize     =  inShape [ 2 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  time     =  inShape [ 3 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    char  order  =  ( char ) ( inShape [ 9 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor1 ( ArrayOptions : : dataType ( inShape ) ,  order ,  { bS ,  inSize ,  time } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor2 ( ArrayOptions : : dataType ( inShape ) ,  order ,  { bS ,  3  *  inSize ,  inSize } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor3 ( ArrayOptions : : dataType ( inShape ) ,  order ,  { 1 ,  2  *  inSize } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor4 ( ArrayOptions : : dataType ( inShape ) ,  order ,  { bS ,  inSize } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-06-06 15:26:55 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    return  SHAPELIST ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor1 ) ,  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor2 ) ,  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor3 ) ,  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor4 ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								CUSTOM_OP_IMPL ( sru_bi ,  5 ,  2 ,  true ,  0 ,  0 )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  x   =  INPUT_VARIABLE ( 0 ) ;                                       // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  w   =  INPUT_VARIABLE ( 1 ) ;                                       // W, 2d tensor of weights [2*inSize x 6*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  b   =  INPUT_VARIABLE ( 2 ) ;                                       // B, row of biases with twice length [1 x 4*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  c0  =  INPUT_VARIABLE ( 3 ) ;                                       // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    NDArray *  mask  =  block . width ( )  >  4  ?  INPUT_VARIABLE ( 4 )  :  nullptr ;   // optional, 2d tensor of dropout mask [bS x 2*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  ht  =  OUTPUT_VARIABLE ( 0 ) ;              // h_t, [time x bS x 2*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  ct  =  OUTPUT_VARIABLE ( 1 ) ;              // c_t, [time x bS x 2*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // input shapes validation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  rank  =  x - > rankOf ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  bS      =  x - > sizeAt ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  inSize  =  x - > sizeAt ( 2 )  /  2 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( x - > rankOf ( )   = =  rank ,    0 ,  " SRU_BI operation: wrong rank of input array, expected is %i, but got %i instead ! " ,  rank ,  x - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( w - > rankOf ( )   = =  rank - 1 ,  0 ,  " SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead ! " ,  rank - 1 ,  w - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( b - > rankOf ( )   = =  1 ,       0 ,  " SRU_BI operation: wrong rank of biases array, expected is 1, but got %i instead ! " ,  b - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0 - > rankOf ( )  = =  rank - 1 ,  0 ,  " SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead ! " ,  rank - 1 ,  c0 - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( mask - > rankOf ( )  = =  rank - 1 ,  0 ,  " SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead ! " ,  rank - 1 ,  mask - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  wCorrectShape   =  { 2 * inSize ,  6 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  bCorrectShape   =  { 4 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  c0CorrectShape  =  { bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( w - > isSameShape ( wCorrectShape ) ,   0 ,  " SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( wCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( w ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( b - > isSameShape ( bCorrectShape ) ,   0 ,  " SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( bCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0 - > isSameShape ( c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( c0 ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( mask - > isSameShape ( c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( mask ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    helpers : : sruBI ( block . launchContext ( ) ,  x ,  w ,  b ,  c0 ,  mask ,  ht ,  ct ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    return  Status : : OK ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        DECLARE_TYPES ( sru_bi )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            getOpDescriptor ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								                    - > setAllowedInputTypes ( sd : : DataType : : ANY ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								                    - > setAllowedOutputTypes ( { ALL_FLOATS } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								DECLARE_SHAPE_FN ( sru_bi )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  xShapeInfo     =  inputShape - > at ( 0 ) ;          // [time x bS x 2K ]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  wShapeInfo     =  inputShape - > at ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  bShapeInfo     =  inputShape - > at ( 2 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  c0ShapeInfo    =  inputShape - > at ( 3 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2020-05-09 08:06:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  maskShapeInfo  =  block . width ( )  >  4  ?  inputShape - > at ( 4 )  :  nullptr ;      // optional,  2d tensor of dropout mask [bS x inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int       rank    =  xShapeInfo [ 0 ] ;               // = 3
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  time    =  xShapeInfo [ 1 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  bS      =  xShapeInfo [ 2 ] ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  inSize  =  xShapeInfo [ 3 ]  /  2 ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								      // input shapes validation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( wShapeInfo [ 0 ]   = =  rank - 1 ,  0 ,  " SRU_BI operation: wrong rank of weights array, expected is %i, but got %i instead ! " ,  rank - 1 ,  wShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( bShapeInfo [ 0 ]   = =  1 ,       0 ,  " SRU_BI operation: wrong rank of biases  array, expected is 1, but got %i instead ! " ,  bShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0ShapeInfo [ 0 ]  = =  rank - 1 ,  0 ,  " SRU_BI operation: wrong rank of initial state array, expected is %i, but got %i instead ! " ,  rank - 1 ,  c0ShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( maskShapeInfo ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( maskShapeInfo [ 0 ]  = =  rank - 1 ,  0 ,  " SRU_BI operation: wrong rank of mask array, expected is %i, but got %i instead ! " ,  rank - 1 ,  maskShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  wCorrectShape   =  { 2 * inSize ,  6 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  bCorrectShape   =  { 4 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  c0CorrectShape  =  { bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( wShapeInfo ,  wCorrectShape ) ,   0 ,  " SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( wCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( wShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( bShapeInfo ,  bCorrectShape ) ,   0 ,  " SRU_BI operation: wrong shape of biases array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( bCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( bShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( c0ShapeInfo ,  c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( c0ShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( maskShapeInfo ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( maskShapeInfo ,  c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( maskShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    char  order  =  shape : : order ( xShapeInfo ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor ( ArrayOptions : : dataType ( xShapeInfo ) ,  order ,  { time ,  bS ,  2  *  inSize } ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2020-06-06 15:26:55 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  result  =  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    return  SHAPELIST ( result ,  result ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        DECLARE_TYPES ( sru_bi_bp )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            getOpDescriptor ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								                    - > setAllowedInputTypes ( sd : : DataType : : ANY ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								                    - > setAllowedOutputTypes ( { ALL_FLOATS } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								CUSTOM_OP_IMPL ( sru_bi_bp ,  8 ,  4 ,  true ,  0 ,  0 )  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  x         =  INPUT_VARIABLE ( 0 ) ;                 // X, input 3d tensor [time x bS x 2*inSize], time - number of time steps, bS - batch size, inSize - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  w         =  INPUT_VARIABLE ( 1 ) ;                 // W, 2d tensor of weights [2*inSize x 6*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  b         =  INPUT_VARIABLE ( 2 ) ;                 // B, row of biases with twice length [4*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  c0        =  INPUT_VARIABLE ( 3 ) ;                 // C_{0}, 2d tensor of initial state [bS x 2*inSize] at time t=0
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  ct        =  INPUT_VARIABLE ( 4 ) ;                 // C, [time x bS x 2*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inGradC0  =  INPUT_VARIABLE ( 5 ) ;                 // [bS x 2*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inGradHt  =  INPUT_VARIABLE ( 6 ) ;                 // [time x bS x 2*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    NDArray *  mask  =  block . width ( )  >  7  ?  INPUT_VARIABLE ( 7 )  :  nullptr ;   // optional,  2d tensor of dropout mask [bS x 2*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // input shapes validation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  rank  =  x - > rankOf ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  time    =  x - > sizeAt ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  bS      =  x - > sizeAt ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  inSize  =  x - > sizeAt ( 2 )  /  2 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( w - > rankOf ( )         = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead ! " ,  rank - 1 ,  w - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( b - > rankOf ( )         = =  1 ,       0 ,  " SRU_BI_BP operation: wrong rank of biases array, expected is 1, but got %i instead ! " ,  b - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0 - > rankOf ( )        = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead ! " ,  rank - 1 ,  c0 - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ct - > rankOf ( )        = =  rank ,    0 ,  " SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead ! " ,  rank ,  ct - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( inGradC0 - > rankOf ( )  = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead ! " ,  rank - 1 ,  inGradC0 - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( inGradHt - > rankOf ( )  = =  rank ,    0 ,  " SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead ! " ,  rank ,  inGradHt - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( mask - > rankOf ( )  = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead ! " ,  rank - 1 ,  mask - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  wCorrectShape   =  { 2 * inSize ,  6 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  bCorrectShape   =  { 4 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  c0CorrectShape  =  { bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  ctCorrectShape  =  { time ,  bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( w - > isSameShape ( wCorrectShape ) ,   0 ,  " SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( wCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( w ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( b - > isSameShape ( bCorrectShape ) ,   0 ,  " SRU_BI operation: wrong shape of biases  array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( bCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0 - > isSameShape ( c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( c0 ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ct - > isSameShape ( ctCorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( ctCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( ct ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( mask - > isSameShape ( c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( mask ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradI   =  OUTPUT_VARIABLE ( 0 ) ;               // [time x bS x 2*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradW   =  OUTPUT_VARIABLE ( 1 ) ;               // [time x 2*inSize x 6*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradB   =  OUTPUT_VARIABLE ( 2 ) ;               // [1 x 4*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  gradC0  =  OUTPUT_VARIABLE ( 3 ) ;               // [bS x 2*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    helpers : : sruBIBP ( block . launchContext ( ) ,  x ,  w ,  b ,  c0 ,  ct ,  inGradC0 ,  inGradHt ,  mask ,  gradI ,  gradW ,  gradB ,  gradC0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  Status : : OK ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								DECLARE_SHAPE_FN ( sru_bi_bp )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  xShapeInfo         =  inputShape - > at ( 0 ) ;          // [time x bS x 2K ]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  wShapeInfo         =  inputShape - > at ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  bShapeInfo         =  inputShape - > at ( 2 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  c0ShapeInfo        =  inputShape - > at ( 3 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  ctShapeInfo        =  inputShape - > at ( 4 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inGradC0ShapeInfo  =  inputShape - > at ( 5 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  inGradHtShapeInfo  =  inputShape - > at ( 6 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2020-05-09 08:06:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    auto  maskShapeInfo  =  block . width ( )  >  7  ?  inputShape - > at ( 7 )  :  nullptr ;      // optional,  2d tensor of dropout mask [bS x inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // input shapes validation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  rank         =  xShapeInfo [ 0 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  time    =  xShapeInfo [ 1 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  bS      =  xShapeInfo [ 2 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  Nd4jLong  inSize  =  xShapeInfo [ 3 ]  /  2 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( wShapeInfo [ 0 ]         = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of weights array, expected is %i, but got %i instead ! " ,  rank - 1 ,  wShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( bShapeInfo [ 0 ]         = =  1 ,       0 ,  " SRU_BI_BP operation: wrong rank of biases  array, expected is 1, but got %i instead ! " ,  bShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( c0ShapeInfo [ 0 ]        = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of initial state array, expected is %i, but got %i instead ! " ,  rank - 1 ,  c0ShapeInfo ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ctShapeInfo [ 0 ]        = =  rank ,    0 ,  " SRU_BI_BP operation: wrong rank of state array, expected is %i, but got %i instead ! " ,  rank ,  ctShapeInfo ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( inGradC0ShapeInfo [ 0 ]  = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of gradient c0, expected is %i, but got %i instead ! " ,  rank - 1 ,  inGradC0ShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( inGradHtShapeInfo [ 0 ]  = =  rank ,    0 ,  " SRU_BI_BP operation: wrong rank of gradient ht, expected is %i, but got %i instead ! " ,  rank ,  inGradHtShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( maskShapeInfo ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( maskShapeInfo [ 0 ]  = =  rank - 1 ,  0 ,  " SRU_BI_BP operation: wrong rank of mask array, expected is %i, but got %i instead ! " ,  rank - 1 ,  maskShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  wCorrectShape         =  { 2 * inSize ,  6 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  bCorrectShape         =  { 4 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  c0CorrectShape        =  { bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  ctCorrectShape        =  { time ,  bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  inGradC0CorrectShape  =  { bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  std : : vector < Nd4jLong >  inGradHtCorrectShape  =  { time ,  bS ,  2 * inSize } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( wShapeInfo ,  wCorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( wCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( wShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( bShapeInfo ,  bCorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of biases  array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( bCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( bShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( c0ShapeInfo ,  c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of initial state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( c0ShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( ctShapeInfo ,  ctCorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of state array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( ctCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( ctShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( inGradC0ShapeInfo ,  inGradC0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of gradient c0 array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( inGradC0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( inGradC0ShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( inGradHtShapeInfo ,  inGradHtCorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of gradient ht array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( inGradHtCorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( inGradHtShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( maskShapeInfo ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( maskShapeInfo ,  c0CorrectShape ) ,  0 ,  " SRU_BI operation: wrong shape of mask array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( c0CorrectShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( maskShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  char  order  =  shape : : order ( xShapeInfo ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor1 ( ArrayOptions : : dataType ( xShapeInfo ) ,  order ,  { time ,  bS ,  2  *  inSize } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor2 ( ArrayOptions : : dataType ( xShapeInfo ) ,  order ,  { time ,  2  *  inSize ,  6  *  inSize } ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor3 ( ArrayOptions : : dataType ( xShapeInfo ) ,  order ,  { 4  *  inSize } ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    ShapeDescriptor  descriptor4 ( ArrayOptions : : dataType ( xShapeInfo ) ,  order ,  { bS ,  2  *  inSize } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-06-06 15:26:55 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    return  SHAPELIST ( ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor1 ) ,  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor2 ) ,  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor3 ) ,  ConstantShapeHelper : : getInstance ( ) . createShapeInfo ( descriptor4 ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# endif 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    /**
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Implementation  of  operations  for  Simple  Recurrent  Unit :  " Training RNNs as Fast as CNNs "  Tao  Lei ,  Yu  Zhang ,  Yoav  Artzi 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Input  arrays : 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     0 :  input  3 d  tensor  with  shape  [ bS  x  K  x  N ] ,  N  -  number  of  time  steps ,  bS  -  batch  size ,  K  -  number  of  features 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     1 :  2 d  tensor  of  weights  [ 3 K  x  K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       *     2 :  row  of  biases  with  twice  length  [ 1  x  2 K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     3 :  2 d  tensor  of  previous  cell  state  [ bS  x  K ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     4 :  optional ,  2 d  tensor  of  dropout  mask  [ bS  x  K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Output  arrays : 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     0 :  3 d  tensor  of  cell  output  [ bS  x  K  x  N ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     1 :  3 d  tensor  of  cell  state  [ bS  x  K  x  N ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       */ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // #if NOT_EXCLUDED(OP_sru)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // DECLARE_CUSTOM_OP(sru_old,       5, 2, false, 0, 0);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    //////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    /**
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Implementation  of  operation  for  Simple  Recurrent  Unit :  " Training RNNs as Fast as CNNs "  Tao  Lei ,  Yu  Zhang ,  Yoav  Artzi 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Input  arrays : 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     0 :  input  3 d  tensor  with  shape  [ bS  x  K  x  N ] ,  N  -  number  of  time  steps ,  bS  -  batch  size ,  K  -  number  of  features 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     1 :  2 d  tensor  of  weights  [ 3 K  x  K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       *     2 :  row  of  biases  with  twice  length  [ 1  x  2 K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     3 :  2 d  tensor  of  previous  cell  state  [ bS  x  K ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     4 :  optional ,  2 d  tensor  of  dropout  mask  [ bS  x  K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Output  arrays : 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     0 :  3 d  tensor  of  cell  output  [ bS  x  K  x  N ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     1 :  3 d  tensor  of  cell  state  [ bS  x  K  x  N ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       */ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // #if NOT_EXCLUDED(OP_sru_logic)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // DECLARE_CUSTOM_OP(sru_logic,   5, 2, false, 0, 0);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // #endif
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    /**
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Implementation  of  operation  for  back  propagation  in  Simple  Recurrent  Unit :  " Training RNNs as Fast as CNNs "  Tao  Lei ,  Yu  Zhang ,  Yoav  Artzi 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Input  arrays : 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     0 :  input  3 d  tensor  with  shape  [ bS  x  K  x  N ] ,  N  -  number  of  time  steps ,  bS  -  batch  size ,  K  -  number  of  features 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     1 :  2 d  tensor  of  weights  [ 3 K  x  K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       *     2 :  row  of  biases  with  twice  length  [ 1  x  2 K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     3 :  2 d  tensor  of  previous  cell  state  [ bS  x  K ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     4 :  3 d  tensor  of  cell  state  [ bS  x  K  x  N ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     5 :  2 d  tensor  of  cell  state  gradients  [ bS  x  K ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     6 :  3 d  tensor  of  state  output  gradients  [ bS  x  K  x  N ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     7 :  optional ,  2 d  tensor  of  dropout  mask  [ bS  x  K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *  Output  arrays : 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								       *     0 :  3 d  tensor  of  input  gradients  [ bS  x  K  x  N ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     1 :  3 d  tensor  of  weights  gradients  [ bS  x  3 K  x  K ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     2 :  2 d ,  row  of  biases  gradients  [ 1  x  2 K ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								       *     3 :  2 d ,  tensor  of  state  gradients  [ bS  x  K ] 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								       */ 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        // #if NOT_EXCLUDED(OP_sru_logic)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // DECLARE_CUSTOM_OP(sru_bp_logic,8, 4, true,  0, 0);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        // #endif
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// return 2d array evaluated though last dimension interval t1-t2
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// static NDArray* timestep(const NDArray* arr, const int t1, const int t2) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray* result = new NDArray((*arr)({0,0, 0,0, t1,t2}, true));
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         result->reshapei(result->ordering(), {arr->shapeOf()[0], arr->shapeOf()[1]} );
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         return result;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								/////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								// CUSTOM_OP_IMPL(sru_logic, 5, 2, false, 0, 0) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     auto input   = INPUT_VARIABLE(0);                // X, input 3d tensor [bS x K x N], N - number of time steps, bS - batch size, K - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto weights = INPUT_VARIABLE(1);                // W, 2d tensor of weights [3K x K]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     auto bias    = INPUT_VARIABLE(2);                // B, row of biases with twice length [1 x 2*K]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     auto init    = INPUT_VARIABLE(3);                // C_{0}, 2d tensor of initial state [bS x K] at time t=0
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray* mask    = nullptr;                          // optional,  2d tensor of dropout mask [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     bool applyMask = false;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     if (block.width() > 4) {
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//         mask = INPUT_VARIABLE(4);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//         applyMask = true;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto output = OUTPUT_VARIABLE(0);                // h_t, [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto state  = OUTPUT_VARIABLE(1);                // c_t, [bS x K x N]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     const int bS     = input->shapeOf()[0];                     // bS - batch size
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int K      = input->shapeOf()[1];                     // K - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int N      = input->shapeOf()[2];                     // N - number of time steps
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     const auto wi = mmul(*weights, *input);                    //  U [bS x 3K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const auto bF = (*bias)({0,0,  0,  K});                       // biases for forget gate [1 x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const auto bR = (*bias)({0,0,  K,2*K});                       // biases for reset  gate [1 x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray xt(input->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray zt(input->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray ft(input->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray rt(input->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray ht(input->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray ct = *init;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray gct(state->ordering(), {bS, K}, input->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray xmt = *input;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     //  input = input * mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     if(applyMask)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         xmt.applyBroadcast(broadcast::Multiply, {0, 1}, mask, &xmt, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     for (int t = 0; t < N; ++t) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         xt = xmt({0,0, 0,0,     t,t+1}); xt.reshapei(xt.ordering(), {bS, K});       // [bS x  K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         zt =  wi({0,0, 0,    K, t,t+1}); zt.reshapei(zt.ordering(), {bS, K});       // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft =  wi({0,0, K,  2*K, t,t+1}); ft.reshapei(ft.ordering(), {bS, K});       // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt =  wi({0,0, 2*K,3*K, t,t+1}); rt.reshapei(rt.ordering(), {bS, K});       // [bS x 3K x N] -> [bS x K x 1] -> [bS x K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft = sigmoid_(ft + bF);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt = sigmoid_(rt + bR);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//         ct = ft * (ct - zt) + zt;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//         // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ct.applyTransform(transform::Tanh, &gct);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ht = rt * (gct - xt) + xt;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // save results
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         (*output)({0,0, 0,0, t,t+1}, true).assign(ht);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         (*state)({0,0, 0,0, t,t+1}, true).assign(ct);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     return Status::OK();
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         DECLARE_TYPES(sru_logic) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//             getOpDescriptor()
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//                     ->setAllowedInputTypes(sd::DataType::ANY)
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//                     ->setAllowedOutputTypes({ALL_FLOATS});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// DECLARE_SHAPE_FN(sru_logic) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto inShape = inputShape->at(0);   // [bS x K x N]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int rank = inShape[0];              // = 3
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int size = rank*2 + 4;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int bS   = inShape[1];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int K    = inShape[2];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int N    = inShape[3];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     char order = (char)(inShape[size-1]);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     Nd4jLong* newShapeInfo1 = nullptr;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[0] = rank;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[1] = bS;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[2] = K;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[3] = N;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto result = CONSTANT(newShapeInfo1);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     return SHAPELIST(result, result);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								// }
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// //////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// CUSTOM_OP_IMPL(sru_old, 5, 2, false, 0, 0) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto x   = INPUT_VARIABLE(0);                // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto w = INPUT_VARIABLE(1);                // W, 2d tensor of weights [3K x inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     auto b    = INPUT_VARIABLE(2);                // B, row of biases with twice length [1 x 2*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     auto c0    = INPUT_VARIABLE(3);                // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray* mask    = nullptr;                          // optional,  2d tensor of dropout mask [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     bool applyMask = false;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     if (block.width() > 4) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         mask = INPUT_VARIABLE(4);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         applyMask = true;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto h = OUTPUT_VARIABLE(0);                // h_t, [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto state  = OUTPUT_VARIABLE(1);                // c_t, [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int bS     = x->shapeOf()[0];                     // bS - batch size
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int inSize      = x->shapeOf()[1];                     // inSize - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int time      = x->shapeOf()[2];                     // time - number of time steps
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//       // multiplication matrix = matmul(w,x)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto wi = MmulHelper::mmul(w, x, nullptr, 1., 0.);            // U [bS x 3K x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto wiZ = (*wi)({0,0,  0,inSize,          0,0}, true);       // [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto wiF = (*wi)({0,0,  inSize,2*inSize,   0,0}, true);       // forget gate [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto wiR = (*wi)({0,0,  2*inSize,3*inSize, 0,0}, true);       // reset gate [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto bF  = (*b) ({0,0,  0,inSize       }, true);              // biases for forget gate [1 x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto bR  = (*b) ({0,0,  inSize,2*inSize}, true);              // biases for reset gate [1 x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray* xt(nullptr), *zt(nullptr), *ft(nullptr), *rt(nullptr), *ct(nullptr), *ht(nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto ct_1 = c0->dup(c0->ordering());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto gct  = NDArrayFactory::create_(state->ordering(), {bS, inSize}, state->dataType(), state->getContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto xmt  = x->dup(x->ordering());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     //  x = x * mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     if(applyMask)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         xmt->applyBroadcast(broadcast::Multiply, {0, 1}, mask, xmt, nullptr);            // apply mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     for (int t = 0; t < time; ++t) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         xt = timestep(xmt, t, t+1);         // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         zt = timestep(&wiZ, t, t+1);        // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft = timestep(&wiF, t, t+1);        // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt = timestep(&wiR, t, t+1);        // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ct = timestep(state, t, t+1);       // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ht = timestep(h, t, t+1);           // [bS x inSize x time] -> [bS x inSize x 1] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft->addRowVector(&bF, ft);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt->addRowVector(&bR, rt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft->applyTransform(transform::Sigmoid, ft, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt->applyTransform(transform::Sigmoid, rt, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // ct = ft * c_t-1 + (1 - ft) * zt,
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft->applyPairwiseTransform(pairwise::Multiply, ct_1, ct, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft->applyTransform(transform::OneMinus, ft);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft->applyPairwiseTransform(pairwise::Multiply, *zt, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ct->applyPairwiseTransform(pairwise::Add, *ft, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ct->applyTransform(transform::Tanh, gct);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // ht = rt * gct + (1 - rt) * xt
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt->applyPairwiseTransform(pairwise::Multiply, gct, ht, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt->applyTransform(transform::OneMinus, rt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         rt->applyPairwiseTransform(pairwise::Multiply, *xt, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ht->applyPairwiseTransform(pairwise::Add, *rt, nullptr);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         delete xt; delete zt; delete ft; delete rt; delete ht; delete ct_1;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ct_1 = ct;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     delete wi; delete ct_1; delete gct; delete xmt;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     return Status::OK();
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         DECLARE_TYPES(sru_old) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//             getOpDescriptor()
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//                     ->setAllowedInputTypes(sd::DataType::ANY)
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//                     ->setAllowedOutputTypes({ALL_FLOATS});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// DECLARE_SHAPE_FN(sru_old) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto inShape = inputShape->at(0);   // [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int rank = inShape[0];              // = 3
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int size = rank*2 + 4;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto bS   = inShape[1];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto inSize    = inShape[2];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     int time    = inShape[3];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     char order = (char)(inShape[size-1]);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     Nd4jLong *newShapeInfo1 = nullptr;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     ALLOCATE(newShapeInfo1, block.getWorkspace(), size, Nd4jLong);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[0] = rank;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[1] = bS;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[2] = inSize;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     newShapeInfo1[3] = time;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     ShapeUtils::updateStridesAndType(newShapeInfo1, inShape, order);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-06-06 15:26:55 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     auto result = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(newShapeInfo1));
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     RELEASE(newShapeInfo1, block.getWorkspace());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     return SHAPELIST(result, result);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// static NDArray sigmoid_(const NDArray& arr) {
 
							 
						 
					
						
							
								
									
										
										
										
											2020-05-09 08:06:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     NDArray result(arr.shapeInfo(), false, arr.getContext());
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     (const_cast<NDArray&>(arr)).applyTransform(transform::Sigmoid, &result);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     return result;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//////////////////////////////////////////////////////////////////////////
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// CUSTOM_OP_IMPL(sru_bp_logic, 8, 4, true, 0, 0) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto x        = INPUT_VARIABLE(0);                                   // X, input 3d tensor [bS x inSize x time], time - number of time steps, bS - batch size, inSize - number of features
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto w        = INPUT_VARIABLE(1);                                   // W, 2d tensor of weights [3*inSize x inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-08 17:58:48 +10:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     auto b        = INPUT_VARIABLE(2);                                   // B, row of biases with twice length [1 x 2*inSize]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     auto c0       = INPUT_VARIABLE(3);                                   // C_{0}, 2d tensor of initial state [bS x inSize] at time t=0
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto c        = INPUT_VARIABLE(4);                                   // C, [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto inGradCt = INPUT_VARIABLE(5);                                   // [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto inGradH  = INPUT_VARIABLE(6);                                   // [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto mask     = block.width() > 7 ? INPUT_VARIABLE(7) : nullptr;     // optional,  2d tensor of dropout mask [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto gradX    = OUTPUT_VARIABLE(0);              // [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto gradW    = OUTPUT_VARIABLE(1);              // [bS x 3*inSize x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto gradB    = OUTPUT_VARIABLE(2);              // [2*inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto gradInit = OUTPUT_VARIABLE(3);              // [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     // input shapes validation
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int rank = 3;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(x->rankOf()  == rank,   0, "SRU_BP operation: wrong rank of input array, expected is %i, but got %i instead !", rank, x->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(w->rankOf()  == rank-1, 0, "SRU_BP operation: wrong rank of weights array, expected is %i, but got %i instead !", rank-1, w->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(b->rankOf()  <= 2,      0, "SRU_BP operation: wrong rank of biases  array, expected is <=2, but got %i instead !", b->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(c0->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of initial state array, expected is %i, but got %i instead !", rank-1, c0->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(c->rankOf()  == rank,   0, "SRU_BP operation: wrong rank of cell states array, expected is %i, but got %i instead !", rank, c->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(inGradCt->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of array of cell state gradient, expected is %i, but got %i instead !", rank-1, inGradCt->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(inGradH->rankOf()  == rank,   0, "SRU_BP operation: wrong rank of array of cell outputs gradients, expected is %i, but got %i instead !", rank, inGradH->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     if(mask)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         REQUIRE_TRUE(mask->rankOf() == rank-1, 0, "SRU_BP operation: wrong rank of mask array, expected is %i, but got %i instead !", rank-1, mask->rankOf());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int bS      = x->shapeOf()[0];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int inSize  = x->shapeOf()[1];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const int time    = x->shapeOf()[2];                     // time - number of time steps
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string wShape               = ShapeUtils::shapeAsString(w);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string wCorrectShape        = ShapeUtils::shapeAsString({3*inSize, inSize});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     // const std::string bShape               = ShapeUtils::shapeAsString(b);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     // const std::string bCorrectShape        = ShapeUtils::shapeAsString({2*inSize});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string c0Shape              = ShapeUtils::shapeAsString(c0);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string c0CorrectShape       = ShapeUtils::shapeAsString({bS, inSize});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string cShape               = ShapeUtils::shapeAsString(c);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string cCorrectShape        = ShapeUtils::shapeAsString({bS, inSize, time});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string inGradCtShape        = ShapeUtils::shapeAsString(inGradCt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string inGradCtCorrectShape = ShapeUtils::shapeAsString({bS, inSize});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string inGradHShape         = ShapeUtils::shapeAsString(inGradH);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const std::string inGradHCorrectShape  = ShapeUtils::shapeAsString({bS, inSize, time});
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(wShape  == wCorrectShape,  0, "SRU_BP operation: wrong shape of weights array, expected is %s, but got %s instead !", wCorrectShape.c_str(), wShape.c_str());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     // REQUIRE_TRUE(bShape  == bCorrectShape,  0, "SRU_BP operation: wrong shape of biases  array, expected is %s, but got %s instead !", bCorrectShape.c_str(), bShape.c_str());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(c0Shape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of initial state array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), c0Shape.c_str());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(cShape == cCorrectShape, 0, "SRU_BP operation: wrong shape of cell states array, expected is %s, but got %s instead !", cCorrectShape.c_str(), cShape.c_str());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(inGradCtShape == inGradCtCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell state gradient, expected is %s, but got %s instead !", inGradCtCorrectShape.c_str(), inGradCtShape.c_str());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     REQUIRE_TRUE(inGradHShape == inGradHCorrectShape, 0, "SRU_BP operation: wrong shape of array of cell outputs gradients, expected is %s, but got %s instead !", inGradHCorrectShape.c_str(), inGradHShape.c_str());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     if(mask) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         const std::string maskShape = ShapeUtils::shapeAsString(mask);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         REQUIRE_TRUE(maskShape == c0CorrectShape, 0, "SRU_BP operation: wrong shape of mask array, expected is %s, but got %s instead !", c0CorrectShape.c_str(), maskShape.c_str());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const auto bF = (*b)({0,0,  0,       inSize});                                 // biases for forget gate [1 x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const auto bR = (*b)({0,0,  inSize,2*inSize});                                 // biases for reset  gate [1 x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray gradBias(x->ordering(),   {bS, 2*inSize, time}, x->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray gradU   (x->ordering(),   {bS, 3*inSize, time}, x->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray gradHX  (x->ordering(),   {bS,   inSize, time}, x->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     NDArray gct     (c->ordering(),   {bS, inSize},         x->dataType(), block.launchContext());
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     //  x = x * mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     if(mask)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         x->applyBroadcast(broadcast::Multiply, {0, 1}, mask, x, nullptr);             // apply mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     // multiplication matrix wi = matmul(w,x), U = WX
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     const auto wi = mmul(*w, *x);                                                   //  U [bS x 3K x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     for (int t = time-1; t >=0 ; --t) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // initialization
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         auto xt =         (*x)({0,0, 0,0,                   t,t+1});    // [bS x inSize  x time] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         auto zt =               wi({0,0, 0,         inSize, t,t+1});    // [bS x 3K x time] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         auto ft =               wi({0,0, inSize,  2*inSize, t,t+1});    // [bS x 3K x time] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         auto rt =               wi({0,0, 2*inSize,3*inSize, t,t+1});    // [bS x 3K x time] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         auto ct =         (*c)({0,0, 0,0,                   t,t+1});    // [bS x inSize  x time] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         auto inGradHt = (*inGradH)({ 0,0, 0,0,              t,t+1});    // [bS x inSize  x time] -> [bS x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         auto ct_1 = t ? (*c)({ 0,0, 0,0, t-1,t}) : *c0;                                                // previous c_{t-1}
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//         ///////////////// forward
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // ft = sigmoid(ft + bf), rt = sigmoid(rt + bR)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ft = sigmoid_(ft + bF);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//         rt = sigmoid_(rt + bR);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//         // TODO T val = (activation_type == 1) ? tanh(cur) : ((activation_type == 2) ? reluf(cur) : cur );
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ct.applyTransform(transform::Tanh, &gct);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         ///////////////// backward
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // bR, *grad_brt_ptr = inGradHt * (g_ct - xt) * (1.0f - rt) * rt;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // ftMinus = -ft + (T)1.;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray ftMinus = 1. - ft;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray rtMinus = 1. - rt;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray gradBRt = inGradHt * (gct - xt) * rtMinus * rt;
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//         // bF, TODO - tanh
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray gradTanh = 1. - gct * gct;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray gradCt = inGradHt * rt * gradTanh;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray gradBFt = (gradCt + *inGradCt) * (ct_1 - zt) * ftMinus * ft;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // x_t (highway connection), gradHXt = inGradHt * (1.0f - rt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray gradHXt = inGradHt * rtMinus;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // U_t, gradUZt = (inGradHt * rt * grad_tanh + inGradCt) * (1.0f - ft);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         NDArray gradUZt = (inGradHt * rt * gradTanh + *inGradCt) * ftMinus;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         // c_{t-1}, inGradCt = (gradCt + inGradCt) * ft;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         *inGradCt = (gradCt + *inGradCt) * ft;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//         // save results
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//         gradBias({0,0, 0,inSize, t,t+1}, true).assign(gradBFt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         gradBias({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBRt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         gradU({0,0, 0,inSize, t,t+1}, true).assign(gradUZt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         gradU({0,0, inSize,2*inSize, t,t+1}, true).assign(gradBFt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         gradU({0,0, 2*inSize, 3*inSize, t,t+1}, true).assign(gradBRt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         gradHX({0,0, 0,0, t,t+1}, true).assign(gradHXt);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     // gradInit
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     gradInit->assign(inGradCt);
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     // gradX
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     w->transposei();                                                               // [inSize x 3K]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     gradX->assign( mmul(*w, gradU) + gradHX);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     if(mask)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         gradX->applyBroadcast(broadcast::Multiply, {0,1}, mask, gradX, nullptr);       // apply mask
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     // gradB
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     gradBias.reduceAlongDimension(reduce::Sum, *gradB, {0,2}, false, true);    // [1 x 2K]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     // gradW [bS x 3K x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     x->permutei({0, 2, 1});                                               // [bS x time x inSize]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     gradW->assign( mmul(gradU, *x) );
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//     return Status::OK();
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         DECLARE_TYPES(sru_bp_logic) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//             getOpDescriptor()
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//                     ->setAllowedInputTypes(sd::DataType::ANY)
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								//                     ->setAllowedOutputTypes({ALL_FLOATS});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//         }
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// DECLARE_SHAPE_FN(sru_bp_logic) {
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto inShape = inputShape->at(0);   // [bS x inSize x time]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto bS   = inShape[1];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto inSize    = inShape[2];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     auto time    = inShape[3];
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     char order = shape::order(inShape);
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     ShapeDescriptor descriptor1(ArrayOptions::dataType(inShape), order, {bS, inSize, time});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     ShapeDescriptor descriptor2(ArrayOptions::dataType(inShape), order, {bS, 3 * inSize, inSize});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     ShapeDescriptor descriptor3(ArrayOptions::dataType(inShape), order, {1, 2 * inSize});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//     ShapeDescriptor descriptor4(ArrayOptions::dataType(inShape), order, {bS, inSize});
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-06-06 15:26:55 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								//     return SHAPELIST(ConstantShapeHelper::getInstance().createShapeInfo(descriptor1), ConstantShapeHelper::getInstance().createShapeInfo(descriptor2), ConstantShapeHelper::getInstance().createShapeInfo(descriptor3), ConstantShapeHelper::getInstance().createShapeInfo(descriptor4));
 
							 
						 
					
						
							
								
									
										
										
										
											2019-07-12 11:51:51 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								// }