2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// Based on PyTorch - https://github.com/pytorch/pytorch
//
# ifndef LIBND4J_CONVOLUTIONS_H
# define LIBND4J_CONVOLUTIONS_H
2020-03-02 10:49:41 +01:00
# include <array/NDArray.h>
2019-06-06 14:21:15 +02:00
# include <graph/Context.h>
2020-03-02 10:49:41 +01:00
# include <system/dll.h>
2019-06-06 14:21:15 +02:00
# include <execution/LaunchContext.h>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
enum PoolingType {
MAX_POOL = 0 ,
AVG_POOL = 1 ,
PNORM_POOL = 2 ,
} ;
2019-12-02 19:37:21 +01:00
class ND4J_EXPORT ConvolutionUtils {
2019-06-06 14:21:15 +02:00
public :
2019-11-29 12:14:30 +01:00
static inline void calcOutSizePool2D ( int & oH , int & oW , const int kH , const int kW , const int sH , const int sW , const int pH , const int pW , const int dH , const int dW , const int iH , const int iW , const int paddingMode ) {
if ( paddingMode = = 0 ) { // valid
2020-01-20 19:32:46 +01:00
// oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1;
// oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1;
oH = ( iH - ( ( kH - 1 ) * dH + 1 ) + 2 * pH ) / sH + 1 ;
oW = ( iW - ( ( kW - 1 ) * dW + 1 ) + 2 * pW ) / sW + 1 ;
2019-11-29 12:14:30 +01:00
}
else if ( paddingMode = = 1 ) { // same
2019-06-06 14:21:15 +02:00
oH = ( int ) math : : nd4j_ceil < double , double > ( iH * 1. / sH ) ;
oW = ( int ) math : : nd4j_ceil < double , double > ( iW * 1. / sW ) ;
}
2019-11-29 12:14:30 +01:00
else { // causal
oH = ( iH - 1 ) / sH + 1 ; // 2*pH = (kH-1)*dH
oW = ( iW - 1 ) / sW + 1 ;
2019-06-06 14:21:15 +02:00
}
}
2019-11-29 12:14:30 +01:00
static inline void calcOutSizePool3D ( int & oD , int & oH , int & oW , const int kD , const int kH , const int kW , const int sD , const int sH , const int sW , const int pD , const int pH , const int pW , const int dD , const int dH , const int dW , const int iD , const int iH , const int iW , const int paddingMode ) {
2019-06-06 14:21:15 +02:00
2019-11-29 12:14:30 +01:00
if ( paddingMode = = 0 ) { // valid
2020-01-20 19:32:46 +01:00
oD = ( iD - ( ( kD - 1 ) * dD + 1 ) + 2 * pD ) / sD + 1 ;
oH = ( iH - ( ( kH - 1 ) * dH + 1 ) + 2 * pH ) / sH + 1 ;
oW = ( iW - ( ( kW - 1 ) * dW + 1 ) + 2 * pW ) / sW + 1 ;
2019-06-06 14:21:15 +02:00
}
2019-11-29 12:14:30 +01:00
else if ( paddingMode = = 1 ) { // same
2020-03-02 10:49:41 +01:00
oD = ( int ) sd : : math : : nd4j_ceil < double , double > ( iD * 1. / sD ) ;
oH = ( int ) sd : : math : : nd4j_ceil < double , double > ( iH * 1. / sH ) ;
oW = ( int ) sd : : math : : nd4j_ceil < double , double > ( iW * 1. / sW ) ;
2019-11-29 12:14:30 +01:00
}
else { // causal
oD = ( iD - 1 ) / sD + 1 ;
oH = ( iH - 1 ) / sH + 1 ; // 2*pH = (kH-1)*dH
oW = ( iW - 1 ) / sW + 1 ;
2019-06-06 14:21:15 +02:00
}
}
2019-11-29 12:14:30 +01:00
static inline void calcPadding2D ( int & pH , int & pW , int oH , int oW , int iH , int iW , int kH , int kW , int sH , int sW , int dH , int dW , const int paddingMode = 1 /* default is same mode*/ ) {
2019-06-06 14:21:15 +02:00
2019-11-29 12:14:30 +01:00
if ( paddingMode = = 0 ) // valid
return ;
if ( paddingMode = = 1 ) { // same
2019-06-06 14:21:15 +02:00
2019-11-29 12:14:30 +01:00
const int eKH = ( kH - 1 ) * dH + 1 ;
const int eKW = ( kW - 1 ) * dW + 1 ;
pH = ( ( oH - 1 ) * sH + eKH - iH ) / 2 ; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
pW = ( ( oW - 1 ) * sW + eKW - iW ) / 2 ;
}
else { // causal
pH = ( kH - 1 ) * dH ;
pW = ( kW - 1 ) * dW ;
2019-06-06 14:21:15 +02:00
}
2019-11-29 12:14:30 +01:00
}
static inline void calcPadding3D ( int & pD , int & pH , int & pW , const int oD , const int oH , const int oW , const int iD , const int iH , const int iW , const int kD , const int kH , const int kW , const int sD , const int sH , const int sW , const int dD , const int dH , const int dW , const int paddingMode = 1 /* default is same mode*/ ) {
if ( paddingMode = = 0 ) // valid
return ;
if ( paddingMode = = 1 ) { // same
2019-06-06 14:21:15 +02:00
2019-11-29 12:14:30 +01:00
const int eKD = ( kD - 1 ) * dD + 1 ;
const int eKH = ( kH - 1 ) * dH + 1 ;
const int eKW = ( kW - 1 ) * dW + 1 ;
2019-06-06 14:21:15 +02:00
2019-11-29 12:14:30 +01:00
pD = ( ( oD - 1 ) * sD + eKD - iD ) / 2 ;
pH = ( ( oH - 1 ) * sH + eKH - iH ) / 2 ; //Note that padBottom is 1 bigger than this if bracketed term is not divisible by 2
pW = ( ( oW - 1 ) * sW + eKW - iW ) / 2 ;
}
else { // causal
pD = ( kD - 1 ) * dD ;
pH = ( kH - 1 ) * dH ;
pW = ( kW - 1 ) * dW ;
}
2019-06-06 14:21:15 +02:00
}
// calculation of output height and width in 2D deconvolution procedure
2019-11-29 12:14:30 +01:00
static inline void calcOutSizeDeconv2D ( int & oH , int & oW , const int kH , const int kW , const int sH , const int sW , const int pH , const int pW , const int dH , const int dW , const int iH , const int iW , const int paddingMode ) {
if ( paddingMode ) {
2019-06-06 14:21:15 +02:00
oH = sH * iH ;
oW = sW * iW ;
}
else {
2019-11-29 12:14:30 +01:00
const int ekH = ( kH - 1 ) * dH + 1 ;
const int ekW = ( kW - 1 ) * dW + 1 ;
2019-06-06 14:21:15 +02:00
oH = sH * ( iH - 1 ) + ekH - 2 * pH ;
oW = sW * ( iW - 1 ) + ekW - 2 * pW ;
}
}
// calculation of output height and width in 3D deconvolution procedure
2019-11-29 12:14:30 +01:00
static inline void calcOutSizeDeconv3D ( int & oD , int & oH , int & oW , const int kD , const int kH , const int kW , const int sD , const int sH , const int sW , const int pD , const int pH , const int pW , const int dD , const int dH , const int dW , const int iD , const int iH , const int iW , const int paddingMode ) {
if ( paddingMode ) {
2019-06-06 14:21:15 +02:00
oD = sD * iD ;
oH = sH * iH ;
oW = sW * iW ;
}
else {
2019-11-29 12:14:30 +01:00
const int ekD = ( kD - 1 ) * dD + 1 ;
const int ekH = ( kH - 1 ) * dH + 1 ;
const int ekW = ( kW - 1 ) * dW + 1 ;
2019-06-06 14:21:15 +02:00
oD = sD * ( iD - 1 ) + ekD - 2 * pD ;
oH = sH * ( iH - 1 ) + ekH - 2 * pH ;
oW = sW * ( iW - 1 ) + ekW - 2 * pW ;
}
}
// evaluates sizes values and indexes using input and output arrays depending on data format
2020-03-20 10:11:27 +01:00
static inline void getSizesAndIndexesConv2d ( const bool isNCHW , const int wFormat , const NDArray & input , const NDArray & output , int & bS , int & iC , int & iH , int & iW , int & oC , int & oH , int & oW , int & indIOioC , int & indIiH , int & indWiC , int & indWoC , int & indWkH , int & indOoH ) {
2020-05-09 07:06:14 +02:00
getSizesAndIndexesConv2d ( isNCHW , wFormat , input . shapeInfo ( ) , output . shapeInfo ( ) , bS , iC , iH , iW , oC , oH , oW , indIOioC , indIiH , indWiC , indWoC , indWkH , indOoH ) ;
2019-06-06 14:21:15 +02:00
}
2020-03-20 10:11:27 +01:00
static inline void getSizesAndIndexesConv2d ( const bool isNCHW , const int wFormat , const Nd4jLong * inShapeInfo , const Nd4jLong * outShapeInfo , int & bS , int & iC , int & iH , int & iW , int & oC , int & oH , int & oW , int & indIOioC , int & indIiH , int & indWiC , int & indWoC , int & indWkH , int & indOoH ) {
2019-06-06 14:21:15 +02:00
// input [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
2020-03-20 10:11:27 +01:00
// weights [kH, kW, iC, oC] (wFormat = 0), [oC, iC, kH, kW] (wFormat = 1), [oC, kH, kW, iC] (wFormat = 2)
2019-06-06 14:21:15 +02:00
// output [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
2020-03-20 10:11:27 +01:00
if ( 0 = = wFormat ) {
indWkH = 0 ; indWiC = 2 ; indWoC = 3 ;
}
else if ( 1 = = wFormat ) {
indWkH = 2 ; indWiC = 1 ; indWoC = 0 ;
}
else {
indWkH = 1 ; indWiC = 3 ; indWoC = 0 ;
}
2019-06-06 14:21:15 +02:00
if ( ! isNCHW ) {
indIOioC = 3 ; indIiH = 1 ; indOoH = 1 ;
}
else {
indIOioC = 1 ; indIiH = 2 ; indOoH = 2 ;
}
bS = inShapeInfo [ 1 ] ; // batch size
iC = inShapeInfo [ indIOioC + 1 ] ; // input channels
iH = inShapeInfo [ indIiH + 1 ] ; // input height
iW = inShapeInfo [ indIiH + 2 ] ; // input width
oC = outShapeInfo [ indIOioC + 1 ] ; // output channels
oH = outShapeInfo [ indOoH + 1 ] ; // output height
oW = outShapeInfo [ indOoH + 2 ] ; // output width
}
// evaluates sizes values and indexes using input and output arrays depending on data format
2020-03-20 10:11:27 +01:00
static inline void getSizesAndIndexesConv3d ( const bool isNCDHW , const int wFormat , const NDArray & input , const NDArray & output , int & bS , int & iC , int & iD , int & iH , int & iW , int & oC , int & oD , int & oH , int & oW , int & indIOioC , int & indIOioD , int & indWiC , int & indWoC , int & indWkD ) {
2019-06-06 14:21:15 +02:00
// input [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
2020-03-20 10:11:27 +01:00
// weights [kD, kH, kW, iC, oC] (wFormat = 0), [oC, iC, kD, kH, kW] (wFormat = 1), [oC, kD, kH, kW, iC] (wFormat = 2)
2019-06-06 14:21:15 +02:00
// output [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
2020-03-20 10:11:27 +01:00
if ( 0 = = wFormat ) {
indWkD = 0 ; indWiC = 3 ; indWoC = 4 ;
}
else if ( 1 = = wFormat ) {
indWkD = 2 ; indWiC = 1 ; indWoC = 0 ;
}
else {
indWkD = 1 ; indWiC = 4 ; indWoC = 0 ;
}
2019-06-06 14:21:15 +02:00
if ( ! isNCDHW ) {
indIOioC = 4 ; indIOioD = 1 ;
}
else {
indIOioC = 1 ; indIOioD = 2 ;
}
bS = input . sizeAt ( 0 ) ; // batch size
iC = input . sizeAt ( indIOioC ) ; // input channels
iD = input . sizeAt ( indIOioD ) ; // input depth
iH = input . sizeAt ( indIOioD + 1 ) ; // input height
iW = input . sizeAt ( indIOioD + 2 ) ; // input width
oC = output . sizeAt ( indIOioC ) ; // output channels
oD = output . sizeAt ( indIOioD ) ; // output depth
oH = output . sizeAt ( indIOioD + 1 ) ; // output height
oW = output . sizeAt ( indIOioD + 2 ) ; // output width
}
2019-11-29 12:14:30 +01:00
// static inline void calcPaddingAndDilationForConv2DMKL(const int iH, const int iW, const int oH, const int oW, const int kH, const int kW, const int sH, const int sW, const int paddingMode, int& pH, int& pW, int& dH, int& dW) {
2019-11-21 20:17:30 +01:00
// if(kH != 1) {
2019-11-29 12:14:30 +01:00
// if(paddingMode) {
2019-11-21 20:17:30 +01:00
// pH = (oH - 1) * sH - iH + kH - pH;
// dH = dH - 1;
// }
// else
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
// }
// if(kW != 1) {
2019-11-29 12:14:30 +01:00
// if(paddingMode) {
2019-11-21 20:17:30 +01:00
// pW = (oW - 1) * sW - iW + kW - pW;
// dW = dW - 1;
// }
// else
// dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
// }
// }
2019-11-29 12:14:30 +01:00
// static inline void calcPaddingAndDilationForConv3DMKL(const int iD, const int iH, const int iW, const int oD, const int oH, const int oW, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int paddingMode, int& pD, int& pH, int& pW, int& dD, int& dH, int& dW) {
2019-11-21 20:17:30 +01:00
// if(kD != 1) {
2019-11-29 12:14:30 +01:00
// if(paddingMode) {
2019-11-21 20:17:30 +01:00
// pD = (oD - 1) * sD - iD + kD - pD;
// dD = dD - 1;
// }
// else
// dD = (iD + 2*pD - (oD - 1) * sD - kD) / (kD - 1);
// }
// if(kH != 1) {
2019-11-29 12:14:30 +01:00
// if(paddingMode) {
2019-11-21 20:17:30 +01:00
// pH = (oH - 1) * sH - iH + kH - pH;
// dH = dH - 1;
// }
// else
// dH = (iH + 2*pH - (oH - 1) * sH - kH) / (kH - 1);
// }
// if(kW != 1) {
2019-11-29 12:14:30 +01:00
// if(paddingMode) {
2019-11-21 20:17:30 +01:00
// pW = (oW - 1) * sW - iW + kW - pW;
// dW = dW - 1;
// }
// else
// dW = (iW + 2*pW - (oW - 1) * sW - kW) / (kW - 1);
// }
// }
2019-11-14 18:21:22 +01:00
2020-03-20 10:11:27 +01:00
static std : : vector < Nd4jLong > expectWeightsShape ( const int wFormat , const int kH , const int kW , const int iC , const int oC ) {
if ( 0 = = wFormat )
return std : : vector < Nd4jLong > ( { kH , kW , iC , oC } ) ;
if ( 1 = = wFormat )
return std : : vector < Nd4jLong > ( { oC , iC , kH , kW } ) ;
return std : : vector < Nd4jLong > ( { oC , kH , kW , iC } ) ;
}
static std : : vector < Nd4jLong > expectWeightsShape ( const int wFormat , const int kD , const int kH , const int kW , const int iC , const int oC ) {
if ( 0 = = wFormat )
return std : : vector < Nd4jLong > ( { kD , kH , kW , iC , oC } ) ;
if ( 1 = = wFormat )
return std : : vector < Nd4jLong > ( { oC , iC , kD , kH , kW } ) ;
return std : : vector < Nd4jLong > ( { oC , kD , kH , kW , iC } ) ;
}
static void conv2d ( sd : : graph : : Context & context , const NDArray * input , const NDArray * weights , const NDArray * bias , NDArray * output , const int kH , const int kW , const int sH , const int sW , int pH , int pW , const int dH , const int dW , const int paddingMode , const int isNCHW , const int wFormat ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
// static void conv2d(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, NDArray* output, const std::vector<int>& intArgs);
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
// static void conv2dBP(sd::graph::Context & block, const std::vector<NDArray*>& inArrs, const std::vector<NDArray*>& outArrs, const std::vector<int>& intArgs);
2019-06-06 14:21:15 +02:00
2020-03-20 10:11:27 +01:00
static void conv2dBP ( sd : : graph : : Context & block , const NDArray * input , const NDArray * weights , const NDArray * bias , const NDArray * gradO , NDArray * gradI , NDArray * gradW , NDArray * gradB , const int kH , const int kW , const int sH , const int sW , int pH , int pW , const int dH , const int dW , const int paddingMode , const int isNCHW , const int wFormat ) ;
2019-06-06 14:21:15 +02:00
2020-03-20 10:11:27 +01:00
static void depthwiseConv2d ( sd : : graph : : Context & block , const NDArray * input , const NDArray * weights , const NDArray * bias , NDArray * output , const int kH , const int kW , const int sH , const int sW , int pH , int pW , const int dH , const int dW , const int paddingMode , const int isNCHW , const int wFormat ) ;
2019-06-06 14:21:15 +02:00
2020-03-20 10:11:27 +01:00
static void depthwiseConv2dBP ( sd : : graph : : Context & block , const NDArray * input , const NDArray * weights , const NDArray * bias , const NDArray * gradO , NDArray * gradI , NDArray * gradW , NDArray * gradB , const int kH , const int kW , const int sH , const int sW , int pH , int pW , const int dH , const int dW , const int paddingMode , const int isNCHW , const int wFormat ) ;
2019-06-06 14:21:15 +02:00
2020-03-20 10:11:27 +01:00
static void sconv2d ( sd : : graph : : Context & block , const NDArray * input , const NDArray * weightsDepth , const NDArray * weightsPoint , const NDArray * bias , NDArray * output , const int kH , const int kW , const int sH , const int sW , int pH , int pW , const int dH , const int dW , const int paddingMode , const int isNCHW , const int wFormat ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void vol2col ( sd : : graph : : Context & block , const NDArray & vol , NDArray & col , const int sD , const int sH , const int sW , const int pD , const int pH , const int pW , const int dD , const int dH , const int dW ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void col2vol ( sd : : graph : : Context & block , const NDArray & col , NDArray & vol , const int sD , const int sH , const int sW , const int pD , const int pH , const int pW , const int dD , const int dH , const int dW ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void upsampling2d ( sd : : graph : : Context & block , const NDArray & input , NDArray & output , const int factorH , const int factorW , const bool isNCHW ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void upsampling3d ( sd : : graph : : Context & block , const NDArray & input , NDArray & output , const int factorD , const int factorH , const int factorW , const bool isNCDHW ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void upsampling2dBP ( sd : : graph : : Context & block , const NDArray & gradO , NDArray & gradI , const bool isNCHW ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void upsampling3dBP ( sd : : graph : : Context & block , const NDArray & gradO , NDArray & gradI , const bool isNCDHW ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void pooling2d ( sd : : graph : : Context & block , const NDArray & input , NDArray & output , const int kH , const int kW , const int sH , const int sW , const int pH , const int pW , const int dH , const int dW , const PoolingType poolingMode , const int extraParam0 ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void pooling3d ( sd : : graph : : Context & block , const NDArray & input , NDArray & output , const int kD , const int kH , const int kW , const int sD , const int sH , const int sW , const int pD , const int pH , const int pW , const int dD , const int dH , const int dW , const int poolingMode , const int extraParam0 ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void pooling2dBP ( sd : : graph : : Context & block , const NDArray & input , const NDArray & gradO , NDArray & gradI , const int kH , const int kW , const int sH , const int sW , const int pH , const int pW , const int dH , const int dW , const int poolingMode , const int extraParam0 ) ;
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
static void pooling3dBP ( sd : : graph : : Context & block , const NDArray & input , const NDArray & gradO , NDArray & gradI , const int kD , const int kH , const int kW , const int sD , const int sH , const int sW , const int pD , const int pH , const int pW , const int dD , const int dH , const int dW , const int poolingMode , const int extraParam0 ) ;
2019-06-06 14:21:15 +02:00
} ;
}
}
# endif //LIBND4J_CONVOLUTIONS_H