2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								/*******************************************************************************
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  Copyright  ( c )  2015 - 2018  Skymind ,  Inc . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  This  program  and  the  accompanying  materials  are  made  available  under  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  terms  of  the  Apache  License ,  Version  2.0  which  is  available  at 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  https : //www.apache.org/licenses/LICENSE-2.0.
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  Unless  required  by  applicable  law  or  agreed  to  in  writing ,  software 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  distributed  under  the  License  is  distributed  on  an  " AS IS "  BASIS ,  WITHOUT 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  WARRANTIES  OR  CONDITIONS  OF  ANY  KIND ,  either  express  or  implied .  See  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  License  for  the  specific  language  governing  permissions  and  limitations 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  under  the  License . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 *  SPDX - License - Identifier :  Apache - 2.0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								// Created by raver119 on 29/10/17.
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								//
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								# include  <system/op_boilerplate.h> 
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								# if NOT_EXCLUDED(OP_fused_batch_norm) 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# include  <ops/declarable/CustomOperations.h> 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								namespace  sd  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								namespace  ops  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    DECLARE_TYPES ( fused_batch_norm )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        getOpDescriptor ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								                - > setAllowedInputTypes ( sd : : DataType : : ANY ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								                - > setAllowedOutputTypes ( { ALL_FLOATS } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-19 11:15:48 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								CUSTOM_OP_IMPL ( fused_batch_norm ,  3 ,  3 ,  false ,  0 ,  2 )  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    auto  x       =  INPUT_VARIABLE ( 0 ) ;                  // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  scale   =  INPUT_VARIABLE ( 1 ) ;                  // [iD]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  offset  =  INPUT_VARIABLE ( 2 ) ;                  // [iD]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  y  =  OUTPUT_VARIABLE ( 0 ) ;                      // [bS,iH,iW,iD] (NHWC) or [bS,iD,iH,iW] (NCHW)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  batchMean  =  OUTPUT_VARIABLE ( 1 ) ;              // [iD]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  batchVar   =  OUTPUT_VARIABLE ( 2 ) ;              // [iD]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  bool  dataFormat  =  ( bool ) INT_ARG ( 0 ) ;                // 0->NHWC, 1->NCHW
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    const  bool  isTraining  =  ( bool ) INT_ARG ( 1 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( x - > rankOf ( )  = =  4 ,  0 ,  " CUSTOM_OP fused_batch_norm: the rank of input x array must be equal to 4, but got %i instead ! " ,  x - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    int  bS  =  x - > sizeAt ( 0 ) ;               // batch size
 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    int  iH ,  iW ,  iD ;                      // input height, input width, input depth(number of channels)
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    if ( dataFormat )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        iD  =  x - > sizeAt ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        iH  =  x - > sizeAt ( 2 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        iW  =  x - > sizeAt ( 3 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        iD  =  x - > sizeAt ( 3 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        iH  =  x - > sizeAt ( 1 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        iW  =  x - > sizeAt ( 2 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( scale - > rankOf ( )  = =  1   & &  scale - > sizeAt ( 0 )   = =  iD ,  0 ,  " CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead " ,  iD ,  ShapeUtils : : shapeAsString ( scale ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( offset - > rankOf ( )  = =  1  & &  offset - > sizeAt ( 0 )  = =  iD ,  0 ,  " CUSTOM_OP fused_batch_norm: wrong shape of input offset array, expected is [%i], but got %s instead " ,  iD ,  ShapeUtils : : shapeAsString ( offset ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    NDArray  * mean ( nullptr ) ,  * variance ( nullptr ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( ! isTraining ) { 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        mean      =  INPUT_VARIABLE ( 3 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        variance  =  INPUT_VARIABLE ( 4 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( mean - > rankOf ( )  = =  1      & &  mean - > sizeAt ( 0 )  = =  iD ,      0 ,  " CUSTOM_OP fused_batch_norm: wrong shape of input mean array, expected is [%i], but got %s instead " ,  iD ,  ShapeUtils : : shapeAsString ( mean ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        REQUIRE_TRUE ( variance - > rankOf ( )  = =  1  & &  variance - > sizeAt ( 0 )  = =  iD ,  0 ,  " CUSTOM_OP fused_batch_norm: wrong shape of input variance array, expected is [%i], but got %s instead " ,  iD ,  ShapeUtils : : shapeAsString ( variance ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        //REQUIRE_TRUE(block.width() == 3, 0, "CUSTOM_OP fused_batch_norm: when isTraining=true then number of input arrays must be equal to 3, but got %i instead !", block.width());
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        std : : vector < Nd4jLong >  shape  =  { iD } ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        mean  =  NDArrayFactory : : create_ ( scale - > ordering ( ) ,  shape ,  scale - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        variance  =  NDArrayFactory : : create_ ( scale - > ordering ( ) ,  shape ,  scale - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // FIXME: double?
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    double  epsilon ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if ( block . getTArguments ( ) - > size ( )  >  0 ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        epsilon  =  T_ARG ( 0 )  >  1.001e-5  ?  T_ARG ( 0 )  :  1.001e-5 ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    else 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        epsilon  =  0.001 ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  restSize  =  x - > lengthOf ( )  /  iD ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  xAffected  =  NDArrayFactory : : create ( x - > ordering ( ) ,  { restSize ,  iD } ,  mean - > dataType ( ) ,  block . launchContext ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    xAffected . assign ( x ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  restSizeMinusOne  =  ( restSize  >  1 )  ?  ( restSize  -  1 )  :  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    // FIXME: float?
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  double  restSizeInv  =  1.0  /  restSize ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  double  restSizeAdjust  =  ( double ) restSize  /  restSizeMinusOne ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( isTraining )  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        auto  sum  =  xAffected . reduceAlongDimension ( reduce : : Sum ,  { 0 } ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        sum  * =  restSizeInv ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        mean - > assign ( sum ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        * batchMean  =  * mean ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        //delete sum;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    else 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        * batchMean  =  0. ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    xAffected  - =  * mean ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if ( isTraining )  { 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        int  power  =  2 ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        xAffected . applyScalar ( scalar : : Pow ,  power ,  xAffected ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        auto  sum  =  xAffected . reduceAlongDimension ( reduce : : Sum ,  { 0 } ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								        sum  * =  restSizeInv ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        variance - > assign ( sum ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        * batchVar   =  ( * variance )  *  restSizeAdjust ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        //delete sum;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    else 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        * batchVar   =  0. ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    xAffected  * =  ( * variance  +  epsilon ) . transform ( transform : : RSqrt )  *  ( * scale )  +  ( * offset ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    y - > assign (  xAffected  ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if ( isTraining )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        delete  mean ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        delete  variance ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  Status : : OK ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								DECLARE_SHAPE_FN ( fused_batch_norm )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  xShapeInfo      =  inputShape - > at ( 0 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    auto  scaleShapeInfo  =  inputShape - > at ( 1 ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  bool  dataFormat  =  ( bool ) INT_ARG ( 0 ) ;                // 0->NHWC, 1->NCHW
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    const  int  iD  =  dataFormat  ?  xShapeInfo [ 2 ]  :  xShapeInfo [ 4 ] ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    REQUIRE_TRUE ( scaleShapeInfo [ 0 ]  = =  1   & &  scaleShapeInfo [ 1 ]  = =  iD ,  0 ,  " CUSTOM_OP fused_batch_norm: wrong shape of input scale array, expected is [%i], but got %s instead " ,  iD ,  ShapeUtils : : shapeAsString ( scaleShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    Nd4jLong *  outShapeInfo ( nullptr ) ,  * batchMeanShapeInfo ( nullptr ) ,  * batchVarShapeInfo ( nullptr ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    COPY_SHAPE ( xShapeInfo ,  outShapeInfo ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    COPY_SHAPE ( scaleShapeInfo ,  batchMeanShapeInfo ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-12-20 21:35:39 +02:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    COPY_SHAPE ( scaleShapeInfo ,  batchVarShapeInfo ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								    return  SHAPELIST ( CONSTANT ( outShapeInfo ) ,  CONSTANT ( batchMeanShapeInfo ) ,  CONSTANT ( batchVarShapeInfo ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								} 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								# endif