2019-11-03 11:37:19 +01:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
# include <ops/declarable/PlatformHelper.h>
# include <ops/declarable/OpRegistrator.h>
# include <platform_boilerplate.h>
# include <helpers/MKLDNNStream.h>
# include "mkldnnUtils.h"
# include <ops/declarable/helpers/convolutions.h>
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv3dMKLDNN ( const NDArray * input , const NDArray * weights , const NDArray * bias , NDArray * output ,
const int kD , const int kH , const int kW , const int sD , const int sH , const int sW ,
2020-01-11 05:36:40 +01:00
const int pD , const int pH , const int pW , const int dD , const int dH , const int dW ) {
2019-11-03 11:37:19 +01:00
// input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
// bias [oC], may be nullptr
// output [bS, oD, oH, oW, oC] ncdhw, mkl doesn't support format ndhwc
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
ConvolutionUtils : : getSizesAndIndexesConv3d ( true , * input , * output , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWoC , indWiC , indWkD ) ;
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims strides = { sD , sH , sW } ;
dnnl : : memory : : dims padding = { pD , pH , pW } ;
2019-11-21 20:17:30 +01:00
dnnl : : memory : : dims padding_r = { ( iD - 1 ) * sD - oD + kD - pD , ( iH - 1 ) * sH - oH + kH - pH , ( iW - 1 ) * sW - oW + kW - pW } ;
dnnl : : memory : : dims dilation = { dD - 1 , dH - 1 , dW - 1 } ;
2019-11-03 11:37:19 +01:00
// input type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type xType ;
2019-11-03 11:37:19 +01:00
if ( input - > dataType ( ) = = DataType : : FLOAT32 )
2019-11-20 11:23:08 +01:00
xType = dnnl : : memory : : data_type : : f32 ;
2019-11-03 11:37:19 +01:00
else if ( input - > dataType ( ) = = DataType : : HALF )
2019-11-20 11:23:08 +01:00
xType = dnnl : : memory : : data_type : : f16 ;
2019-11-03 11:37:19 +01:00
else if ( input - > dataType ( ) = = DataType : : UINT8 )
2019-11-20 11:23:08 +01:00
xType = dnnl : : memory : : data_type : : u8 ;
2019-11-03 11:37:19 +01:00
else
2019-11-20 11:23:08 +01:00
xType = dnnl : : memory : : data_type : : s8 ;
2019-11-03 11:37:19 +01:00
// weights type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type wType = xType ;
if ( xType = = dnnl : : memory : : data_type : : u8 )
wType = dnnl : : memory : : data_type : : s8 ;
2019-11-03 11:37:19 +01:00
// output and bias type (have the same types)
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type zType ;
2019-11-03 11:37:19 +01:00
if ( output - > dataType ( ) = = DataType : : FLOAT32 )
2019-11-20 11:23:08 +01:00
zType = dnnl : : memory : : data_type : : f32 ;
2019-11-03 11:37:19 +01:00
else if ( output - > dataType ( ) = = DataType : : HALF )
2019-11-20 11:23:08 +01:00
zType = dnnl : : memory : : data_type : : f16 ;
2019-11-03 11:37:19 +01:00
else if ( output - > dataType ( ) = = DataType : : UINT8 )
2019-11-20 11:23:08 +01:00
zType = dnnl : : memory : : data_type : : u8 ;
2019-11-03 11:37:19 +01:00
else if ( output - > dataType ( ) = = DataType : : INT8 )
2019-11-20 11:23:08 +01:00
zType = dnnl : : memory : : data_type : : s8 ;
2019-11-03 11:37:19 +01:00
else
2019-11-20 11:23:08 +01:00
zType = dnnl : : memory : : data_type : : s32 ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl : : memory : : format_tag xFormat = dnnl : : memory : : format_tag : : ncdhw ;
dnnl : : memory : : format_tag wFormat = dnnl : : memory : : format_tag : : oidhw ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims xDims = { bS , iC , iD , iH , iW } ;
dnnl : : memory : : dims wDims = { oC , iC , kD , kH , kW } ;
dnnl : : memory : : dims zDims = { bS , oC , oD , oH , oW } ;
2019-11-03 11:37:19 +01:00
// memory descriptors for arrays
// input
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( xDims , xType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( xDims , xType , xFormat ) ;
x_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
x_user_md . data . format_desc . blocking . strides [ 0 ] = input - > stridesOf ( ) [ 0 ] ;
x_user_md . data . format_desc . blocking . strides [ 1 ] = input - > stridesOf ( ) [ 1 ] ;
x_user_md . data . format_desc . blocking . strides [ 2 ] = input - > stridesOf ( ) [ 2 ] ;
x_user_md . data . format_desc . blocking . strides [ 3 ] = input - > stridesOf ( ) [ 3 ] ;
x_user_md . data . format_desc . blocking . strides [ 4 ] = input - > stridesOf ( ) [ 4 ] ;
// weights
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc w_mkl_md = dnnl : : memory : : desc ( wDims , wType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc w_user_md = dnnl : : memory : : desc ( wDims , wType , wFormat ) ;
w_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
w_user_md . data . format_desc . blocking . strides [ 0 ] = weights - > stridesOf ( ) [ 0 ] ;
w_user_md . data . format_desc . blocking . strides [ 1 ] = weights - > stridesOf ( ) [ 1 ] ;
w_user_md . data . format_desc . blocking . strides [ 2 ] = weights - > stridesOf ( ) [ 2 ] ;
w_user_md . data . format_desc . blocking . strides [ 3 ] = weights - > stridesOf ( ) [ 3 ] ;
w_user_md . data . format_desc . blocking . strides [ 4 ] = weights - > stridesOf ( ) [ 4 ] ;
// bias
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc b_mkl_md ;
2019-11-03 11:37:19 +01:00
if ( bias ! = nullptr )
2019-11-20 11:23:08 +01:00
b_mkl_md = dnnl : : memory : : desc ( { oC } , zType , dnnl : : memory : : format_tag : : x ) ;
2019-11-03 11:37:19 +01:00
// output
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc z_mkl_md = dnnl : : memory : : desc ( zDims , zType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc z_user_md = dnnl : : memory : : desc ( zDims , zType , xFormat ) ;
z_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
z_user_md . data . format_desc . blocking . strides [ 0 ] = output - > stridesOf ( ) [ 0 ] ;
z_user_md . data . format_desc . blocking . strides [ 1 ] = output - > stridesOf ( ) [ 1 ] ;
z_user_md . data . format_desc . blocking . strides [ 2 ] = output - > stridesOf ( ) [ 2 ] ;
z_user_md . data . format_desc . blocking . strides [ 3 ] = output - > stridesOf ( ) [ 3 ] ;
z_user_md . data . format_desc . blocking . strides [ 4 ] = output - > stridesOf ( ) [ 4 ] ;
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// operation primitive description
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_forward : : desc op_desc ( dnnl : : prop_kind : : forward_inference , dnnl : : algorithm : : deconvolution_direct ,
2019-11-03 11:37:19 +01:00
x_mkl_md , w_mkl_md , b_mkl_md , z_mkl_md , strides , dilation , padding , padding_r ) ;
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_forward : : primitive_desc op_prim_desc ( op_desc , engine ) ;
2019-11-03 11:37:19 +01:00
// arguments (memory buffers) necessary for calculations
2019-11-20 11:23:08 +01:00
std : : unordered_map < int , dnnl : : memory > args ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl : : stream stream ( engine ) ;
2019-11-03 11:37:19 +01:00
// provide memory buffers and check whether reorder is required
// input
2019-11-20 11:23:08 +01:00
auto x_user_mem = dnnl : : memory ( x_user_md , engine , input - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool xReorder = op_prim_desc . src_desc ( ) ! = x_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto x_mkl_mem = xReorder ? dnnl : : memory ( op_prim_desc . src_desc ( ) , engine ) : x_user_mem ;
2019-11-03 11:37:19 +01:00
if ( xReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( x_user_mem , x_mkl_mem ) . execute ( stream , x_user_mem , x_mkl_mem ) ;
args [ DNNL_ARG_SRC ] = x_mkl_mem ;
2019-11-03 11:37:19 +01:00
// weights
2019-11-20 11:23:08 +01:00
auto w_user_mem = dnnl : : memory ( w_user_md , engine , weights - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool wReorder = op_prim_desc . weights_desc ( ) ! = w_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto w_mkl_mem = wReorder ? dnnl : : memory ( op_prim_desc . weights_desc ( ) , engine ) : w_user_mem ;
2019-11-03 11:37:19 +01:00
if ( wReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( w_user_mem , w_mkl_mem ) . execute ( stream , w_user_mem , w_mkl_mem ) ;
args [ DNNL_ARG_WEIGHTS ] = w_mkl_mem ;
2019-11-03 11:37:19 +01:00
// bias
if ( bias ! = nullptr ) {
2019-11-20 11:23:08 +01:00
auto b_mkl_mem = dnnl : : memory ( b_mkl_md , engine , bias - > getBuffer ( ) ) ;
args [ DNNL_ARG_BIAS ] = b_mkl_mem ;
2019-11-03 11:37:19 +01:00
}
// output
2019-11-20 11:23:08 +01:00
auto z_user_mem = dnnl : : memory ( z_user_md , engine , output - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool zReorder = op_prim_desc . dst_desc ( ) ! = z_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto z_mkl_mem = zReorder ? dnnl : : memory ( op_prim_desc . dst_desc ( ) , engine ) : z_user_mem ;
args [ DNNL_ARG_DST ] = z_mkl_mem ;
2019-11-03 11:37:19 +01:00
// run calculations
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_forward ( op_prim_desc ) . execute ( stream , args ) ;
2019-11-03 11:37:19 +01:00
// reorder outputs if necessary
if ( zReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( z_mkl_mem , z_user_mem ) . execute ( stream , z_mkl_mem , z_user_mem ) ;
2019-11-03 11:37:19 +01:00
stream . wait ( ) ;
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
static void deconv3dBackPropMKLDNN ( const NDArray * input , const NDArray * weights , const NDArray * gradO , NDArray * gradI , NDArray * gradW , NDArray * gradB ,
2020-01-11 05:36:40 +01:00
const int kD , const int kH , const int kW ,
const int sD , const int sH , const int sW ,
const int pD , const int pH , const int pW ,
const int dD , const int dH , const int dW ) {
2019-11-03 11:37:19 +01:00
// input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
// gradB [oC], may be nullptr
// gradO [bS, oD, oH, oW, oC]
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
ConvolutionUtils : : getSizesAndIndexesConv3d ( true , * input , * gradO , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWoC , indWiC , indWkD ) ;
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims strides = { sD , sH , sW } ;
dnnl : : memory : : dims padding = { pD , pH , pW } ;
2019-11-21 20:17:30 +01:00
dnnl : : memory : : dims padding_r = { ( iD - 1 ) * sD - oD + kD - pD , ( iH - 1 ) * sH - oH + kH - pH , ( iW - 1 ) * sW - oW + kW - pW } ;
dnnl : : memory : : dims dilation = { dD - 1 , dH - 1 , dW - 1 } ;
2019-11-03 11:37:19 +01:00
// input type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type xType = input - > dataType ( ) = = DataType : : FLOAT32 ? dnnl : : memory : : data_type : : f32 : dnnl : : memory : : data_type : : bf16 ;
2019-11-03 11:37:19 +01:00
// weights type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type wType = weights - > dataType ( ) = = DataType : : FLOAT32 ? dnnl : : memory : : data_type : : f32 : dnnl : : memory : : data_type : : bf16 ;
2019-11-03 11:37:19 +01:00
// gradO type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type gradOType = gradO - > dataType ( ) = = DataType : : FLOAT32 ? dnnl : : memory : : data_type : : f32 : dnnl : : memory : : data_type : : bf16 ;
2019-11-03 11:37:19 +01:00
// gradI type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type gradIType = gradI - > dataType ( ) = = DataType : : FLOAT32 ? dnnl : : memory : : data_type : : f32 : dnnl : : memory : : data_type : : bf16 ;
2019-11-03 11:37:19 +01:00
// gradW type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type gradWType = gradW - > dataType ( ) = = DataType : : FLOAT32 ? dnnl : : memory : : data_type : : f32 : dnnl : : memory : : data_type : : bf16 ;
2019-11-03 11:37:19 +01:00
// gradB type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type gradBType = gradB ! = nullptr ? ( gradB - > dataType ( ) = = DataType : : FLOAT32 ? dnnl : : memory : : data_type : : f32 : dnnl : : memory : : data_type : : bf16 ) : dnnl : : memory : : data_type : : f32 ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl : : memory : : format_tag xFormat = dnnl : : memory : : format_tag : : ncdhw ; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
dnnl : : memory : : format_tag wFormat = dnnl : : memory : : format_tag : : oidhw ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims xDims = { bS , iC , iD , iH , iW } ;
dnnl : : memory : : dims wDims = { oC , iC , kD , kH , kW } ;
dnnl : : memory : : dims zDims = { bS , oC , oD , oH , oW } ;
2019-11-03 11:37:19 +01:00
// memory descriptors for arrays
// input
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( xDims , xType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( xDims , xType , xFormat ) ;
x_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
x_user_md . data . format_desc . blocking . strides [ 0 ] = input - > stridesOf ( ) [ 0 ] ;
x_user_md . data . format_desc . blocking . strides [ 1 ] = input - > stridesOf ( ) [ 1 ] ;
x_user_md . data . format_desc . blocking . strides [ 2 ] = input - > stridesOf ( ) [ 2 ] ;
x_user_md . data . format_desc . blocking . strides [ 3 ] = input - > stridesOf ( ) [ 3 ] ;
x_user_md . data . format_desc . blocking . strides [ 4 ] = input - > stridesOf ( ) [ 4 ] ;
// weights
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc w_mkl_md = dnnl : : memory : : desc ( wDims , wType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc w_user_md = dnnl : : memory : : desc ( wDims , wType , wFormat ) ;
w_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
w_user_md . data . format_desc . blocking . strides [ 0 ] = weights - > stridesOf ( ) [ 0 ] ;
w_user_md . data . format_desc . blocking . strides [ 1 ] = weights - > stridesOf ( ) [ 1 ] ;
w_user_md . data . format_desc . blocking . strides [ 2 ] = weights - > stridesOf ( ) [ 2 ] ;
w_user_md . data . format_desc . blocking . strides [ 3 ] = weights - > stridesOf ( ) [ 3 ] ;
w_user_md . data . format_desc . blocking . strides [ 4 ] = weights - > stridesOf ( ) [ 4 ] ;
// gradO
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc gradO_mkl_md = dnnl : : memory : : desc ( zDims , gradOType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc gradO_user_md = dnnl : : memory : : desc ( zDims , gradOType , xFormat ) ;
gradO_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
gradO_user_md . data . format_desc . blocking . strides [ 0 ] = gradO - > stridesOf ( ) [ 0 ] ;
gradO_user_md . data . format_desc . blocking . strides [ 1 ] = gradO - > stridesOf ( ) [ 1 ] ;
gradO_user_md . data . format_desc . blocking . strides [ 2 ] = gradO - > stridesOf ( ) [ 2 ] ;
gradO_user_md . data . format_desc . blocking . strides [ 3 ] = gradO - > stridesOf ( ) [ 3 ] ;
gradO_user_md . data . format_desc . blocking . strides [ 4 ] = gradO - > stridesOf ( ) [ 4 ] ;
// gradI
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc gradI_mkl_md = dnnl : : memory : : desc ( xDims , gradIType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc gradI_user_md = dnnl : : memory : : desc ( xDims , gradIType , xFormat ) ;
gradI_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
gradI_user_md . data . format_desc . blocking . strides [ 0 ] = gradI - > stridesOf ( ) [ 0 ] ;
gradI_user_md . data . format_desc . blocking . strides [ 1 ] = gradI - > stridesOf ( ) [ 1 ] ;
gradI_user_md . data . format_desc . blocking . strides [ 2 ] = gradI - > stridesOf ( ) [ 2 ] ;
gradI_user_md . data . format_desc . blocking . strides [ 3 ] = gradI - > stridesOf ( ) [ 3 ] ;
gradI_user_md . data . format_desc . blocking . strides [ 4 ] = gradI - > stridesOf ( ) [ 4 ] ;
// gradW
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc gradW_mkl_md = dnnl : : memory : : desc ( wDims , gradWType , wFormat ) ;
dnnl : : memory : : desc gradW_user_md = dnnl : : memory : : desc ( wDims , gradWType , wFormat ) ;
gradW_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-11-03 11:37:19 +01:00
gradW_user_md . data . format_desc . blocking . strides [ 0 ] = gradW - > stridesOf ( ) [ 0 ] ;
gradW_user_md . data . format_desc . blocking . strides [ 1 ] = gradW - > stridesOf ( ) [ 1 ] ;
gradW_user_md . data . format_desc . blocking . strides [ 2 ] = gradW - > stridesOf ( ) [ 2 ] ;
gradW_user_md . data . format_desc . blocking . strides [ 3 ] = gradW - > stridesOf ( ) [ 3 ] ;
gradW_user_md . data . format_desc . blocking . strides [ 4 ] = gradW - > stridesOf ( ) [ 4 ] ;
// gradB
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc gradB_mkl_md ;
2019-11-03 11:37:19 +01:00
if ( gradB ! = nullptr )
2019-11-20 11:23:08 +01:00
gradB_mkl_md = dnnl : : memory : : desc ( { oC } , gradBType , dnnl : : memory : : format_tag : : x ) ;
2019-11-03 11:37:19 +01:00
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// forward primitive description
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_forward : : desc op_ff_desc ( dnnl : : prop_kind : : forward_inference , dnnl : : algorithm : : deconvolution_direct , x_mkl_md , w_mkl_md , gradB_mkl_md , gradO_mkl_md , strides , dilation , padding , padding_r ) ;
dnnl : : deconvolution_forward : : primitive_desc op_ff_prim_desc ( op_ff_desc , engine ) ;
2019-11-03 11:37:19 +01:00
// backward data primitive description
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_backward_data : : desc op_data_bp_desc ( dnnl : : algorithm : : deconvolution_direct , gradI_mkl_md , w_mkl_md , gradO_mkl_md , strides , dilation , padding , padding_r ) ;
dnnl : : deconvolution_backward_data : : primitive_desc op_data_bp_prim_desc ( op_data_bp_desc , engine , op_ff_prim_desc ) ;
2019-11-03 11:37:19 +01:00
// backward weights primitive description
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_backward_weights : : desc op_weights_bp_desc ( dnnl : : algorithm : : deconvolution_direct , x_mkl_md , gradW_mkl_md , gradB_mkl_md , gradO_mkl_md , strides , dilation , padding , padding_r ) ;
dnnl : : deconvolution_backward_weights : : primitive_desc op_weights_bp_prim_desc ( op_weights_bp_desc , engine , op_ff_prim_desc ) ;
2019-11-03 11:37:19 +01:00
// arguments (memory buffers) necessary for calculations
2019-11-20 11:23:08 +01:00
std : : unordered_map < int , dnnl : : memory > args ;
2019-11-03 11:37:19 +01:00
2019-11-20 11:23:08 +01:00
dnnl : : stream stream ( engine ) ;
2019-11-03 11:37:19 +01:00
// provide memory buffers and check whether reorder is required
// input
2019-11-20 11:23:08 +01:00
auto x_user_mem = dnnl : : memory ( x_user_md , engine , input - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool xReorder = op_weights_bp_prim_desc . src_desc ( ) ! = x_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto x_mkl_mem = xReorder ? dnnl : : memory ( op_weights_bp_prim_desc . src_desc ( ) , engine ) : x_user_mem ;
2019-11-03 11:37:19 +01:00
if ( xReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( x_user_mem , x_mkl_mem ) . execute ( stream , x_user_mem , x_mkl_mem ) ;
args [ DNNL_ARG_SRC ] = x_mkl_mem ;
2019-11-03 11:37:19 +01:00
// weights
2019-11-20 11:23:08 +01:00
auto w_user_mem = dnnl : : memory ( w_user_md , engine , weights - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool wReorder = op_data_bp_prim_desc . weights_desc ( ) ! = w_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto w_mkl_mem = wReorder ? dnnl : : memory ( op_data_bp_prim_desc . weights_desc ( ) , engine ) : w_user_mem ;
2019-11-03 11:37:19 +01:00
if ( wReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( w_user_mem , w_mkl_mem ) . execute ( stream , w_user_mem , w_mkl_mem ) ;
args [ DNNL_ARG_WEIGHTS ] = w_mkl_mem ;
2019-11-03 11:37:19 +01:00
// gradO
2019-11-20 11:23:08 +01:00
auto gradO_user_mem = dnnl : : memory ( gradO_user_md , engine , gradO - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool gradOReorder = op_data_bp_prim_desc . diff_dst_desc ( ) ! = gradO_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto gradO_mkl_mem = gradOReorder ? dnnl : : memory ( op_data_bp_prim_desc . diff_dst_desc ( ) , engine ) : gradO_user_mem ;
2019-11-03 11:37:19 +01:00
if ( gradOReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( gradO_user_mem , gradO_mkl_mem ) . execute ( stream , gradO_user_mem , gradO_mkl_mem ) ;
args [ DNNL_ARG_DIFF_DST ] = gradO_mkl_mem ;
2019-11-03 11:37:19 +01:00
// gradI
2019-11-20 11:23:08 +01:00
auto gradI_user_mem = dnnl : : memory ( gradI_user_md , engine , gradI - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool gradIReorder = op_data_bp_prim_desc . diff_src_desc ( ) ! = gradI_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto gradI_mkl_mem = gradIReorder ? dnnl : : memory ( op_data_bp_prim_desc . diff_src_desc ( ) , engine ) : gradI_user_mem ;
args [ DNNL_ARG_DIFF_SRC ] = gradI_mkl_mem ;
2019-11-03 11:37:19 +01:00
// gradW
2019-11-20 11:23:08 +01:00
auto gradW_user_mem = dnnl : : memory ( gradW_user_md , engine , gradW - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool gradWReorder = op_weights_bp_prim_desc . diff_weights_desc ( ) ! = gradW_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto gradW_mkl_mem = gradWReorder ? dnnl : : memory ( op_weights_bp_prim_desc . diff_weights_desc ( ) , engine ) : gradW_user_mem ;
args [ DNNL_ARG_DIFF_WEIGHTS ] = gradW_mkl_mem ;
2019-11-03 11:37:19 +01:00
// gradB
if ( gradB ! = nullptr ) {
2019-11-20 11:23:08 +01:00
auto gradB_mkl_mem = dnnl : : memory ( gradB_mkl_md , engine , gradB - > getBuffer ( ) ) ;
args [ DNNL_ARG_DIFF_BIAS ] = gradB_mkl_mem ;
2019-11-03 11:37:19 +01:00
}
// run backward data calculations
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_backward_data ( op_data_bp_prim_desc ) . execute ( stream , args ) ;
2019-11-03 11:37:19 +01:00
// run backward weights calculations
2019-11-20 11:23:08 +01:00
dnnl : : deconvolution_backward_weights ( op_weights_bp_prim_desc ) . execute ( stream , args ) ;
2019-11-03 11:37:19 +01:00
// reorder gradI if necessary
if ( gradIReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( gradI_mkl_mem , gradI_user_mem ) . execute ( stream , gradI_mkl_mem , gradI_user_mem ) ;
2019-11-03 11:37:19 +01:00
if ( gradWReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( gradW_mkl_mem , gradW_user_mem ) . execute ( stream , gradW_mkl_mem , gradW_user_mem ) ;
2019-11-03 11:37:19 +01:00
stream . wait ( ) ;
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
2020-01-20 19:32:46 +01:00
PLATFORM_IMPL ( deconv3d , ENGINE_CPU ) {
2019-11-03 11:37:19 +01:00
auto input = INPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, oC, iC] always
auto bias = block . width ( ) > 2 ? INPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
auto output = OUTPUT_VARIABLE ( 0 ) ; // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW)
REQUIRE_TRUE ( input - > rankOf ( ) = = 5 , 0 , " CUSTOM DECONV3D_MKLDNN OP: rank of input array must be equal to 5, but got %i instead ! " , input - > rankOf ( ) ) ;
REQUIRE_TRUE ( weights - > rankOf ( ) = = 5 , 0 , " CUSTOM DECONV3D_MKLDNN OP: rank of weights array must be equal to 5, but got %i instead ! " , weights - > rankOf ( ) ) ;
int kD = INT_ARG ( 0 ) > 0 ? INT_ARG ( 0 ) : static_cast < int > ( weights - > sizeAt ( 0 ) ) ; // filter(kernel) depth
int kH = INT_ARG ( 1 ) > 0 ? INT_ARG ( 1 ) : static_cast < int > ( weights - > sizeAt ( 1 ) ) ; // filter(kernel) height
int kW = INT_ARG ( 2 ) > 0 ? INT_ARG ( 2 ) : static_cast < int > ( weights - > sizeAt ( 2 ) ) ; // filter(kernel) width
int sD = INT_ARG ( 3 ) ; // strides depth
int sH = INT_ARG ( 4 ) ; // strides height
int sW = INT_ARG ( 5 ) ; // strides width
int pD = INT_ARG ( 6 ) ; // paddings depth
int pH = INT_ARG ( 7 ) ; // paddings height
int pW = INT_ARG ( 8 ) ; // paddings width
int dD = INT_ARG ( 9 ) ; // dilations depth
int dH = INT_ARG ( 10 ) ; // dilations height
int dW = INT_ARG ( 11 ) ; // dilations width
int isSameMode = INT_ARG ( 12 ) ; // 0-SAME, 1-VALID
int isNCDHW = block . getIArguments ( ) - > size ( ) > 13 ? ! INT_ARG ( 13 ) : 1 ; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , * input , * output , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWoC , indWiC , indWkD ) ;
std : : vector < Nd4jLong > expectedWeightsShape = { kD , kH , kW , oC , iC } ;
REQUIRE_TRUE ( weights - > isSameShape ( expectedWeightsShape ) , 0 , " CUSTOM DECONV3D_MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedWeightsShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( weights ) . c_str ( ) ) ;
if ( bias )
REQUIRE_TRUE ( bias - > rankOf ( ) < = 2 & & oC = = bias - > lengthOf ( ) , 0 , " CUSTOM DECONV3D_MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead ! " , oC , bias - > rankOf ( ) , bias - > lengthOf ( ) ) ;
if ( isSameMode ) { // SAME
//Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils : : calcPadding3D ( pD , pH , pW , iD , iH , iW , oD , oH , oW , kD , kH , kW , sD , sH , sW , dD , dH , dW ) ;
}
// mkl supports only [oC, iC, kD, kH, kW] format for weights
weights = new NDArray ( weights - > permute ( { 3 , 4 , 0 , 1 , 2 } ) ) ; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
// mkl supports only NCDHW
if ( ! isNCDHW ) {
input = new NDArray ( input - > permute ( { 0 , 4 , 1 , 2 , 3 } ) ) ; // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
output = new NDArray ( output - > permute ( { 0 , 4 , 1 , 2 , 3 } ) ) ; // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
}
2020-01-11 05:36:40 +01:00
deconv3dMKLDNN ( input , weights , bias , output , kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW ) ;
2019-11-03 11:37:19 +01:00
delete weights ;
if ( ! isNCDHW ) {
delete input ;
delete output ;
}
return Status : : OK ( ) ;
}
2020-01-20 19:32:46 +01:00
PLATFORM_CHECK ( deconv3d , ENGINE_CPU ) {
2019-11-03 11:37:19 +01:00
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE ( 0 ) ;
auto weights = INPUT_VARIABLE ( 1 ) ;
auto bias = block . width ( ) > 2 ? INPUT_VARIABLE ( 2 ) : nullptr ;
auto output = INPUT_VARIABLE ( 0 ) ;
2019-11-21 20:17:30 +01:00
int dD = INT_ARG ( 9 ) ; // dilations depth
int dH = INT_ARG ( 10 ) ; // dilations height
int dW = INT_ARG ( 11 ) ; // dilations width
int isSameMode = INT_ARG ( 12 ) ; // 0-SAME, 1-VALID
2019-11-03 11:37:19 +01:00
const DataType xType = input - > dataType ( ) ;
const DataType wType = weights - > dataType ( ) ;
const DataType zType = output - > dataType ( ) ;
const DataType bType = bias ! = nullptr ? bias - > dataType ( ) : zType ;
2019-11-21 20:17:30 +01:00
return block . isUseMKLDNN ( ) & & ( dD < = 1 & & dH < = 1 & & dW < = 1 & & ! isSameMode ) & &
(
2019-11-03 11:37:19 +01:00
( xType = = DataType : : FLOAT32 & & wType = = DataType : : FLOAT32 & & bType = = DataType : : FLOAT32 & & zType = = DataType : : FLOAT32 ) | |
( ( xType = = DataType : : UINT8 | | xType = = DataType : : INT8 ) & & wType = = DataType : : INT8 & & ( zType = = DataType : : UINT8 | | zType = = DataType : : INT8 | | zType = = DataType : : INT32 | | zType = = DataType : : FLOAT32 ) & & bType = = zType )
) ;
}
//////////////////////////////////////////////////////////////////////////
2020-01-20 19:32:46 +01:00
PLATFORM_IMPL ( deconv3d_bp , ENGINE_CPU ) {
2019-11-03 11:37:19 +01:00
auto input = INPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, oC, iC] always
auto bias = block . width ( ) > 3 ? INPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
auto gradO = block . width ( ) > 3 ? INPUT_VARIABLE ( 3 ) : INPUT_VARIABLE ( 2 ) ; // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, oC, iC] always
auto gradB = block . width ( ) > 3 ? OUTPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
REQUIRE_TRUE ( input - > rankOf ( ) = = 5 , 0 , " CUSTOM DECONV3D_MKLDNN_BP OP: rank of input array must be equal to 5, but got %i instead ! " , input - > rankOf ( ) ) ;
REQUIRE_TRUE ( weights - > rankOf ( ) = = 5 , 0 , " CUSTOM DECONV3D_MKLDNN_BP OP: rank of weights array must be equal to 5 , but got %i instead ! " , weights - > rankOf ( ) ) ;
REQUIRE_TRUE ( gradO - > rankOf ( ) = = 5 , 0 , " CUSTOM DECONV3D_MKLDNN_BP OP: rank of output gradients (next epsilon) array must be equal to 5, but got %i instead ! " , gradO - > rankOf ( ) ) ;
int kD = INT_ARG ( 0 ) > 0 ? INT_ARG ( 0 ) : static_cast < int > ( weights - > sizeAt ( 0 ) ) ; // filter(kernel) depth
int kH = INT_ARG ( 1 ) > 0 ? INT_ARG ( 1 ) : static_cast < int > ( weights - > sizeAt ( 1 ) ) ; // filter(kernel) height
int kW = INT_ARG ( 2 ) > 0 ? INT_ARG ( 2 ) : static_cast < int > ( weights - > sizeAt ( 2 ) ) ; // filter(kernel) width
int sD = INT_ARG ( 3 ) ; // strides depth
int sH = INT_ARG ( 4 ) ; // strides height
int sW = INT_ARG ( 5 ) ; // strides width
int pD = INT_ARG ( 6 ) ; // paddings depth
int pH = INT_ARG ( 7 ) ; // paddings height
int pW = INT_ARG ( 8 ) ; // paddings width
int dD = INT_ARG ( 9 ) ; // dilations depth
int dH = INT_ARG ( 10 ) ; // dilations height
int dW = INT_ARG ( 11 ) ; // dilations width
int isSameMode = INT_ARG ( 12 ) ; // 0-SAME, 1-VALID
int isNCDHW = block . getIArguments ( ) - > size ( ) > 13 ? ! INT_ARG ( 13 ) : 1 ; // INT_ARG(13): 1-NDHWC, 0-NCDHW
int bS , iC , iD , iH , iW , oC , oD , oH , oW ; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
int indIOioC , indIOioD , indWoC , indWiC , indWkD ; // corresponding indexes
ConvolutionUtils : : getSizesAndIndexesConv3d ( isNCDHW , * input , * gradO , bS , iC , iD , iH , iW , oC , oD , oH , oW , indIOioC , indIOioD , indWoC , indWiC , indWkD ) ;
int trueoD , trueoH , trueoW ; // true output height, width
ConvolutionUtils : : calcOutSizeDeconv3D ( trueoD , trueoH , trueoW , kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW , iD , iH , iW , isSameMode ) ;
std : : vector < Nd4jLong > expectedGradOShape = ShapeUtils : : composeShapeUsingDimsAndIdx ( { bS , oC , trueoD , trueoH , trueoW , 0 , indIOioC , indIOioD , indIOioD + 1 , indIOioD + 2 } ) ;
std : : vector < Nd4jLong > expectedWeightsShape = { kD , kH , kW , oC , iC } ;
REQUIRE_TRUE ( gradO - > isSameShape ( expectedGradOShape ) , 0 , " CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedGradOShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( gradO ) . c_str ( ) ) ;
REQUIRE_TRUE ( weights - > isSameShape ( expectedWeightsShape ) , 0 , " CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of weights array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( expectedWeightsShape ) . c_str ( ) , ShapeUtils : : shapeAsString ( weights ) . c_str ( ) ) ;
if ( bias )
REQUIRE_TRUE ( bias - > rankOf ( ) < = 2 & & oC = = bias - > lengthOf ( ) , 0 , " CUSTOM DECONV3D_MKLDNN_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead ! " , oC , bias - > rankOf ( ) , bias - > lengthOf ( ) ) ;
if ( isSameMode ) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
ConvolutionUtils : : calcPadding3D ( pD , pH , pW , iD , iH , iW , oD , oH , oW , kD , kH , kW , sD , sH , sW , dD , dH , dW ) ;
// mkl supports only [oC, iC, kD, kH, kW] for weights
weights = new NDArray ( weights - > permute ( { 3 , 4 , 0 , 1 , 2 } ) ) ; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
gradW = new NDArray ( gradW - > permute ( { 3 , 4 , 0 , 1 , 2 } ) ) ; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
// mkl supports NCDHW format only
if ( ! isNCDHW ) {
input = new NDArray ( input - > permute ( { 0 , 4 , 1 , 2 , 3 } ) ) ; // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradI = new NDArray ( gradI - > permute ( { 0 , 4 , 1 , 2 , 3 } ) ) ; // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
gradO = new NDArray ( gradO - > permute ( { 0 , 4 , 1 , 2 , 3 } ) ) ; // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
}
2020-01-11 05:36:40 +01:00
deconv3dBackPropMKLDNN ( input , weights , gradO , gradI , gradW , gradB , kD , kH , kW , sD , sH , sW , pD , pH , pW , dD , dH , dW ) ;
2019-11-03 11:37:19 +01:00
delete weights ;
delete gradW ;
if ( ! isNCDHW ) {
delete input ;
delete gradI ;
delete gradO ;
}
return Status : : OK ( ) ;
}
2020-01-20 19:32:46 +01:00
PLATFORM_CHECK ( deconv3d_bp , ENGINE_CPU ) {
2019-11-03 11:37:19 +01:00
auto input = INPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NHWC) or [bS, iD, iC, iH, iW] (NCDHW)
auto weights = INPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, oC, iC] always
auto bias = block . width ( ) > 3 ? INPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
auto gradO = block . width ( ) > 3 ? INPUT_VARIABLE ( 3 ) : INPUT_VARIABLE ( 2 ) ; // [bS, oD, oH, oW, oC] (NHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
auto gradI = OUTPUT_VARIABLE ( 0 ) ; // [bS, iD, iH, iW, iC] (NHWC) or [bS, iC, iD, iH, iW] (NCDHW), gradI
auto gradW = OUTPUT_VARIABLE ( 1 ) ; // [kD, kH, kW, oC, iC] always
auto gradB = block . width ( ) > 3 ? OUTPUT_VARIABLE ( 2 ) : nullptr ; // [oC]
2019-11-21 20:17:30 +01:00
int dD = INT_ARG ( 9 ) ; // dilations depth
int dH = INT_ARG ( 10 ) ; // dilations height
int dW = INT_ARG ( 11 ) ; // dilations width
int isSameMode = INT_ARG ( 12 ) ; // 0-SAME, 1-VALID
2019-11-03 11:37:19 +01:00
const DataType xType = input - > dataType ( ) ;
const DataType wType = weights - > dataType ( ) ;
const DataType gradOType = gradO - > dataType ( ) ;
const DataType gradIType = gradI - > dataType ( ) ;
const DataType gradWType = gradW - > dataType ( ) ;
const DataType gradBType = gradB ! = nullptr ? gradB - > dataType ( ) : DataType : : FLOAT32 ;
2019-11-21 20:17:30 +01:00
return block . isUseMKLDNN ( ) & & ( dD < = 1 & & dH < = 1 & & dW < = 1 & & ! isSameMode ) & & ( ( xType = = DataType : : FLOAT32 | | xType = = DataType : : BFLOAT16 ) & & ( wType = = DataType : : FLOAT32 | | wType = = DataType : : BFLOAT16 ) & & ( gradOType = = DataType : : FLOAT32 | | gradOType = = DataType : : BFLOAT16 ) & & ( gradIType = = DataType : : FLOAT32 | | gradIType = = DataType : : BFLOAT16 ) & & ( gradWType = = DataType : : FLOAT32 | | gradWType = = DataType : : BFLOAT16 ) & & ( gradBType = = DataType : : FLOAT32 | | gradBType = = DataType : : BFLOAT16 ) ) ;
2019-11-03 11:37:19 +01:00
}
}
}
}