2019-09-11 20:50:28 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author saudet
// @author raver119@gmail.com
//
# include <ops/declarable/PlatformHelper.h>
# include <ops/declarable/OpRegistrator.h>
2020-03-02 10:49:41 +01:00
# include <system/platform_boilerplate.h>
2019-09-11 20:50:28 +02:00
# include <helpers/MKLDNNStream.h>
# include "mkldnnUtils.h"
# include <ops/declarable/helpers/convolutions.h>
2019-11-20 11:23:08 +01:00
using namespace dnnl ;
2019-09-11 20:50:28 +02:00
2020-03-02 10:49:41 +01:00
namespace sd {
2019-11-03 11:37:19 +01:00
namespace ops {
namespace platforms {
2020-02-06 19:12:54 +01:00
//////////////////////////////////////////////////////////////////////
static void conv3dMKLDNN ( const NDArray * input , const NDArray * weights ,
const NDArray * bias , NDArray * output ,
const int kD , const int kH , const int kW ,
const int sD , const int sH , const int sW ,
const int pD , const int pH , const int pW ,
const int dD , const int dH , const int dW ,
2020-03-20 10:11:27 +01:00
const int paddingMode , const int isNCDHW , const int wFormat ) {
2020-02-06 19:12:54 +01:00
2020-03-20 10:11:27 +01:00
// mkl support weights in [oC, iC, kD, kH, kW] format only
2020-02-06 19:12:54 +01:00
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
2020-03-20 10:11:27 +01:00
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , wFormat , * input , * output , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWiC , indWoC , indWkD ) ;
2020-02-06 19:12:54 +01:00
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
dnnl : : memory : : dims strides = { sD , sH , sW } ;
dnnl : : memory : : dims padding = { pD , pH , pW } ;
// dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
dnnl : : memory : : dims padding_r = { ( oD - 1 ) * sD - iD + kD - pD , ( oH - 1 ) * sH - iH + kH - pH , ( oW - 1 ) * sW - iW + kW - pW } ;
dnnl : : memory : : dims dilation = { dD - 1 , dH - 1 , dW - 1 } ;
2020-03-20 10:11:27 +01:00
auto xzFormatMkl = isNCDHW ? dnnl : : memory : : format_tag : : ncdhw : dnnl : : memory : : format_tag : : ndhwc ;
dnnl : : memory : : format_tag wFormatMkl = dnnl : : memory : : format_tag : : oidhw ;
2020-02-06 19:12:54 +01:00
dnnl : : memory : : dims xDims = { bS , iC , iD , iH , iW } ;
dnnl : : memory : : dims wDims = { oC , iC , kD , kH , kW } ;
dnnl : : memory : : dims zDims = { bS , oC , oD , oH , oW } ;
2020-05-12 06:47:09 +02:00
std : : vector < int > permut ;
if ( 0 = = wFormat )
permut = { 4 , 3 , 0 , 1 , 2 } ; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
else if ( 2 = = wFormat )
permut = { 0 , 4 , 1 , 2 , 3 } ; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
2020-02-06 19:12:54 +01:00
auto type = dnnl : : memory : : data_type : : f32 ;
// memory descriptors for arrays
// input
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( xDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( xDims , type , xzFormatMkl ) ;
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * input , x_user_md ) ;
2020-02-06 19:12:54 +01:00
// weights
dnnl : : memory : : desc w_mkl_md = dnnl : : memory : : desc ( wDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc w_user_md = dnnl : : memory : : desc ( wDims , type , wFormatMkl ) ;
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * weights , w_user_md , permut ) ;
2020-02-06 19:12:54 +01:00
// bias
dnnl : : memory : : desc b_mkl_md ;
if ( bias ! = nullptr )
b_mkl_md = dnnl : : memory : : desc ( { oC } , type , dnnl : : memory : : format_tag : : x ) ;
// output
dnnl : : memory : : desc z_mkl_md = dnnl : : memory : : desc ( zDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc z_user_md = dnnl : : memory : : desc ( zDims , type , xzFormatMkl ) ;
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * output , z_user_md ) ;
2020-02-06 19:12:54 +01:00
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// operation primitive description
dnnl : : convolution_forward : : desc op_desc ( dnnl : : prop_kind : : forward_inference , dnnl : : algorithm : : convolution_auto , x_mkl_md , w_mkl_md , b_mkl_md , z_mkl_md , strides , dilation , padding , padding_r ) ;
dnnl : : convolution_forward : : primitive_desc op_prim_desc ( op_desc , engine ) ;
// arguments (memory buffers) necessary for calculations
std : : unordered_map < int , dnnl : : memory > args ;
dnnl : : stream stream ( engine ) ;
// provide memory buffers and check whether reorder is required
// input
2020-05-12 06:47:09 +02:00
mkldnnUtils : : loadDataToMklStream ( * input , engine , stream , x_user_md , op_prim_desc . src_desc ( ) , args [ DNNL_ARG_SRC ] ) ;
2020-02-06 19:12:54 +01:00
// weights
2020-05-12 06:47:09 +02:00
mkldnnUtils : : loadDataToMklStream ( * weights , engine , stream , w_user_md , op_prim_desc . weights_desc ( ) , args [ DNNL_ARG_WEIGHTS ] ) ;
2020-03-20 10:11:27 +01:00
2020-02-06 19:12:54 +01:00
// bias
if ( bias ! = nullptr ) {
2020-05-09 07:06:14 +02:00
auto b_mkl_mem = dnnl : : memory ( b_mkl_md , engine , const_cast < void * > ( bias - > buffer ( ) ) ) ;
2020-02-06 19:12:54 +01:00
args [ DNNL_ARG_BIAS ] = b_mkl_mem ;
}
// output
2020-05-12 06:47:09 +02:00
auto z_user_mem = mkldnnUtils : : loadDataToMklStream ( * output , engine , stream , z_user_md , op_prim_desc . dst_desc ( ) , args [ DNNL_ARG_DST ] ) ;
2020-02-06 19:12:54 +01:00
// run calculations
dnnl : : convolution_forward ( op_prim_desc ) . execute ( stream , args ) ;
// reorder outputs if necessary
2020-05-12 06:47:09 +02:00
if ( op_prim_desc . dst_desc ( ) ! = z_user_mem . get_desc ( ) )
dnnl : : reorder ( args [ DNNL_ARG_DST ] , z_user_mem ) . execute ( stream , args [ DNNL_ARG_DST ] , z_user_mem ) ;
2020-02-06 19:12:54 +01:00
stream . wait ( ) ;
}
//////////////////////////////////////////////////////////////////////
static void conv3dBpMKLDNN ( const NDArray * input , const NDArray * weights , const NDArray * bias , const NDArray * gradO ,
NDArray * gradI , NDArray * gradW , NDArray * gradB ,
const int kD , const int kH , const int kW ,
const int sD , const int sH , const int sW ,
const int pD , const int pH , const int pW ,
const int dD , const int dH , const int dW ,
2020-03-20 10:11:27 +01:00
const int paddingMode , const int isNCDHW , const int wFormat ) {
2020-02-06 19:12:54 +01:00
2020-03-20 10:11:27 +01:00
// mkl support weights/gradW in [oC, iC, kD, kH, kW] format only
2020-02-06 19:12:54 +01:00
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
2020-03-20 10:11:27 +01:00
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , wFormat , * input , * gradO , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWiC , indWoC , indWkD ) ;
2020-02-06 19:12:54 +01:00
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
dnnl : : memory : : dims strides = { sD , sH , sW } ;
dnnl : : memory : : dims padding = { pD , pH , pW } ;
// dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
dnnl : : memory : : dims padding_r = { ( oD - 1 ) * sD - iD + kD - pD , ( oH - 1 ) * sH - iH + kH - pH , ( oW - 1 ) * sW - iW + kW - pW } ;
dnnl : : memory : : dims dilation = { dD - 1 , dH - 1 , dW - 1 } ;
2020-03-20 10:11:27 +01:00
auto xzFormatMkl = isNCDHW ? dnnl : : memory : : format_tag : : ncdhw : dnnl : : memory : : format_tag : : ndhwc ;
dnnl : : memory : : format_tag wFormatMkl = dnnl : : memory : : format_tag : : oidhw ;
2020-02-06 19:12:54 +01:00
dnnl : : memory : : dims xDims = { bS , iC , iD , iH , iW } ;
dnnl : : memory : : dims wDims = { oC , iC , kD , kH , kW } ;
dnnl : : memory : : dims zDims = { bS , oC , oD , oH , oW } ;
auto type = dnnl : : memory : : data_type : : f32 ;
2020-05-12 06:47:09 +02:00
std : : vector < int > permut ;
if ( 0 = = wFormat )
permut = { 4 , 3 , 0 , 1 , 2 } ; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
else if ( 2 = = wFormat )
permut = { 0 , 4 , 1 , 2 , 3 } ; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
2020-02-06 19:12:54 +01:00
// memory descriptors for arrays
// input
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( xDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( xDims , type , xzFormatMkl ) ;
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * input , x_user_md ) ;
2020-02-06 19:12:54 +01:00
// weights
dnnl : : memory : : desc w_mkl_md = dnnl : : memory : : desc ( wDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc w_user_md = dnnl : : memory : : desc ( wDims , type , wFormatMkl ) ;
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * weights , w_user_md , permut ) ;
2020-02-06 19:12:54 +01:00
// gradO
dnnl : : memory : : desc gradO_mkl_md = dnnl : : memory : : desc ( zDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc gradO_user_md = dnnl : : memory : : desc ( zDims , type , xzFormatMkl ) ;
2020-03-12 16:25:29 +01:00
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * gradO , gradO_user_md ) ;
2020-02-06 19:12:54 +01:00
// gradI
dnnl : : memory : : desc gradI_mkl_md = dnnl : : memory : : desc ( xDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc gradI_user_md = dnnl : : memory : : desc ( xDims , type , xzFormatMkl ) ;
2020-03-12 16:25:29 +01:00
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * gradI , gradI_user_md ) ;
2020-02-06 19:12:54 +01:00
// gradW
dnnl : : memory : : desc gradW_mkl_md = dnnl : : memory : : desc ( wDims , type , dnnl : : memory : : format_tag : : any ) ;
2020-03-20 10:11:27 +01:00
dnnl : : memory : : desc gradW_user_md = dnnl : : memory : : desc ( wDims , type , wFormatMkl ) ;
2020-05-12 06:47:09 +02:00
mkldnnUtils : : setBlockStrides ( * gradW , gradW_user_md , permut ) ;
2020-02-06 19:12:54 +01:00
// gradB
dnnl : : memory : : desc gradB_mkl_md ;
if ( gradB ! = nullptr )
gradB_mkl_md = dnnl : : memory : : desc ( { oC } , type , dnnl : : memory : : format_tag : : x ) ;
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// forward primitive description
dnnl : : convolution_forward : : desc op_ff_desc ( dnnl : : prop_kind : : forward_inference , dnnl : : algorithm : : convolution_auto , x_mkl_md , w_mkl_md , gradB_mkl_md , gradO_mkl_md , strides , dilation , padding , padding_r ) ;
dnnl : : convolution_forward : : primitive_desc op_ff_prim_desc ( op_ff_desc , engine ) ;
// backward data primitive description
dnnl : : convolution_backward_data : : desc op_data_bp_desc ( dnnl : : algorithm : : convolution_auto , gradI_mkl_md , w_mkl_md , gradO_mkl_md , strides , dilation , padding , padding_r ) ;
dnnl : : convolution_backward_data : : primitive_desc op_data_bp_prim_desc ( op_data_bp_desc , engine , op_ff_prim_desc ) ;
// backward weights primitive description
dnnl : : convolution_backward_weights : : desc op_weights_bp_desc ( dnnl : : algorithm : : convolution_auto , x_mkl_md , gradW_mkl_md , gradB_mkl_md , gradO_mkl_md , strides , dilation , padding , padding_r ) ;
dnnl : : convolution_backward_weights : : primitive_desc op_weights_bp_prim_desc ( op_weights_bp_desc , engine , op_ff_prim_desc ) ;
// arguments (memory buffers) necessary for calculations
std : : unordered_map < int , dnnl : : memory > args ;
dnnl : : stream stream ( engine ) ;
// provide memory buffers and check whether reorder is required
// input
2020-05-12 06:47:09 +02:00
mkldnnUtils : : loadDataToMklStream ( * input , engine , stream , x_user_md , op_weights_bp_prim_desc . src_desc ( ) , args [ DNNL_ARG_SRC ] ) ;
2020-02-06 19:12:54 +01:00
// weights
2020-05-12 06:47:09 +02:00
mkldnnUtils : : loadDataToMklStream ( * weights , engine , stream , w_user_md , op_data_bp_prim_desc . weights_desc ( ) , args [ DNNL_ARG_WEIGHTS ] ) ;
2020-02-06 19:12:54 +01:00
// gradO
2020-05-09 07:06:14 +02:00
auto gradO_user_mem = dnnl : : memory ( gradO_user_md , engine , const_cast < void * > ( gradO - > buffer ( ) ) ) ;
2020-02-06 19:12:54 +01:00
const bool gradOReorderW = op_weights_bp_prim_desc . diff_dst_desc ( ) ! = gradO_user_mem . get_desc ( ) ;
const bool gradOReorderD = op_data_bp_prim_desc . diff_dst_desc ( ) ! = gradO_user_mem . get_desc ( ) ;
auto gradO_mkl_memW = gradOReorderW ? dnnl : : memory ( op_weights_bp_prim_desc . diff_dst_desc ( ) , engine ) : gradO_user_mem ;
auto gradO_mkl_memD = gradOReorderD ? dnnl : : memory ( op_data_bp_prim_desc . diff_dst_desc ( ) , engine ) : gradO_user_mem ;
if ( gradOReorderW )
dnnl : : reorder ( gradO_user_mem , gradO_mkl_memW ) . execute ( stream , gradO_user_mem , gradO_mkl_memW ) ;
if ( gradOReorderD )
dnnl : : reorder ( gradO_user_mem , gradO_mkl_memD ) . execute ( stream , gradO_user_mem , gradO_mkl_memD ) ;
args [ DNNL_ARG_DIFF_DST ] = gradO_mkl_memD ;
// gradI
2020-05-12 06:47:09 +02:00
auto gradI_user_mem = mkldnnUtils : : loadDataToMklStream ( * gradI , engine , stream , gradI_user_md , op_data_bp_prim_desc . diff_src_desc ( ) , args [ DNNL_ARG_DIFF_SRC ] ) ;
2020-02-06 19:12:54 +01:00
// gradW
2020-05-12 06:47:09 +02:00
auto gradW_user_mem = mkldnnUtils : : loadDataToMklStream ( * gradW , engine , stream , gradW_user_md , op_weights_bp_prim_desc . diff_weights_desc ( ) , args [ DNNL_ARG_DIFF_WEIGHTS ] ) ;
2020-02-06 19:12:54 +01:00
// gradB
if ( gradB ! = nullptr ) {
2020-05-09 07:06:14 +02:00
auto gradB_mkl_mem = dnnl : : memory ( gradB_mkl_md , engine , gradB - > buffer ( ) ) ;
2020-02-06 19:12:54 +01:00
args [ DNNL_ARG_DIFF_BIAS ] = gradB_mkl_mem ;
}
// run backward data calculations
dnnl : : convolution_backward_data ( op_data_bp_prim_desc ) . execute ( stream , args ) ;
if ( gradOReorderW | | gradOReorderD )
args [ DNNL_ARG_DIFF_DST ] = gradO_mkl_memW ;
// run backward weights calculations
dnnl : : convolution_backward_weights ( op_weights_bp_prim_desc ) . execute ( stream , args ) ;
// reorder gradI if necessary
2020-05-12 06:47:09 +02:00
if ( op_data_bp_prim_desc . diff_src_desc ( ) ! = gradI_user_mem . get_desc ( ) )
dnnl : : reorder ( args [ DNNL_ARG_DIFF_SRC ] , gradI_user_mem ) . execute ( stream , args [ DNNL_ARG_DIFF_SRC ] , gradI_user_mem ) ;
if ( op_weights_bp_prim_desc . diff_weights_desc ( ) ! = gradW_user_mem . get_desc ( ) )
dnnl : : reorder ( args [ DNNL_ARG_DIFF_WEIGHTS ] , gradW_user_mem ) . execute ( stream , args [ DNNL_ARG_DIFF_WEIGHTS ] , gradW_user_mem ) ;
2020-02-06 19:12:54 +01:00
stream . wait ( ) ;
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
/*
2019-11-03 11:37:19 +01:00
//////////////////////////////////////////////////////////////////////
2020-03-02 10:49:41 +01:00
static void conv3dMKLDNN ( sd : : graph : : Context & block ,
2020-01-28 16:23:07 +01:00
const NDArray * input , const NDArray * weights , const NDArray * bias ,
NDArray * output ,
const int kD , const int kH , const int kW , const int sD , const int sH , const int sW , int pD , int pH , int pW , const int dD , const int dH , const int dW ,
const int paddingMode , const int isNCDHW ) {
2019-11-03 11:37:19 +01:00
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
2020-01-28 16:23:07 +01:00
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , * input , * output , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWiC , indWoC , indWkD ) ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl_memory_desc_t empty ;
2020-01-28 16:23:07 +01:00
dnnl : : memory : : desc conv_src_md ( empty ) , conv_weights_md ( empty ) , conv_bias_md ( empty ) , conv_dst_md ( empty ) ;
dnnl : : memory : : desc user_src_md ( empty ) , user_weights_md ( empty ) , user_bias_md ( empty ) , user_dst_md ( empty ) ;
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims conv_strides , conv_padding , conv_padding_r , conv_dilation ;
2020-01-28 16:23:07 +01:00
mkldnnUtils : : getMKLDNNMemoryDescConv3d ( kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW , paddingMode ,
2019-11-03 11:37:19 +01:00
isNCDHW ,
bS , iC , iD , iH , iW , oC , oD , oH , oW , input , nullptr , weights ,
nullptr , bias , output ,
& conv_src_md , nullptr , & conv_weights_md , nullptr ,
& conv_bias_md , & conv_dst_md ,
& user_src_md , nullptr , & user_weights_md , nullptr ,
& user_bias_md , & user_dst_md ,
conv_strides , conv_padding , conv_padding_r , conv_dilation ) ;
2020-01-28 16:23:07 +01:00
auto conv_desc = bias ! = nullptr ? convolution_forward : : desc ( prop_kind : : forward , algorithm : : convolution_auto , conv_src_md , conv_weights_md , conv_bias_md , conv_dst_md , conv_strides , conv_dilation , conv_padding , conv_padding_r )
: convolution_forward : : desc ( prop_kind : : forward , algorithm : : convolution_auto , conv_src_md , conv_weights_md , conv_dst_md , conv_strides , conv_dilation , conv_padding , conv_padding_r ) ;
2019-11-03 11:37:19 +01:00
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
2019-11-20 11:23:08 +01:00
dnnl : : stream stream ( engine ) ;
2020-01-28 16:23:07 +01:00
2019-11-03 11:37:19 +01:00
auto conv_prim_desc = convolution_forward : : primitive_desc ( conv_desc , engine ) ;
2019-11-20 11:23:08 +01:00
auto user_src_memory = dnnl : : memory ( user_src_md , engine , const_cast < NDArray * > ( input ) - > buffer ( ) ) ;
2020-01-28 16:23:07 +01:00
auto user_weights_memory = dnnl : : memory ( user_weights_md , engine , const_cast < NDArray * > ( weights ) - > buffer ( ) ) ;
2019-11-20 11:23:08 +01:00
auto user_dst_memory = dnnl : : memory ( user_dst_md , engine , output - > buffer ( ) ) ;
2020-01-28 16:23:07 +01:00
2019-11-03 11:37:19 +01:00
auto conv_src_memory = user_src_memory ;
if ( conv_prim_desc . src_desc ( ) ! = user_src_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
conv_src_memory = dnnl : : memory ( conv_prim_desc . src_desc ( ) , engine ) ;
2019-11-03 11:37:19 +01:00
reorder ( user_src_memory , conv_src_memory ) . execute ( stream , user_src_memory , conv_src_memory ) ;
}
2020-01-28 16:23:07 +01:00
2019-11-03 11:37:19 +01:00
auto conv_weights_memory = user_weights_memory ;
if ( conv_prim_desc . weights_desc ( ) ! = user_weights_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
conv_weights_memory = dnnl : : memory ( conv_prim_desc . weights_desc ( ) , engine ) ;
2020-01-28 16:23:07 +01:00
reorder ( user_weights_memory , conv_weights_memory ) . execute ( stream , user_weights_memory , conv_weights_memory ) ;
2019-11-03 11:37:19 +01:00
}
2020-01-28 16:23:07 +01:00
2019-11-03 11:37:19 +01:00
auto conv_dst_memory = user_dst_memory ;
if ( conv_prim_desc . dst_desc ( ) ! = user_dst_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
conv_dst_memory = dnnl : : memory ( conv_prim_desc . dst_desc ( ) , engine ) ;
2019-11-03 11:37:19 +01:00
}
2020-01-28 16:23:07 +01:00
2019-11-03 11:37:19 +01:00
if ( bias ! = nullptr ) {
2020-05-09 07:06:14 +02:00
auto conv_bias_memory = dnnl : : memory ( conv_prim_desc . bias_desc ( ) , engine , bias - > buffer ( ) ) ;
2019-11-20 11:23:08 +01:00
convolution_forward ( conv_prim_desc ) . execute ( stream , { { DNNL_ARG_SRC , conv_src_memory } ,
{ DNNL_ARG_WEIGHTS , conv_weights_memory } ,
{ DNNL_ARG_BIAS , conv_bias_memory } ,
{ DNNL_ARG_DST , conv_dst_memory } } ) ;
2020-01-28 16:23:07 +01:00
}
else {
2019-11-20 11:23:08 +01:00
convolution_forward ( conv_prim_desc ) . execute ( stream , { { DNNL_ARG_SRC , conv_src_memory } ,
{ DNNL_ARG_WEIGHTS , conv_weights_memory } ,
{ DNNL_ARG_DST , conv_dst_memory } } ) ;
2019-11-03 11:37:19 +01:00
}
2020-01-28 16:23:07 +01:00
if ( conv_prim_desc . dst_desc ( ) ! = user_dst_memory . get_desc ( ) )
reorder ( conv_dst_memory , user_dst_memory ) . execute ( stream , conv_dst_memory , user_dst_memory ) ;
2019-11-03 11:37:19 +01:00
2020-01-28 16:23:07 +01:00
stream . wait ( ) ;
2019-11-03 11:37:19 +01:00
}
//////////////////////////////////////////////////////////////////////
2020-03-02 10:49:41 +01:00
static void conv3dBpMKLDNN ( sd : : graph : : Context & block ,
2020-01-28 16:23:07 +01:00
const NDArray * input , const NDArray * weights , const NDArray * bias , const NDArray * gradO ,
NDArray * gradI , NDArray * gradW , NDArray * gradB ,
const int kD , const int kH , const int kW , const int sD , const int sH , const int sW , int pD , int pH , int pW , const int dD , const int dH , const int dW ,
const int paddingMode , const int isNCDHW ) {
2019-11-03 11:37:19 +01:00
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
2020-01-28 16:23:07 +01:00
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , * input , * gradO , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWiC , indWoC , indWkD ) ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl_memory_desc_t empty ;
2020-01-28 16:23:07 +01:00
dnnl : : memory : : desc conv_src_md ( empty ) , conv_diff_src_md ( empty ) , conv_weights_md ( empty ) , conv_diff_weights_md ( empty ) , conv_bias_md ( empty ) , conv_dst_md ( empty ) ;
dnnl : : memory : : desc user_src_md ( empty ) , user_diff_src_md ( empty ) , user_weights_md ( empty ) , user_diff_weights_md ( empty ) , user_bias_md ( empty ) , user_dst_md ( empty ) ;
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims conv_strides , conv_padding , conv_padding_r , conv_dilation ;
2020-01-28 16:23:07 +01:00
mkldnnUtils : : getMKLDNNMemoryDescConv3d ( kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW , paddingMode ,
isNCDHW ,
2019-11-03 11:37:19 +01:00
bS , iC , iD , iH , iW , oC , oD , oH , oW , input , gradI , weights ,
gradW , gradB , gradO ,
& conv_src_md , & conv_diff_src_md , & conv_weights_md ,
& conv_diff_weights_md , & conv_bias_md , & conv_dst_md ,
& user_src_md , & user_diff_src_md , & user_weights_md ,
& user_diff_weights_md , & user_bias_md , & user_dst_md ,
conv_strides , conv_padding , conv_padding_r , conv_dilation ) ;
2020-01-28 16:23:07 +01:00
auto conv_desc = gradB ! = nullptr ? convolution_forward : : desc ( prop_kind : : forward , algorithm : : convolution_auto , conv_src_md , conv_weights_md , conv_bias_md , conv_dst_md , conv_strides , conv_dilation , conv_padding , conv_padding_r )
: convolution_forward : : desc ( prop_kind : : forward , algorithm : : convolution_auto , conv_src_md , conv_weights_md , conv_dst_md , conv_strides , conv_dilation , conv_padding , conv_padding_r ) ;
auto conv_prim_desc = convolution_forward : : primitive_desc ( conv_desc , mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ) ;
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
dnnl : : stream stream ( engine ) ;
2019-11-03 11:37:19 +01:00
if ( gradW ! = nullptr ) {
2020-01-28 16:23:07 +01:00
auto convW_desc = gradB ! = nullptr ? convolution_backward_weights : : desc ( algorithm : : convolution_auto , conv_src_md , conv_diff_weights_md , conv_bias_md , conv_dst_md , conv_strides , conv_dilation , conv_padding , conv_padding_r )
: convolution_backward_weights : : desc ( algorithm : : convolution_auto , conv_src_md , conv_diff_weights_md , conv_dst_md , conv_strides , conv_dilation , conv_padding , conv_padding_r ) ; auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
auto convW_prim_desc = convolution_backward_weights : : primitive_desc ( convW_desc , engine , conv_prim_desc ) ;
auto userW_src_memory = dnnl : : memory ( user_src_md , engine , const_cast < NDArray * > ( input ) - > buffer ( ) ) ;
2019-11-20 11:23:08 +01:00
auto userW_weights_memory = dnnl : : memory ( user_diff_weights_md , engine , gradW - > buffer ( ) ) ;
2020-01-28 16:23:07 +01:00
auto userW_dst_memory = dnnl : : memory ( user_dst_md , engine , const_cast < NDArray * > ( gradO ) - > buffer ( ) ) ;
2019-11-03 11:37:19 +01:00
auto convW_src_memory = userW_src_memory ;
if ( convW_prim_desc . src_desc ( ) ! = userW_src_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
convW_src_memory = dnnl : : memory ( convW_prim_desc . src_desc ( ) , engine ) ;
2020-01-28 16:23:07 +01:00
reorder ( userW_src_memory , convW_src_memory ) . execute ( stream , userW_src_memory , convW_src_memory ) ;
2019-09-11 20:50:28 +02:00
}
2019-11-03 11:37:19 +01:00
auto convW_weights_memory = userW_weights_memory ;
if ( convW_prim_desc . diff_weights_desc ( ) ! = userW_weights_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
convW_weights_memory = dnnl : : memory ( convW_prim_desc . diff_weights_desc ( ) , engine ) ;
2019-11-03 11:37:19 +01:00
}
auto convW_dst_memory = userW_dst_memory ;
if ( convW_prim_desc . diff_dst_desc ( ) ! = userW_dst_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
convW_dst_memory = dnnl : : memory ( convW_prim_desc . diff_dst_desc ( ) , engine ) ;
2020-01-28 16:23:07 +01:00
reorder ( userW_dst_memory , convW_dst_memory ) . execute ( stream , userW_dst_memory , convW_dst_memory ) ;
2019-11-03 11:37:19 +01:00
}
if ( gradB ! = nullptr ) {
2020-01-28 16:23:07 +01:00
auto convW_bias_memory = dnnl : : memory ( convW_prim_desc . diff_bias_desc ( ) , engine , gradB - > buffer ( ) ) ;
2019-11-03 11:37:19 +01:00
convolution_backward_weights ( convW_prim_desc ) . execute ( stream ,
2019-11-20 11:23:08 +01:00
{ { DNNL_ARG_SRC , convW_src_memory } ,
{ DNNL_ARG_DIFF_DST , convW_dst_memory } ,
{ DNNL_ARG_DIFF_WEIGHTS , convW_weights_memory } ,
{ DNNL_ARG_DIFF_BIAS , convW_bias_memory } } ) ;
2020-01-28 16:23:07 +01:00
}
else {
2019-11-03 11:37:19 +01:00
convolution_backward_weights ( convW_prim_desc ) . execute ( stream ,
2019-11-20 11:23:08 +01:00
{ { DNNL_ARG_SRC , convW_src_memory } ,
{ DNNL_ARG_DIFF_DST , convW_dst_memory } ,
{ DNNL_ARG_DIFF_WEIGHTS , convW_weights_memory } } ) ;
2019-11-03 11:37:19 +01:00
}
2020-01-28 16:23:07 +01:00
if ( convW_prim_desc . diff_weights_desc ( ) ! = userW_weights_memory . get_desc ( ) )
reorder ( convW_weights_memory , userW_weights_memory ) . execute ( stream , convW_weights_memory , userW_weights_memory ) ;
2019-11-03 11:37:19 +01:00
stream . wait ( ) ;
2019-09-11 20:50:28 +02:00
}
2019-11-03 11:37:19 +01:00
if ( gradI ! = nullptr ) {
2020-01-28 16:23:07 +01:00
auto convI_desc = convolution_backward_data : : desc ( algorithm : : convolution_auto , conv_diff_src_md , conv_weights_md , conv_dst_md , conv_strides , conv_dilation , conv_padding , conv_padding_r ) ;
auto convI_prim_desc = convolution_backward_data : : primitive_desc ( convI_desc , engine , conv_prim_desc ) ;
2019-11-20 11:23:08 +01:00
auto userI_src_memory = dnnl : : memory ( user_diff_src_md , engine , gradI - > buffer ( ) ) ;
2020-01-28 16:23:07 +01:00
auto userI_weights_memory = dnnl : : memory ( user_weights_md , engine , const_cast < NDArray * > ( weights ) - > buffer ( ) ) ;
auto userI_dst_memory = dnnl : : memory ( user_dst_md , engine , const_cast < NDArray * > ( gradO ) - > buffer ( ) ) ;
2019-11-03 11:37:19 +01:00
auto convI_src_memory = userI_src_memory ;
2020-01-28 16:23:07 +01:00
if ( convI_prim_desc . diff_src_desc ( ) ! = userI_src_memory . get_desc ( ) )
2019-11-20 11:23:08 +01:00
convI_src_memory = dnnl : : memory ( convI_prim_desc . diff_src_desc ( ) , engine ) ;
2019-11-03 11:37:19 +01:00
auto convI_weights_memory = userI_weights_memory ;
if ( convI_prim_desc . weights_desc ( ) ! = userI_weights_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
convI_weights_memory = dnnl : : memory ( convI_prim_desc . weights_desc ( ) , engine ) ;
2020-01-28 16:23:07 +01:00
reorder ( userI_weights_memory , convI_weights_memory ) . execute ( stream , userI_weights_memory , convI_weights_memory ) ;
2019-11-03 11:37:19 +01:00
}
auto convI_dst_memory = userI_dst_memory ;
if ( convI_prim_desc . diff_dst_desc ( ) ! = userI_dst_memory . get_desc ( ) ) {
2019-11-20 11:23:08 +01:00
convI_dst_memory = dnnl : : memory ( convI_prim_desc . diff_dst_desc ( ) , engine ) ;
2020-01-28 16:23:07 +01:00
reorder ( userI_dst_memory , convI_dst_memory ) . execute ( stream , userI_dst_memory , convI_dst_memory ) ;
2019-11-03 11:37:19 +01:00
}
convolution_backward_data ( convI_prim_desc ) . execute ( stream ,
2019-11-20 11:23:08 +01:00
{ { DNNL_ARG_DIFF_DST , convI_dst_memory } ,
{ DNNL_ARG_WEIGHTS , convI_weights_memory } ,
{ DNNL_ARG_DIFF_SRC , convI_src_memory } } ) ;
2019-11-03 11:37:19 +01:00
2020-01-28 16:23:07 +01:00
if ( convI_prim_desc . diff_src_desc ( ) ! = userI_src_memory . get_desc ( ) )
reorder ( convI_src_memory , userI_src_memory ) . execute ( stream , convI_src_memory , userI_src_memory ) ;
2019-11-03 11:37:19 +01:00
}
2020-01-28 16:23:07 +01:00
}
2020-02-06 19:12:54 +01:00
*/
2020-01-28 16:23:07 +01:00
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL ( conv3dnew , ENGINE_CPU ) {
auto input = INPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
2020-03-20 10:11:27 +01:00
auto weights = INPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
2020-01-28 16:23:07 +01:00
auto bias = block . width ( ) > 2 ? INPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
auto output = OUTPUT_VARIABLE ( 0 ) ; // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE ( input - > rankOf ( ) = = 5 , 0 , " CUSTOM CONV3D MKLDNN OP: rank of input array must be equal to 5, but got %i instead ! " , input - > rankOf ( ) ) ;
REQUIRE_TRUE ( weights - > rankOf ( ) = = 5 , 0 , " CUSTOM CONV3D MKLDNN OP: rank of weights array must be equal to 5, but got %i instead ! " , weights - > rankOf ( ) ) ;
int kD = INT_ARG ( 0 ) > 0 ? INT_ARG ( 0 ) : static_cast < int > ( weights - > sizeAt ( 0 ) ) ; // filter(kernel) depth
int kH = INT_ARG ( 1 ) > 0 ? INT_ARG ( 1 ) : static_cast < int > ( weights - > sizeAt ( 1 ) ) ; // filter(kernel) height
int kW = INT_ARG ( 2 ) > 0 ? INT_ARG ( 2 ) : static_cast < int > ( weights - > sizeAt ( 2 ) ) ; // filter(kernel) width
int sD = INT_ARG ( 3 ) ; // strides depth
int sH = INT_ARG ( 4 ) ; // strides height
int sW = INT_ARG ( 5 ) ; // strides width
int pD = INT_ARG ( 6 ) ; // paddings depth
int pH = INT_ARG ( 7 ) ; // paddings height
int pW = INT_ARG ( 8 ) ; // paddings width
int dD = INT_ARG ( 9 ) ; // dilations depth
int dH = INT_ARG ( 10 ) ; // dilations height
int dW = INT_ARG ( 11 ) ; // dilations width
int paddingMode = INT_ARG ( 12 ) ; // 0-SAME, 1-VALID
int isNCDHW = block . getIArguments ( ) - > size ( ) > 13 ? ! INT_ARG ( 13 ) : 1 ; // INT_ARG(13): 1-NDHWC, 0-NCDHW
2020-03-20 10:11:27 +01:00
int wFormat = block . getIArguments ( ) - > size ( ) > 14 ? INT_ARG ( 14 ) : 0 ; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
2020-01-28 16:23:07 +01:00
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
2020-03-20 10:11:27 +01:00
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , wFormat , * input , * output , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWiC , indWoC , indWkD ) ;
2020-01-28 16:23:07 +01:00
2020-03-20 10:11:27 +01:00
std : : vector < Nd4jLong > expectedWeightsShape = ConvolutionUtils : : expectWeightsShape ( wFormat , kD , kH , kW , iC , oC ) ;
2020-02-06 19:12:54 +01:00
REQUIRE_TRUE ( weights - > isSameShape ( expectedWeightsShape ) , 0 , " CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedWeightsShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( weights ) . c_str ( ) ) ;
2020-01-28 16:23:07 +01:00
if ( bias )
REQUIRE_TRUE ( bias - > rankOf ( ) < = 2 & & oC = = bias - > lengthOf ( ) , 0 , " CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead ! " , oC , bias - > rankOf ( ) , bias - > lengthOf ( ) ) ;
if ( paddingMode ) // SAME
ConvolutionUtils : : calcPadding3D ( pD , pH , pW , oD , oH , oW , iD , iH , iW , kD , kH , kW , sD , sH , sW , dD , dH , dW ) ;
2020-03-20 10:11:27 +01:00
conv3dMKLDNN ( input , weights , bias , output , kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW , paddingMode , isNCDHW , wFormat ) ;
2020-01-28 16:23:07 +01:00
return Status : : OK ( ) ;
}
PLATFORM_CHECK ( conv3dnew , ENGINE_CPU ) {
auto input = INPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, iC, oC] always
auto bias = block . width ( ) > 2 ? INPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
auto output = OUTPUT_VARIABLE ( 0 ) ; // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
2020-03-02 10:49:41 +01:00
return block . isUseMKLDNN ( ) & & sd : : MKLDNNStream : : isSupported ( { input , weights , bias , output } ) ;
2020-01-28 16:23:07 +01:00
}
//////////////////////////////////////////////////////////////////////
PLATFORM_IMPL ( conv3dnew_bp , ENGINE_CPU ) {
2020-02-06 19:12:54 +01:00
2020-01-28 16:23:07 +01:00
auto input = INPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
2020-03-20 10:11:27 +01:00
auto weights = INPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
2020-01-28 16:23:07 +01:00
auto bias = block . width ( ) > 3 ? INPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
auto gradO = block . width ( ) > 3 ? INPUT_VARIABLE ( 3 ) : INPUT_VARIABLE ( 2 ) ; // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
2020-03-20 06:49:28 +01:00
auto gradI = OUTPUT_NULLIFIED ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
2020-03-20 10:11:27 +01:00
auto gradW = OUTPUT_NULLIFIED ( 1 ) ; // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
2020-03-20 06:49:28 +01:00
auto gradB = block . width ( ) > 3 ? OUTPUT_NULLIFIED ( 2 ) : nullptr ; // [oC]
2020-01-28 16:23:07 +01:00
REQUIRE_TRUE ( input - > rankOf ( ) = = 5 , 0 , " CUSTOM CONV3D_BP MKLDNN OP: rank of input array must be equal to 5, but got %i instead ! " , input - > rankOf ( ) ) ;
REQUIRE_TRUE ( weights - > rankOf ( ) = = 5 , 0 , " CUSTOM CONV3D_BP MKLDNN OP: rank of weights array must be equal to 5, but got %i instead ! " , weights - > rankOf ( ) ) ;
REQUIRE_TRUE ( gradO - > rankOf ( ) = = 5 , 0 , " CUSTOM CONV3D_BP MKLDNN OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead ! " , gradO - > rankOf ( ) ) ;
int kD = INT_ARG ( 0 ) > 0 ? INT_ARG ( 0 ) : static_cast < int > ( weights - > sizeAt ( 0 ) ) ; // filter(kernel) depth
int kH = INT_ARG ( 1 ) > 0 ? INT_ARG ( 1 ) : static_cast < int > ( weights - > sizeAt ( 1 ) ) ; // filter(kernel) height
int kW = INT_ARG ( 2 ) > 0 ? INT_ARG ( 2 ) : static_cast < int > ( weights - > sizeAt ( 2 ) ) ; // filter(kernel) width
int sD = INT_ARG ( 3 ) ; // strides depth
int sH = INT_ARG ( 4 ) ; // strides height
int sW = INT_ARG ( 5 ) ; // strides width
int pD = INT_ARG ( 6 ) ; // paddings depth
int pH = INT_ARG ( 7 ) ; // paddings height
int pW = INT_ARG ( 8 ) ; // paddings width
int dD = INT_ARG ( 9 ) ; // dilations depth
int dH = INT_ARG ( 10 ) ; // dilations height
int dW = INT_ARG ( 11 ) ; // dilations width
int paddingMode = INT_ARG ( 12 ) ; // 1-SAME, 0-VALID
int isNCDHW = block . getIArguments ( ) - > size ( ) > 13 ? ! INT_ARG ( 13 ) : 1 ; // INT_ARG(13): 1-NDHWC, 0-NCDHW
2020-03-20 10:11:27 +01:00
int wFormat = block . getIArguments ( ) - > size ( ) > 14 ? INT_ARG ( 14 ) : 0 ; // 0 - [kD, kH, kW, iC, oC], 1 - [oC, iC, kD, kH, kW], 2 - [oC, kD, kH, kW, iC]
2020-01-28 16:23:07 +01:00
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
2020-03-20 10:11:27 +01:00
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , wFormat , * input , * gradO , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWiC , indWoC , indWkD ) ;
2020-01-28 16:23:07 +01:00
if ( paddingMode ) // SAME
ConvolutionUtils : : calcPadding3D ( pD , pH , pW , oD , oH , oW , iD , iH , iW , kD , kH , kW , sD , sH , sW , dD , dH , dW ) ;
int trueoD , trueoH , trueoW ; // true output depth/height/width
ConvolutionUtils : : calcOutSizePool3D ( trueoD , trueoH , trueoW , kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW , iD , iH , iW , paddingMode ) ;
2020-02-06 19:12:54 +01:00
std : : vector < Nd4jLong > expectedGradOShape = ShapeUtils : : composeShapeUsingDimsAndIdx ( { bS , oC , trueoD , trueoH , trueoW , 0 , indIOioC , indIOioD , indIOioD + 1 , indIOioD + 2 } ) ;
2020-03-20 10:11:27 +01:00
std : : vector < Nd4jLong > expectedWeightsShape = ConvolutionUtils : : expectWeightsShape ( wFormat , kD , kH , kW , iC , oC ) ;
2020-02-06 19:12:54 +01:00
REQUIRE_TRUE ( gradO - > isSameShape ( expectedGradOShape ) , 0 , " CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedGradOShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( gradO ) . c_str ( ) ) ;
REQUIRE_TRUE ( weights - > isSameShape ( expectedWeightsShape ) , 0 , " CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedWeightsShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( weights ) . c_str ( ) ) ;
2020-01-28 16:23:07 +01:00
if ( bias )
REQUIRE_TRUE ( bias - > rankOf ( ) < = 2 & & oC = = bias - > lengthOf ( ) , 0 , " CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead ! " , oC , bias - > rankOf ( ) , bias - > lengthOf ( ) ) ;
2020-03-20 10:11:27 +01:00
conv3dBpMKLDNN ( input , weights , bias , gradO , gradI , gradW , gradB , kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW , paddingMode , isNCDHW , wFormat ) ;
2019-11-03 11:37:19 +01:00
return Status : : OK ( ) ;
}
2020-01-20 19:32:46 +01:00
PLATFORM_CHECK ( conv3dnew_bp , ENGINE_CPU ) {
2020-03-20 10:11:27 +01:00
auto input = INPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto bias = block . width ( ) > 3 ? INPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
auto gradO = block . width ( ) > 3 ? INPUT_VARIABLE ( 3 ) : INPUT_VARIABLE ( 2 ) ; // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
auto gradW = OUTPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, iC, oC], [oC, iC, kD, kH, kW], [oC, kD, kH, kW, iC]
auto gradB = block . width ( ) > 3 ? OUTPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
2019-11-03 11:37:19 +01:00
return block . isUseMKLDNN ( ) & &
2020-03-02 10:49:41 +01:00
sd : : MKLDNNStream : : isSupported ( { input , weights , bias , gradO , gradI , gradW , gradB } ) ;
2019-11-03 11:37:19 +01:00
}
}
}
2019-09-11 20:50:28 +02:00
}