2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma, created on 05.12.2017
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_sruCell)
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/sru.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( sruCell , 4 , 2 , false , 0 , 0 ) {
2020-03-03 05:32:37 +01:00
auto xt = INPUT_VARIABLE ( 0 ) ; // input [bS x inSize], bS - batch size, inSize - number of features
2019-06-06 14:21:15 +02:00
auto ct_1 = INPUT_VARIABLE ( 1 ) ; // previous cell state ct [bS x inSize], that is at previous time step t-1
auto w = INPUT_VARIABLE ( 2 ) ; // weights [inSize x 3*inSize]
2019-09-05 04:25:03 +02:00
auto b = INPUT_VARIABLE ( 3 ) ; // biases [2*inSize]
2019-06-06 14:21:15 +02:00
auto ht = OUTPUT_VARIABLE ( 0 ) ; // current cell output [bS x inSize], that is at current time step t
auto ct = OUTPUT_VARIABLE ( 1 ) ; // current cell state [bS x inSize], that is at current time step t
const int rank = xt - > rankOf ( ) ;
2020-03-03 05:32:37 +01:00
const int bS = xt - > sizeAt ( 0 ) ;
2019-06-06 14:21:15 +02:00
const int inSize = xt - > sizeAt ( 1 ) ; // inSize - number of features
// input shapes validation
2020-03-03 05:32:37 +01:00
const std : : vector < Nd4jLong > correctCt_1Shape = { bS , inSize } ;
const std : : vector < Nd4jLong > correctWShape = { inSize , 3 * inSize } ;
const std : : vector < Nd4jLong > correctBShape = { 2 * inSize } ;
2019-06-06 14:21:15 +02:00
2020-03-03 05:32:37 +01:00
REQUIRE_TRUE ( ct_1 - > isSameShape ( correctCt_1Shape ) , 0 , " SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( correctCt_1Shape ) . c_str ( ) , ShapeUtils : : shapeAsString ( ct_1 ) . c_str ( ) ) ;
REQUIRE_TRUE ( w - > isSameShape ( correctWShape ) , 0 , " SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( correctWShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( w ) . c_str ( ) ) ;
REQUIRE_TRUE ( b - > isSameShape ( correctBShape ) , 0 , " SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( correctBShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( b ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
// fixme: shitty initializer lists
helpers : : sruCell ( block . launchContext ( ) , xt , ct_1 , w , b , ht , ct ) ;
2020-03-03 05:32:37 +01:00
2019-06-06 14:21:15 +02:00
return Status : : OK ( ) ;
}
DECLARE_TYPES ( sruCell ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
DECLARE_SHAPE_FN ( sruCell ) {
auto xtShapeInfo = inputShape - > at ( 0 ) ; // input [bS x inSize], bS - batch size, inSize - number of features
2020-03-03 05:32:37 +01:00
auto ct_1ShapeInfo = inputShape - > at ( 1 ) ; // previous cell state ct [bS x inSize], that is at previous time step t-1
2019-06-06 14:21:15 +02:00
auto wShapeInfo = inputShape - > at ( 2 ) ; // weights [inSize x 3*inSize]
auto bShapeInfo = inputShape - > at ( 3 ) ; // biases [2*inSize]
const int rank = xtShapeInfo [ 0 ] ;
2020-03-03 05:32:37 +01:00
const int bS = xtShapeInfo [ 1 ] ;
2019-06-06 14:21:15 +02:00
const int inSize = xtShapeInfo [ 2 ] ; // inSize - number of features
// input shapes validation
2020-03-03 05:32:37 +01:00
const std : : vector < Nd4jLong > correctCt_1Shape = { bS , inSize } ;
const std : : vector < Nd4jLong > correctWShape = { inSize , 3 * inSize } ;
const std : : vector < Nd4jLong > correctBShape = { 2 * inSize } ;
REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( ct_1ShapeInfo , correctCt_1Shape ) , 0 , " SRUCELL operation: wrong shape of previous cell state, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( correctCt_1Shape ) . c_str ( ) , ShapeUtils : : shapeAsString ( ct_1ShapeInfo ) . c_str ( ) ) ;
REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( wShapeInfo , correctWShape ) , 0 , " SRUCELL operation: wrong shape of weights, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( correctWShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( wShapeInfo ) . c_str ( ) ) ;
REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( bShapeInfo , correctBShape ) , 0 , " SRUCELL operation: wrong shape of biases, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( correctBShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( bShapeInfo ) . c_str ( ) ) ;
2019-06-06 14:21:15 +02:00
// evaluate output shapeInfos
Nd4jLong * hShapeInfo ( nullptr ) , * cShapeInfo ( nullptr ) ;
ALLOCATE ( hShapeInfo , block . getWorkspace ( ) , shape : : shapeInfoLength ( rank ) , Nd4jLong ) ; // [bS x numProj]
ALLOCATE ( cShapeInfo , block . getWorkspace ( ) , shape : : shapeInfoLength ( rank ) , Nd4jLong ) ; // [bS x numUnits]
2020-03-03 05:32:37 +01:00
2019-06-06 14:21:15 +02:00
hShapeInfo [ 0 ] = cShapeInfo [ 0 ] = rank ;
hShapeInfo [ 1 ] = cShapeInfo [ 1 ] = bS ;
hShapeInfo [ 2 ] = cShapeInfo [ 2 ] = inSize ;
2020-03-03 05:32:37 +01:00
2019-06-06 14:21:15 +02:00
ShapeUtils : : updateStridesAndType ( hShapeInfo , ct_1ShapeInfo , shape : : order ( ct_1ShapeInfo ) ) ;
ShapeUtils : : updateStridesAndType ( cShapeInfo , ct_1ShapeInfo , shape : : order ( ct_1ShapeInfo ) ) ;
2020-03-03 05:32:37 +01:00
2019-06-06 14:21:15 +02:00
return SHAPELIST ( ConstantShapeHelper : : getInstance ( ) - > createFromExisting ( hShapeInfo , block . workspace ( ) ) , ConstantShapeHelper : : getInstance ( ) - > createFromExisting ( cShapeInfo , block . workspace ( ) ) ) ;
2020-03-03 05:32:37 +01:00
}
2019-06-06 14:21:15 +02:00
}
}
# endif