2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// Created by raver119 on 29/10/17.
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_fused_batch_norm)
# include <ops/declarable/CustomOperations.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
DECLARE_TYPES ( fused_batch_norm ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-06-06 14:21:15 +02:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
2019-12-19 10:15:48 +01:00
CUSTOM_OP_IMPL ( fused_batch_norm , 3 , 3 , false , 0 , 2 ) {
2019-06-06 14:21:15 +02:00
auto x = INPUT_VARIABLE ( 0 ) ; // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
auto scale = INPUT_VARIABLE ( 1 ) ; // [iD]
auto offset = INPUT_VARIABLE ( 2 ) ; // [iD]
auto y = OUTPUT_VARIABLE ( 0 ) ; // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
auto batchMean = OUTPUT_VARIABLE ( 1 ) ; // [iD]
auto batchVar = OUTPUT_VARIABLE ( 2 ) ; // [iD]
const bool dataFormat = ( bool ) INT_ARG ( 0 ) ; // 0->NHWC, 1->NCHW
2019-12-20 20:35:39 +01:00
const bool isTraining = ( bool ) INT_ARG ( 1 ) ;
2019-06-06 14:21:15 +02:00
2019-12-20 20:35:39 +01:00
REQUIRE_TRUE ( x - > rankOf ( ) = = 4 , 0 , " CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead ! " , x - > rankOf ( ) ) ;
2019-06-06 14:21:15 +02:00
int bS = x - > sizeAt ( 0 ) ; // batch size
2019-12-20 20:35:39 +01:00
int iH , iW , iD ; // input height, input width, input depth(number of channels)
2019-06-06 14:21:15 +02:00
if ( dataFormat ) {
iD = x - > sizeAt ( 1 ) ;
iH = x - > sizeAt ( 2 ) ;
iW = x - > sizeAt ( 3 ) ;
}
else {
2019-12-20 20:35:39 +01:00
iD = x - > sizeAt ( 3 ) ;
2019-06-06 14:21:15 +02:00
iH = x - > sizeAt ( 1 ) ;
2019-12-20 20:35:39 +01:00
iW = x - > sizeAt ( 2 ) ;
}
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( scale - > rankOf ( ) = = 1 & & scale - > sizeAt ( 0 ) = = iD , 0 , " CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead " , iD , ShapeUtils : : shapeAsString ( scale ) . c_str ( ) ) ;
REQUIRE_TRUE ( offset - > rankOf ( ) = = 1 & & offset - > sizeAt ( 0 ) = = iD , 0 , " CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead " , iD , ShapeUtils : : shapeAsString ( offset ) . c_str ( ) ) ;
NDArray * mean ( nullptr ) , * variance ( nullptr ) ;
if ( ! isTraining ) {
2019-12-20 20:35:39 +01:00
mean = INPUT_VARIABLE ( 3 ) ;
variance = INPUT_VARIABLE ( 4 ) ;
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( mean - > rankOf ( ) = = 1 & & mean - > sizeAt ( 0 ) = = iD , 0 , " CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead " , iD , ShapeUtils : : shapeAsString ( mean ) . c_str ( ) ) ;
REQUIRE_TRUE ( variance - > rankOf ( ) = = 1 & & variance - > sizeAt ( 0 ) = = iD , 0 , " CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead " , iD , ShapeUtils : : shapeAsString ( variance ) . c_str ( ) ) ;
}
else {
2019-12-20 20:35:39 +01:00
//REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
2019-06-06 14:21:15 +02:00
std : : vector < Nd4jLong > shape = { iD } ;
mean = NDArrayFactory : : create_ ( scale - > ordering ( ) , shape , scale - > dataType ( ) , block . launchContext ( ) ) ;
variance = NDArrayFactory : : create_ ( scale - > ordering ( ) , shape , scale - > dataType ( ) , block . launchContext ( ) ) ;
}
// FIXME: double?
double epsilon ;
2019-12-20 20:35:39 +01:00
if ( block . getTArguments ( ) - > size ( ) > 0 )
2019-06-06 14:21:15 +02:00
epsilon = T_ARG ( 0 ) > 1.001e-5 ? T_ARG ( 0 ) : 1.001e-5 ;
2019-12-20 20:35:39 +01:00
else
2019-06-06 14:21:15 +02:00
epsilon = 0.001 ;
2019-12-20 20:35:39 +01:00
const int restSize = x - > lengthOf ( ) / iD ;
auto xAffected = NDArrayFactory : : create ( x - > ordering ( ) , { restSize , iD } , mean - > dataType ( ) , block . launchContext ( ) ) ;
2019-06-06 14:21:15 +02:00
xAffected . assign ( x ) ;
const int restSizeMinusOne = ( restSize > 1 ) ? ( restSize - 1 ) : 1 ;
// FIXME: float?
const double restSizeInv = 1.0 / restSize ;
const double restSizeAdjust = ( double ) restSize / restSizeMinusOne ;
if ( isTraining ) {
2019-12-20 20:35:39 +01:00
auto sum = xAffected . reduceAlongDimension ( reduce : : Sum , { 0 } ) ;
2019-06-06 14:21:15 +02:00
sum * = restSizeInv ;
mean - > assign ( sum ) ;
* batchMean = * mean ;
//delete sum;
}
2019-12-20 20:35:39 +01:00
else
2019-06-06 14:21:15 +02:00
* batchMean = 0. ;
2019-12-20 20:35:39 +01:00
2019-06-06 14:21:15 +02:00
xAffected - = * mean ;
2019-12-20 20:35:39 +01:00
if ( isTraining ) {
2019-06-06 14:21:15 +02:00
int power = 2 ;
2019-12-20 20:35:39 +01:00
xAffected . applyScalar ( scalar : : Pow , power , xAffected ) ;
auto sum = xAffected . reduceAlongDimension ( reduce : : Sum , { 0 } ) ;
2019-06-06 14:21:15 +02:00
sum * = restSizeInv ;
variance - > assign ( sum ) ;
* batchVar = ( * variance ) * restSizeAdjust ;
//delete sum;
}
2019-12-20 20:35:39 +01:00
else
* batchVar = 0. ;
2019-06-06 14:21:15 +02:00
xAffected * = ( * variance + epsilon ) . transform ( transform : : RSqrt ) * ( * scale ) + ( * offset ) ;
y - > assign ( xAffected ) ;
if ( isTraining ) {
delete mean ;
delete variance ;
}
return Status : : OK ( ) ;
}
DECLARE_SHAPE_FN ( fused_batch_norm ) {
auto xShapeInfo = inputShape - > at ( 0 ) ;
auto scaleShapeInfo = inputShape - > at ( 1 ) ;
const bool dataFormat = ( bool ) INT_ARG ( 0 ) ; // 0->NHWC, 1->NCHW
const int iD = dataFormat ? xShapeInfo [ 2 ] : xShapeInfo [ 4 ] ;
REQUIRE_TRUE ( scaleShapeInfo [ 0 ] = = 1 & & scaleShapeInfo [ 1 ] = = iD , 0 , " CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead " , iD , ShapeUtils : : shapeAsString ( scaleShapeInfo ) . c_str ( ) ) ;
2019-12-20 20:35:39 +01:00
2019-06-06 14:21:15 +02:00
Nd4jLong * outShapeInfo ( nullptr ) , * batchMeanShapeInfo ( nullptr ) , * batchVarShapeInfo ( nullptr ) ;
2019-12-20 20:35:39 +01:00
2019-06-06 14:21:15 +02:00
COPY_SHAPE ( xShapeInfo , outShapeInfo ) ;
COPY_SHAPE ( scaleShapeInfo , batchMeanShapeInfo ) ;
2019-12-20 20:35:39 +01:00
COPY_SHAPE ( scaleShapeInfo , batchVarShapeInfo ) ;
2019-06-06 14:21:15 +02:00
return SHAPELIST ( CONSTANT ( outShapeInfo ) , CONSTANT ( batchMeanShapeInfo ) , CONSTANT ( batchVarShapeInfo ) ) ;
}
}
}
# endif