2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma, created on 02.04.2018
//
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/rnn.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( static_rnn , 4 , 2 , false , 0 , 0 ) {
auto x = INPUT_VARIABLE ( 0 ) ; // input [time x bS x inSize]
auto Wx = INPUT_VARIABLE ( 1 ) ; // input-to-hidden weights, [inSize x numUnits]
auto Wh = INPUT_VARIABLE ( 2 ) ; // hidden-to-hidden weights, [numUnits x numUnits]
auto b = INPUT_VARIABLE ( 3 ) ; // biases for, [2*numUnits]
NDArray * h0 = nullptr ; // initial cell output (at time step = 0) [bS x numUnits]
NDArray * maxTimeStep = nullptr ; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
if ( block . width ( ) = = 5 ) {
if ( ( * INPUT_VARIABLE ( 4 ) ) . rankOf ( ) = = 2 )
h0 = INPUT_VARIABLE ( 4 ) ;
else
maxTimeStep = INPUT_VARIABLE ( 4 ) ;
}
else if ( block . width ( ) = = 6 ) {
h0 = INPUT_VARIABLE ( 4 ) ;
maxTimeStep = INPUT_VARIABLE ( 5 ) ;
2019-07-20 07:58:44 +02:00
}
2019-06-06 14:21:15 +02:00
auto h = OUTPUT_VARIABLE ( 0 ) ; // cell outputs [time x bS x numUnits]
auto hFinal = OUTPUT_VARIABLE ( 1 ) ; // at the end it will store cell final non-zero output [bS x numUnits]
REQUIRE_TRUE ( x - > rankOf ( ) = = 3 , 0 , " STATIC_RNN custom operation: input array x must have rank = 3, but got %i instead ! " , x - > rankOf ( ) ) ;
2019-07-20 07:58:44 +02:00
REQUIRE_TRUE ( Wx - > rankOf ( ) = = 2 , 0 , " STATIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead ! " , Wx - > rankOf ( ) ) ;
2019-06-06 14:21:15 +02:00
const int time = x - > sizeAt ( 0 ) ;
const int bS = x - > sizeAt ( 1 ) ;
const int inSize = x - > sizeAt ( 2 ) ;
const int numUnits = Wx - > sizeAt ( 1 ) ;
2020-03-03 05:32:37 +01:00
const std : : vector < Nd4jLong > expectedWhShape = { numUnits , numUnits } ;
const std : : vector < Nd4jLong > expectedbShape = { 2 * numUnits } ;
REQUIRE_TRUE ( Wh - > isSameShape ( expectedWhShape ) , 0 , " STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedWhShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( Wh ) . c_str ( ) ) ;
REQUIRE_TRUE ( b - > isSameShape ( expectedbShape ) , 0 , " STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedbShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ;
if ( h0 ) {
const std : : vector < Nd4jLong > expectedh0Shape = { bS , numUnits } ;
REQUIRE_TRUE ( h0 - > isSameShape ( expectedh0Shape ) , 0 , " STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedh0Shape ) . c_str ( ) , ShapeUtils : : shapeAsString ( h0 ) . c_str ( ) ) ;
}
2019-06-06 14:21:15 +02:00
if ( maxTimeStep )
2020-03-03 05:32:37 +01:00
REQUIRE_TRUE ( maxTimeStep - > isSameShape ( { bS } ) , 0 , " STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS } ) . c_str ( ) , ShapeUtils : : shapeAsString ( maxTimeStep ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
helpers : : rnnTimeLoop ( block . launchContext ( ) , x , Wx , Wh , b , h0 , maxTimeStep , h , hFinal ) ;
2019-07-20 07:58:44 +02:00
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
2019-07-20 07:58:44 +02:00
DECLARE_TYPES ( static_rnn ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-07-20 07:58:44 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
2019-06-06 14:21:15 +02:00
2019-07-20 07:58:44 +02:00
DECLARE_SHAPE_FN ( static_rnn ) {
2019-06-06 14:21:15 +02:00
auto xShapeInfo = inputShape - > at ( 0 ) ; // input [time x bS x inSize]
2019-07-20 07:58:44 +02:00
auto WxShapeInfo = inputShape - > at ( 1 ) ; // input-to-hidden weights, [inSize x numUnits]
auto WhShapeInfo = inputShape - > at ( 2 ) ; // hidden-to-hidden weights, [numUnits x numUnits]
auto bShapeInfo = inputShape - > at ( 3 ) ; // biases for, [2*numUnits]
2019-06-06 14:21:15 +02:00
2020-05-09 07:06:14 +02:00
const Nd4jLong * h0ShapeInfo = nullptr ; // initial cell output (at time step = 0) [bS x numUnits]
const Nd4jLong * maxTimeStepShapeInfo = nullptr ; // vector [bS] containing integer values within [0,time), each element of this vector set max time step per each input in batch, this means there are no calculations for time >= maxTimeStep
2019-06-06 14:21:15 +02:00
if ( block . width ( ) = = 5 ) {
if ( inputShape - > at ( 4 ) [ 0 ] = = 2 )
h0ShapeInfo = inputShape - > at ( 4 ) ;
else
maxTimeStepShapeInfo = inputShape - > at ( 4 ) ;
}
else if ( block . width ( ) = = 6 ) {
h0ShapeInfo = inputShape - > at ( 4 ) ;
maxTimeStepShapeInfo = inputShape - > at ( 5 ) ;
2019-07-20 07:58:44 +02:00
}
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( xShapeInfo [ 0 ] = = 3 , 0 , " STATIC_RNN custom operation: input array x must have rank = 3, but got %i instead ! " , xShapeInfo [ 0 ] ) ;
2019-07-20 07:58:44 +02:00
REQUIRE_TRUE ( WxShapeInfo [ 0 ] = = 2 , 0 , " STATIC_RNN custom operation: input-to-hidden weights array must have rank = 2, but got %i instead ! " , WxShapeInfo [ 0 ] ) ;
2019-06-06 14:21:15 +02:00
const int inRank = xShapeInfo [ 0 ] ;
const int time = xShapeInfo [ 1 ] ;
const int bS = xShapeInfo [ 2 ] ;
const int numUnits = WxShapeInfo [ 2 ] ;
2020-03-03 05:32:37 +01:00
const std : : vector < Nd4jLong > expectedWhShape = { numUnits , numUnits } ;
const std : : vector < Nd4jLong > expectedbShape = { 2 * numUnits } ;
REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( WhShapeInfo , expectedWhShape ) , 0 , " STATIC_RNN custom operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedWhShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( WhShapeInfo ) . c_str ( ) ) ;
REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( bShapeInfo , expectedbShape ) , 0 , " STATIC_RNN custom operation: wrong shape of biases array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedbShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( bShapeInfo ) . c_str ( ) ) ;
if ( h0ShapeInfo ) {
const std : : vector < Nd4jLong > expectedh0Shape = { bS , numUnits } ;
REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( h0ShapeInfo , expectedh0Shape ) , 0 , " STATIC_RNN custom operation: wrong shape of initial cell output array, expected is %s but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedh0Shape ) . c_str ( ) , ShapeUtils : : shapeAsString ( h0ShapeInfo ) . c_str ( ) ) ;
}
2019-06-06 14:21:15 +02:00
if ( maxTimeStepShapeInfo )
2020-03-03 05:32:37 +01:00
REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( maxTimeStepShapeInfo , { bS } ) , 0 , " STATIC_RNN custom operation: wrong shape of maxTimeStep array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( { bS } ) . c_str ( ) , ShapeUtils : : shapeAsString ( maxTimeStepShapeInfo ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
// evaluate output shapeInfos
Nd4jLong * hShapeInfo ( nullptr ) , * hPrevShapeInfo ( nullptr ) ;
ALLOCATE ( hShapeInfo , block . getWorkspace ( ) , shape : : shapeInfoLength ( inRank ) , Nd4jLong ) ;
ALLOCATE ( hPrevShapeInfo , block . getWorkspace ( ) , shape : : shapeInfoLength ( inRank - 1 ) , Nd4jLong ) ;
2019-07-20 07:58:44 +02:00
2019-06-06 14:21:15 +02:00
hShapeInfo [ 0 ] = inRank ;
hPrevShapeInfo [ 0 ] = inRank - 1 ;
hShapeInfo [ 1 ] = time ;
hShapeInfo [ 2 ] = hPrevShapeInfo [ 1 ] = bS ;
hShapeInfo [ 3 ] = hPrevShapeInfo [ 2 ] = numUnits ;
ShapeUtils : : updateStridesAndType ( hShapeInfo , xShapeInfo , shape : : order ( xShapeInfo ) ) ;
ShapeUtils : : updateStridesAndType ( hPrevShapeInfo , xShapeInfo , shape : : order ( xShapeInfo ) ) ;
2019-07-20 07:58:44 +02:00
2019-06-06 14:21:15 +02:00
return SHAPELIST ( CONSTANT ( hShapeInfo ) , CONSTANT ( hPrevShapeInfo ) ) ;
2019-07-20 07:58:44 +02:00
}
2019-06-06 14:21:15 +02:00
}
}