2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								/*******************************************************************************
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  Copyright  ( c )  2015 - 2018  Skymind ,  Inc . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  This  program  and  the  accompanying  materials  are  made  available  under  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  terms  of  the  Apache  License ,  Version  2.0  which  is  available  at 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  https : //www.apache.org/licenses/LICENSE-2.0.
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  Unless  required  by  applicable  law  or  agreed  to  in  writing ,  software 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  distributed  under  the  License  is  distributed  on  an  " AS IS "  BASIS ,  WITHOUT 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  WARRANTIES  OR  CONDITIONS  OF  ANY  KIND ,  either  express  or  implied .  See  the 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  License  for  the  specific  language  governing  permissions  and  limitations 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  under  the  License . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 * 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 *  SPDX - License - Identifier :  Apache - 2.0 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								// @author Yurii Shyrma, created on 20.03.2018
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								//
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  <ops/declarable/CustomOperations.h> 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# include  <ops/declarable/helpers/convolutions.h> 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								namespace  sd  {  
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								namespace  ops   {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								CUSTOM_OP_IMPL ( pointwise_conv2d ,  2 ,  1 ,  false ,  0 ,  0 )  {  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    auto  input    =  INPUT_VARIABLE ( 0 ) ;                                     // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    auto  weights  =  INPUT_VARIABLE ( 1 ) ;                                     // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    auto  bias     =  block . width ( )  >  2  ?  INPUT_VARIABLE ( 2 )  :  nullptr ;       // [oC]
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    auto  output   =  OUTPUT_VARIABLE ( 0 ) ;                                    // [bS, iH, iW, oC] (NHWC) or [bS, oC, iH, iW] (NCHW)
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    REQUIRE_TRUE ( input - > rankOf ( )    = =  4 ,  0 ,  " CUSTOM POINTWISECONV2D OP: rank of input array must be equal to 4, but got %i instead ! " ,  input - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    REQUIRE_TRUE ( weights - > rankOf ( )  = =  4 ,  0 ,  " CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to 4, but got %i instead ! " ,  weights - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if ( bias ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        REQUIRE_TRUE ( bias - > rankOf ( )  < =  2 ,  0 ,  " CUSTOM POINTWISECONV2D OP: rank of biases array must be equal <= 2, but got %i instead ! " ,  bias - > rankOf ( ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  kH  =  1 ;                                                              // filter(kernel) height
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  kW  =  1 ;                                                              // filter(kernel) width
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  sH  =  1 ;                                                              // strides height
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  sW  =  1 ;                                                              // strides width
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  pH  =  0 ;                                                              // paddings height
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  pW  =  0 ;                                                              // paddings width
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  dH  =  1 ;                                                              // dilations height
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    int  dW  =  1 ;                                                              // dilations width
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    int  isNCHW   =  block . getIArguments ( ) - > size ( )  >  0  ?  ! INT_ARG ( 0 )  :  1 ;       // INT_ARG(0): 0-NCHW, 1-NHWC
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  wFormat  =  block . getIArguments ( ) - > size ( )  >  1  ?  INT_ARG ( 1 )  :  0 ;        // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    int  bS ,  iC ,  iH ,  iW ,  oC ,  oH ,  oW ;                              // batch size, input channels, input height/width, output channels, output height/width;
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  indIOioC ,  indIiH ,  indWoC ,  indWiC ,  indWkH ,  indOoH ;        // corresponding indexes
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    ConvolutionUtils : : getSizesAndIndexesConv2d ( isNCHW ,  wFormat ,  * input ,  * output ,  bS ,  iC ,  iH ,  iW ,  oC ,  oH ,  oW ,  indIOioC ,  indIiH ,  indWiC ,  indWoC ,  indWkH ,  indOoH ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    std : : vector < Nd4jLong >  expectedWeightsShape  =  ConvolutionUtils : : expectWeightsShape ( wFormat ,  1 ,  1 ,  iC ,  oC ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    REQUIRE_TRUE ( weights - > isSameShape ( expectedWeightsShape ) ,  0 ,  " CUSTOM POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( expectedWeightsShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( weights ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        REQUIRE_TRUE ( bias - > rankOf ( )  < =  2  & &  oC  = =  bias - > lengthOf ( ) ,  0 ,  " CUSTOM POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead ! " ,  oC ,  bias - > rankOf ( ) ,  bias - > lengthOf ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    ConvolutionUtils : : conv2d ( block ,  input ,  weights ,  bias ,  output ,  kH , kW ,  sH , sW ,  pH , pW ,  dH , dW ,  1 /*isSameMode*/ ,  isNCHW ,  wFormat ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  Status : : OK ( ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    DECLARE_TYPES ( pointwise_conv2d )  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        getOpDescriptor ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-02 12:49:41 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                - > setAllowedInputTypes ( sd : : DataType : : ANY ) 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								                - > setAllowedOutputTypes ( { ALL_FLOATS } ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								DECLARE_SHAPE_FN ( pointwise_conv2d )  {  
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    Nd4jLong *  inputShapeInfo   =  inputShape - > at ( 0 ) ;                                    // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    Nd4jLong *  weightsShapeInfo   =  inputShape - > at ( 1 ) ;                                  // [1, 1, iC, oC], [oC, iC, 1, 1], [oC, 1, 1, iC]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    Nd4jLong *  biasShapeInfo  =  block . width ( )  >  2  ?  inputShape - > at ( 2 )  :  nullptr ;        // [oC]
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    const  int  rank  =  4 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    REQUIRE_TRUE ( inputShapeInfo [ 0 ]    = =  rank ,  0 ,  " CUSTOM POINTWISECONV2D OP: rank of input array must be equal to %i, but got %i instead ! " ,  rank ,  inputShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    REQUIRE_TRUE ( weightsShapeInfo [ 0 ]  = =  rank ,  0 ,  " CUSTOM POINTWISECONV2D OP: rank of weights array must be equal to %i, but got %i instead ! " ,  rank ,  weightsShapeInfo [ 0 ] ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    int  isNCHW  =  block . getIArguments ( ) - > size ( )  >  0  ?  ! INT_ARG ( 0 )  :  1 ;        // INT_ARG(0): 0-NCHW, 1-NHWC
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    int  wFormat  =  block . getIArguments ( ) - > size ( )  >  1  ?  INT_ARG ( 1 )  :  0 ;        // 0 - [1, 1, iC, oC], 1 - [oC, iC, 1, 1], 2 - [oC, 1, 1, iC]
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    int  indIOioC ,  indWoC ( 0  = =  wFormat  ?  3  :  0 ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    if ( ! isNCHW ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        indIOioC  =  3 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        indIOioC  =  1 ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    const  int  bS  =  inputShapeInfo [ 1 ] ;                             // batch size
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    const  int  iC  =  inputShapeInfo [ indIOioC + 1 ] ;                    // input channels
 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    const  int  oC  =  weightsShapeInfo [ indWoC + 1 ] ;                    // output channels
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2020-03-20 11:11:27 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    std : : vector < Nd4jLong >  expectedWeightsShape  =  ConvolutionUtils : : expectWeightsShape ( wFormat ,  1 ,  1 ,  iC ,  oC ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    REQUIRE_TRUE ( ShapeUtils : : areShapesEqual ( weightsShapeInfo ,  expectedWeightsShape ) ,  0 ,  " POINTWISECONV2D OP: wrong shape of weights array, expected is %s, but got %s instead ! " ,  ShapeUtils : : shapeAsString ( expectedWeightsShape ) . c_str ( ) ,  ShapeUtils : : shapeAsString ( weightsShapeInfo ) . c_str ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  ( biasShapeInfo ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        REQUIRE_TRUE ( biasShapeInfo [ 0 ]  < =  2  & &  oC  = =  shape : : length ( biasShapeInfo ) ,  0 ,  " POINTWISECONV2D OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead ! " ,  oC ,  biasShapeInfo [ 0 ] ,  shape : : length ( biasShapeInfo ) ) ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    auto  outputShapeInfo  =  ShapeBuilders : : copyShapeInfoAndType ( inputShapeInfo ,  weightsShapeInfo ,  true ,  block . getWorkspace ( ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    // do not forget to put oC instead of iC in outputShapeInfo
 
							 
						 
					
						
							
								
									
										
										
										
											2020-03-03 06:32:37 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    outputShapeInfo [ indIOioC  +  1 ]  =  oC ; 
							 
						 
					
						
							
								
									
										
										
										
											2019-06-06 15:21:15 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    shape : : updateStrides ( outputShapeInfo ,  shape : : order ( inputShapeInfo ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  SHAPELIST ( CONSTANT ( outputShapeInfo ) ) ; 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}