2019-09-11 20:50:28 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
2019-11-13 15:15:18 +01:00
* Copyright ( c ) 2019 Konduit K . K .
2019-09-11 20:50:28 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author saudet
// @author raver119@gmail.com
2019-10-26 13:14:21 +02:00
// @author Yurii Shyrma (iuriish@yahoo.com)
2019-09-11 20:50:28 +02:00
//
# include <ops/declarable/PlatformHelper.h>
# include <ops/declarable/OpRegistrator.h>
# include <platform_boilerplate.h>
# include <helpers/MKLDNNStream.h>
# include "mkldnnUtils.h"
# include <ops/declarable/helpers/convolutions.h>
# include <NDArrayFactory.h>
2019-10-26 13:14:21 +02:00
namespace nd4j {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void batchnormMKLDNN ( const NDArray * x , const NDArray * mean , const NDArray * variance , const NDArray * weights , const float epsilon , NDArray * z ) {
2019-11-20 11:23:08 +01:00
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
2019-10-26 13:14:21 +02:00
// also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
// mean -> 1D [c]
// variance -> 1D [c]
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
// z(output) - same shape as x
const int xRank = x - > rankOf ( ) ;
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// input type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type type = dnnl : : memory : : data_type : : f32 ;
2019-10-26 13:14:21 +02:00
// indicate whether gamma or/and beta are given
2019-11-20 11:23:08 +01:00
auto flags = dnnl : : normalization_flags : : use_global_stats ; // don't calculate the mean and variance for each mini-batch
2019-10-26 13:14:21 +02:00
if ( weights ! = nullptr )
2019-11-20 11:23:08 +01:00
flags | = dnnl : : normalization_flags : : use_scale_shift ;
2019-10-26 13:14:21 +02:00
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims dims ;
dnnl : : memory : : format_tag format ;
2019-10-26 13:14:21 +02:00
if ( xRank = = 2 ) {
dims = { x - > sizeAt ( 0 ) , x - > sizeAt ( 1 ) } ;
2019-11-20 11:23:08 +01:00
format = dnnl : : memory : : format_tag : : nc ;
2019-10-26 13:14:21 +02:00
}
else if ( xRank = = 4 ) {
dims = { x - > sizeAt ( 0 ) , x - > sizeAt ( 1 ) , x - > sizeAt ( 2 ) , x - > sizeAt ( 3 ) } ;
2019-11-20 11:23:08 +01:00
format = dnnl : : memory : : format_tag : : nchw ;
2019-10-26 13:14:21 +02:00
}
else { // xRank = 5
dims = { x - > sizeAt ( 0 ) , x - > sizeAt ( 1 ) , x - > sizeAt ( 2 ) , x - > sizeAt ( 3 ) , x - > sizeAt ( 4 ) } ;
2019-11-20 11:23:08 +01:00
format = dnnl : : memory : : format_tag : : ncdhw ;
2019-10-26 13:14:21 +02:00
}
// memory descriptors for arrays
// x
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( dims , type , format ) ;
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( dims , type , format ) ;
x_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-10-26 13:14:21 +02:00
x_user_md . data . format_desc . blocking . strides [ 0 ] = x - > stridesOf ( ) [ 0 ] ;
x_user_md . data . format_desc . blocking . strides [ 1 ] = x - > stridesOf ( ) [ 1 ] ;
if ( xRank > 2 ) {
x_user_md . data . format_desc . blocking . strides [ 2 ] = x - > stridesOf ( ) [ 2 ] ;
x_user_md . data . format_desc . blocking . strides [ 3 ] = x - > stridesOf ( ) [ 3 ] ;
}
if ( xRank > 4 )
x_user_md . data . format_desc . blocking . strides [ 4 ] = x - > stridesOf ( ) [ 4 ] ;
// z, output
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc z_mkl_md = dnnl : : memory : : desc ( dims , type , format ) ;
dnnl : : memory : : desc z_user_md = dnnl : : memory : : desc ( dims , type , format ) ;
z_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-10-26 13:14:21 +02:00
z_user_md . data . format_desc . blocking . strides [ 0 ] = z - > stridesOf ( ) [ 0 ] ;
z_user_md . data . format_desc . blocking . strides [ 1 ] = z - > stridesOf ( ) [ 1 ] ;
if ( xRank > 2 ) {
z_user_md . data . format_desc . blocking . strides [ 2 ] = z - > stridesOf ( ) [ 2 ] ;
z_user_md . data . format_desc . blocking . strides [ 3 ] = z - > stridesOf ( ) [ 3 ] ;
}
if ( xRank > 4 )
z_user_md . data . format_desc . blocking . strides [ 4 ] = z - > stridesOf ( ) [ 4 ] ;
// batchnorm forward description
2019-11-20 11:23:08 +01:00
dnnl : : batch_normalization_forward : : desc op_ff_desc ( dnnl : : prop_kind : : forward_inference , x_mkl_md , epsilon , flags ) ;
dnnl : : batch_normalization_forward : : primitive_desc op_ff_prim_desc ( op_ff_desc , engine ) ;
2019-10-26 13:14:21 +02:00
// arguments (memory buffers) necessary for calculations
2019-11-20 11:23:08 +01:00
std : : unordered_map < int , dnnl : : memory > args ;
2019-10-26 13:14:21 +02:00
2019-11-20 11:23:08 +01:00
dnnl : : stream stream ( engine ) ;
2019-10-26 13:14:21 +02:00
// provide memory and check whether reorder is required
// x
2019-11-20 11:23:08 +01:00
auto x_user_mem = dnnl : : memory ( x_user_md , engine , x - > getBuffer ( ) ) ;
2019-10-26 13:14:21 +02:00
const bool xReorder = op_ff_prim_desc . src_desc ( ) ! = x_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto x_mkl_mem = xReorder ? dnnl : : memory ( op_ff_prim_desc . src_desc ( ) , engine ) : x_user_mem ;
2019-10-26 13:14:21 +02:00
if ( xReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( x_user_mem , x_mkl_mem ) . execute ( stream , x_user_mem , x_mkl_mem ) ;
args [ DNNL_ARG_SRC ] = x_mkl_mem ;
2019-10-26 13:14:21 +02:00
// z
2019-11-20 11:23:08 +01:00
auto z_user_mem = dnnl : : memory ( z_user_md , engine , z - > getBuffer ( ) ) ;
2019-10-26 13:14:21 +02:00
const bool zReorder = op_ff_prim_desc . dst_desc ( ) ! = z_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto z_mkl_mem = zReorder ? dnnl : : memory ( op_ff_prim_desc . dst_desc ( ) , engine ) : z_user_mem ;
2019-10-26 13:14:21 +02:00
if ( zReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( z_user_mem , z_mkl_mem ) . execute ( stream , z_user_mem , z_mkl_mem ) ;
args [ DNNL_ARG_DST ] = z_mkl_mem ;
2019-10-26 13:14:21 +02:00
// mean
2019-11-20 11:23:08 +01:00
auto mean_mkl_mem = dnnl : : memory ( op_ff_prim_desc . mean_desc ( ) , engine , mean - > getBuffer ( ) ) ;
args [ DNNL_ARG_MEAN ] = mean_mkl_mem ;
2019-10-26 13:14:21 +02:00
// variance
2019-11-20 11:23:08 +01:00
auto var_mkl_mem = dnnl : : memory ( op_ff_prim_desc . variance_desc ( ) , engine , variance - > getBuffer ( ) ) ;
args [ DNNL_ARG_VARIANCE ] = var_mkl_mem ;
2019-10-26 13:14:21 +02:00
// gamma and beta (and their gradients) if they are present
if ( weights ! = nullptr ) {
2019-11-20 11:23:08 +01:00
auto w_mkl_mem = dnnl : : memory ( op_ff_prim_desc . weights_desc ( ) , engine , weights - > getBuffer ( ) ) ;
args [ DNNL_ARG_WEIGHTS ] = w_mkl_mem ;
2019-10-26 13:14:21 +02:00
}
// run calculations
2019-11-20 11:23:08 +01:00
dnnl : : batch_normalization_forward ( op_ff_prim_desc ) . execute ( stream , args ) ;
2019-10-26 13:14:21 +02:00
// reorder outputs if necessary
if ( zReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( z_mkl_mem , z_user_mem ) . execute ( stream , z_mkl_mem , z_user_mem ) ;
2019-10-26 13:14:21 +02:00
stream . wait ( ) ;
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
static void batchnormBackPropMKLDNN ( const NDArray * x , const NDArray * mean , const NDArray * variance , const NDArray * dLdO , const NDArray * weights ,
const float epsilon , NDArray * dLdI , NDArray * dLdW ) {
2019-11-20 11:23:08 +01:00
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
2019-10-26 13:14:21 +02:00
// also it gives wrong results for formats nhwc and ndhwc
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
// mean -> 1D [c]
// variance -> 1D [c]
// dLdO - same shape as x
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
// dLdI - same shape as x
// dLdW - same shape as weights, dLdW({0,1, 0,0}) contains grad_gamma and dLdW({1,2, 0,0}) contains grad_beta
const int xRank = x - > rankOf ( ) ;
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// input type
2019-11-20 11:23:08 +01:00
dnnl : : memory : : data_type type = dnnl : : memory : : data_type : : f32 ;
2019-10-26 13:14:21 +02:00
// indicate whether gamma or/and beta are given
2019-11-20 11:23:08 +01:00
auto flags = dnnl : : normalization_flags : : use_global_stats ; // don't calculate the mean and variance for each mini-batch
2019-10-26 13:14:21 +02:00
if ( weights ! = nullptr )
2019-11-20 11:23:08 +01:00
flags | = dnnl : : normalization_flags : : use_scale_shift ;
2019-10-26 13:14:21 +02:00
2019-11-20 11:23:08 +01:00
dnnl : : memory : : dims dims ;
dnnl : : memory : : format_tag format ;
2019-10-26 13:14:21 +02:00
if ( xRank = = 2 ) {
dims = { x - > sizeAt ( 0 ) , x - > sizeAt ( 1 ) } ;
2019-11-20 11:23:08 +01:00
format = dnnl : : memory : : format_tag : : nc ;
2019-10-26 13:14:21 +02:00
}
else if ( xRank = = 4 ) {
dims = { x - > sizeAt ( 0 ) , x - > sizeAt ( 1 ) , x - > sizeAt ( 2 ) , x - > sizeAt ( 3 ) } ;
2019-11-20 11:23:08 +01:00
format = dnnl : : memory : : format_tag : : nchw ;
2019-10-26 13:14:21 +02:00
}
else { // xRank = 5
dims = { x - > sizeAt ( 0 ) , x - > sizeAt ( 1 ) , x - > sizeAt ( 2 ) , x - > sizeAt ( 3 ) , x - > sizeAt ( 4 ) } ;
2019-11-20 11:23:08 +01:00
format = dnnl : : memory : : format_tag : : ncdhw ;
2019-10-26 13:14:21 +02:00
}
// memory descriptors for arrays
// x
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( dims , type , format ) ;
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( dims , type , format ) ;
x_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-10-26 13:14:21 +02:00
x_user_md . data . format_desc . blocking . strides [ 0 ] = x - > stridesOf ( ) [ 0 ] ;
x_user_md . data . format_desc . blocking . strides [ 1 ] = x - > stridesOf ( ) [ 1 ] ;
if ( xRank > 2 ) {
x_user_md . data . format_desc . blocking . strides [ 2 ] = x - > stridesOf ( ) [ 2 ] ;
x_user_md . data . format_desc . blocking . strides [ 3 ] = x - > stridesOf ( ) [ 3 ] ;
}
if ( xRank > 4 )
x_user_md . data . format_desc . blocking . strides [ 4 ] = x - > stridesOf ( ) [ 4 ] ;
// dLdO
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc dLdO_mkl_md = dnnl : : memory : : desc ( dims , type , format ) ;
dnnl : : memory : : desc dLdO_user_md = dnnl : : memory : : desc ( dims , type , format ) ;
dLdO_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-10-26 13:14:21 +02:00
dLdO_user_md . data . format_desc . blocking . strides [ 0 ] = dLdO - > stridesOf ( ) [ 0 ] ;
dLdO_user_md . data . format_desc . blocking . strides [ 1 ] = dLdO - > stridesOf ( ) [ 1 ] ;
if ( xRank > 2 ) {
dLdO_user_md . data . format_desc . blocking . strides [ 2 ] = dLdO - > stridesOf ( ) [ 2 ] ;
dLdO_user_md . data . format_desc . blocking . strides [ 3 ] = dLdO - > stridesOf ( ) [ 3 ] ;
}
if ( xRank > 4 )
dLdO_user_md . data . format_desc . blocking . strides [ 4 ] = dLdO - > stridesOf ( ) [ 4 ] ;
// dLdI
2019-11-20 11:23:08 +01:00
dnnl : : memory : : desc dLdI_mkl_md = dnnl : : memory : : desc ( dims , type , format ) ;
dnnl : : memory : : desc dLdI_user_md = dnnl : : memory : : desc ( dims , type , format ) ;
dLdI_user_md . data . format_kind = dnnl_blocked ; // overrides format
2019-10-26 13:14:21 +02:00
dLdI_user_md . data . format_desc . blocking . strides [ 0 ] = dLdI - > stridesOf ( ) [ 0 ] ;
dLdI_user_md . data . format_desc . blocking . strides [ 1 ] = dLdI - > stridesOf ( ) [ 1 ] ;
if ( xRank > 2 ) {
dLdI_user_md . data . format_desc . blocking . strides [ 2 ] = dLdI - > stridesOf ( ) [ 2 ] ;
dLdI_user_md . data . format_desc . blocking . strides [ 3 ] = dLdI - > stridesOf ( ) [ 3 ] ;
}
if ( xRank > 4 )
dLdI_user_md . data . format_desc . blocking . strides [ 4 ] = dLdI - > stridesOf ( ) [ 4 ] ;
// batchnorm forward description
2019-11-20 11:23:08 +01:00
dnnl : : batch_normalization_forward : : desc op_ff_desc ( dnnl : : prop_kind : : forward_inference , x_mkl_md , epsilon , flags ) ;
dnnl : : batch_normalization_forward : : primitive_desc op_ff_prim_desc ( op_ff_desc , engine ) ;
2019-10-26 13:14:21 +02:00
// batchnorm backprop description
2019-11-20 11:23:08 +01:00
dnnl : : batch_normalization_backward : : desc op_bp_desc ( dnnl : : prop_kind : : backward , dLdO_mkl_md , x_mkl_md , epsilon , flags ) ;
dnnl : : batch_normalization_backward : : primitive_desc op_bp_prim_desc ( op_bp_desc , engine , op_ff_prim_desc ) ;
2019-10-26 13:14:21 +02:00
// arguments (memory buffers) necessary for calculations
2019-11-20 11:23:08 +01:00
std : : unordered_map < int , dnnl : : memory > args ;
2019-10-26 13:14:21 +02:00
2019-11-20 11:23:08 +01:00
dnnl : : stream stream ( engine ) ;
2019-10-26 13:14:21 +02:00
// provide memory and check whether reorder is required
// x
2019-11-20 11:23:08 +01:00
auto x_user_mem = dnnl : : memory ( x_user_md , engine , x - > getBuffer ( ) ) ;
2019-10-26 13:14:21 +02:00
const bool xReorder = op_bp_prim_desc . src_desc ( ) ! = x_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto x_mkl_mem = xReorder ? dnnl : : memory ( op_bp_prim_desc . src_desc ( ) , engine ) : x_user_mem ;
2019-10-26 13:14:21 +02:00
if ( xReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( x_user_mem , x_mkl_mem ) . execute ( stream , x_user_mem , x_mkl_mem ) ;
args [ DNNL_ARG_SRC ] = x_mkl_mem ;
2019-10-26 13:14:21 +02:00
// dLdO
2019-11-20 11:23:08 +01:00
auto dLdO_user_mem = dnnl : : memory ( dLdO_user_md , engine , dLdO - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool dLdOReorder = op_bp_prim_desc . diff_dst_desc ( ) ! = dLdO_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto dLdO_mkl_mem = dLdOReorder ? dnnl : : memory ( op_bp_prim_desc . diff_dst_desc ( ) , engine ) : dLdO_user_mem ;
2019-10-26 13:14:21 +02:00
if ( dLdOReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( dLdO_user_mem , dLdO_mkl_mem ) . execute ( stream , dLdO_user_mem , dLdO_mkl_mem ) ;
args [ DNNL_ARG_DIFF_DST ] = dLdO_mkl_mem ;
2019-10-26 13:14:21 +02:00
// mean
2019-11-20 11:23:08 +01:00
auto mean_mkl_mem = dnnl : : memory ( op_bp_prim_desc . mean_desc ( ) , engine , mean - > getBuffer ( ) ) ;
args [ DNNL_ARG_MEAN ] = mean_mkl_mem ;
2019-10-26 13:14:21 +02:00
// variance
2019-11-20 11:23:08 +01:00
auto var_mkl_mem = dnnl : : memory ( op_bp_prim_desc . variance_desc ( ) , engine , variance - > getBuffer ( ) ) ;
args [ DNNL_ARG_VARIANCE ] = var_mkl_mem ;
2019-10-26 13:14:21 +02:00
// dLdI
2019-11-20 11:23:08 +01:00
auto dLdI_user_mem = dnnl : : memory ( dLdI_user_md , engine , dLdI - > getBuffer ( ) ) ;
2019-11-03 11:37:19 +01:00
const bool dLdIReorder = op_bp_prim_desc . diff_src_desc ( ) ! = dLdI_user_mem . get_desc ( ) ;
2019-11-20 11:23:08 +01:00
auto dLdI_mkl_mem = dLdIReorder ? dnnl : : memory ( op_bp_prim_desc . diff_src_desc ( ) , engine ) : dLdI_user_mem ;
args [ DNNL_ARG_DIFF_SRC ] = dLdI_mkl_mem ;
2019-10-26 13:14:21 +02:00
// gamma and beta (and their gradients) if they are present
if ( weights ! = nullptr ) {
2019-11-20 11:23:08 +01:00
auto w_mkl_mem = dnnl : : memory ( op_bp_prim_desc . weights_desc ( ) , engine , weights - > getBuffer ( ) ) ;
args [ DNNL_ARG_WEIGHTS ] = w_mkl_mem ;
2019-10-26 13:14:21 +02:00
2019-11-20 11:23:08 +01:00
auto dLdW_mkl_mem = dnnl : : memory ( op_bp_prim_desc . weights_desc ( ) , engine , dLdW - > getBuffer ( ) ) ;
args [ DNNL_ARG_DIFF_WEIGHTS ] = dLdW_mkl_mem ;
2019-10-26 13:14:21 +02:00
}
// run calculations
2019-11-20 11:23:08 +01:00
dnnl : : batch_normalization_backward ( op_bp_prim_desc ) . execute ( stream , args ) ;
2019-10-26 13:14:21 +02:00
// reorder outputs if necessary
if ( dLdIReorder )
2019-11-20 11:23:08 +01:00
dnnl : : reorder ( dLdI_mkl_mem , dLdI_user_mem ) . execute ( stream , dLdI_mkl_mem , dLdI_user_mem ) ;
2019-10-26 13:14:21 +02:00
stream . wait ( ) ;
// shape::printArray(dLdI_mkl_mem.map_data<float>(),8);
2019-11-13 15:15:18 +01:00
// notations:
// f = g * (gamma * ((x - m) / (v + eps)^0.5) + beta) -> means dLdO * ff_output
// g = dLdO
// stdInv = 1 / (v + eps)^0.5
// N - batch size (product of spatial dimensions)
// formula for full derivative with respect to input (x)
// dLdI = dfdx + dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)
// !!! MKL CALCULATES ONLY FIRST TERM dfdx, SO WE SHOULD CALCULATE TERM (dfdm*dmdx + dfdv*(dvdm*dmdx + dvdx)) BY OURSELF !!!
// dfdm = -gamma*stdInv*g_sum;
// dmdx = 1/N;
// dvdx = 2 * (x - m) / N
// dvdm = -2 * [(x - m)]_sum / N
// dfdv = -0.5 * [g*(x - m)]_sum * stdInv^3, drop gamma here for calc convenience
// finally:
// dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m))
// dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) )
std : : vector < int > axes = { 1 } ;
const auto excludedAxes = ShapeUtils : : evalDimsToExclude ( x - > rankOf ( ) , axes ) ;
// inversed batch size 1 / N
const auto Ninv = 1.f * mean - > lengthOf ( ) / x - > lengthOf ( ) ;
// x - mean
NDArray xMinusMean ( x ) ; // empty array with same shape as x
2019-12-20 20:35:39 +01:00
const_cast < NDArray * > ( x ) - > applyBroadcast ( nd4j : : broadcast : : Subtract , axes , * mean , xMinusMean ) ;
2019-11-13 15:15:18 +01:00
// stdInv
NDArray stdInv = * variance + epsilon ;
2019-12-20 20:35:39 +01:00
stdInv . applyTransform ( transform : : Reciprocal , stdInv ) ; // 1 / (variance + epsilon)
stdInv . applyTransform ( transform : : Sqrt , stdInv ) ; // 1 / (variance + epsilon)^0.5
2019-11-13 15:15:18 +01:00
// dfdm / N
2019-12-20 20:35:39 +01:00
auto dfdm = dLdO - > reduceAlongDimension ( nd4j : : reduce : : Sum , excludedAxes ) ;
2019-11-13 15:15:18 +01:00
dfdm * = stdInv ;
dfdm * = - Ninv ;
// dvdm / 2
NDArray dvdm ( mean ) ; // empty array with same shape as mean
2019-12-20 20:35:39 +01:00
xMinusMean . reduceAlongDimension ( nd4j : : reduce : : Sum , dvdm , excludedAxes ) ;
2019-11-13 15:15:18 +01:00
dvdm * = - Ninv ;
// (2/N)*dfdv
NDArray dfdv ( variance ) ; // empty array with same shape as variance
2019-12-20 20:35:39 +01:00
( xMinusMean * * dLdO ) . reduceAlongDimension ( nd4j : : reduce : : Sum , dfdv , excludedAxes ) ;
2019-11-13 15:15:18 +01:00
dfdv * = stdInv * stdInv * stdInv ;
dfdv * = - Ninv ;
// dvdm/2 + (x - m)
2019-12-20 20:35:39 +01:00
xMinusMean . applyBroadcast ( nd4j : : broadcast : : Add , axes , dvdm , xMinusMean ) ;
2019-11-13 15:15:18 +01:00
// dfdv * (dvdm/2 + (x - m))
2019-12-20 20:35:39 +01:00
xMinusMean . applyBroadcast ( nd4j : : broadcast : : Multiply , axes , dfdv , xMinusMean ) ;
2019-11-13 15:15:18 +01:00
// add dfdm / N
2019-12-20 20:35:39 +01:00
xMinusMean . applyBroadcast ( nd4j : : broadcast : : Add , axes , dfdm , xMinusMean ) ;
2019-11-13 15:15:18 +01:00
// * gamma
auto gamma = ( * weights ) ( { 0 , 1 , 0 , 0 } ) ;
2019-12-20 20:35:39 +01:00
xMinusMean . applyBroadcast ( nd4j : : broadcast : : Multiply , axes , gamma , xMinusMean ) ;
2019-11-13 15:15:18 +01:00
* dLdI + = xMinusMean ;
2019-10-26 13:14:21 +02:00
}
PLATFORM_IMPL ( batchnorm ) {
auto input = INPUT_VARIABLE ( 0 ) ; // 2D:nc, 4D:nchw, 5D:ncdhw
auto mean = INPUT_VARIABLE ( 1 ) ; // [c]
auto variance = INPUT_VARIABLE ( 2 ) ; // [c]
NDArray * gamma = nullptr ; // [c]
NDArray * beta = nullptr ; // [c]
auto output = OUTPUT_VARIABLE ( 0 ) ; // same shape as input
const bool applyScale = ( bool ) INT_ARG ( 0 ) ;
const bool applyOffset = ( bool ) INT_ARG ( 1 ) ;
const double epsilon = T_ARG ( 0 ) ;
if ( applyScale )
gamma = INPUT_VARIABLE ( 3 ) ;
if ( applyOffset )
beta = INPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
const int numOfIntArgs = block . getIArguments ( ) - > size ( ) ;
const int inRank = input - > rankOf ( ) ;
// get axes args to normalize input array over
std : : vector < int > axes ;
if ( numOfIntArgs > 2 )
for ( int i = 2 ; i < numOfIntArgs ; + + i )
axes . push_back ( INT_ARG ( i ) ) ;
else
axes . push_back ( inRank - 1 ) ; // default dimension to reduce along is last dimension
const int numOfAxes = axes . size ( ) ;
REQUIRE_TRUE ( numOfAxes = = 1 , 0 , " BATCHNORM_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead! " , numOfAxes ) ;
REQUIRE_TRUE ( inRank = = 2 | | inRank = = 4 | | inRank = = 5 , 0 , " BATCHNORM_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead! " , inRank ) ;
REQUIRE_TRUE ( mean - > rankOf ( ) = = 1 & & mean - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( mean ) . c_str ( ) ) ;
REQUIRE_TRUE ( variance - > rankOf ( ) = = 1 & & variance - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( variance ) . c_str ( ) ) ;
if ( gamma ! = nullptr )
REQUIRE_TRUE ( gamma - > rankOf ( ) = = 1 & & gamma - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( gamma ) . c_str ( ) ) ;
if ( beta ! = nullptr )
REQUIRE_TRUE ( beta - > rankOf ( ) = = 1 & & beta - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( beta ) . c_str ( ) ) ;
// types of all input arrays should be the same (except dLdO)
for ( int i = 1 ; i < block . width ( ) - 1 ; + + i )
REQUIRE_TRUE ( INPUT_VARIABLE ( 0 ) - > dataType ( ) = = INPUT_VARIABLE ( i ) - > dataType ( ) , 0 , " BATCHNORM_MKLDNN op: types of all input arrays should be the same ! " ) ;
NDArray * weights = nullptr ;
if ( applyScale | | applyOffset ) {
weights = new NDArray ( input - > ordering ( ) , { 2 , input - > sizeAt ( axes [ 0 ] ) } , input - > dataType ( ) ) ;
if ( applyScale )
( * weights ) ( { 0 , 1 , 0 , 0 } ) . assign ( gamma ) ;
else
( * weights ) ( { 0 , 1 , 0 , 0 } ) . assign ( 1 ) ;
if ( applyOffset )
( * weights ) ( { 1 , 2 , 0 , 0 } ) . assign ( beta ) ;
else
( * weights ) ( { 1 , 2 , 0 , 0 } ) . assign ( 0 ) ;
}
2019-11-13 15:15:18 +01:00
if ( axes [ 0 ] = = inRank - 1 & & inRank > 2 ) { // if nhwc or ndhwc
std : : vector < int > permut = inRank = = 4 ? std : : vector < int > ( { 0 , 3 , 1 , 2 } ) : std : : vector < int > ( { 0 , 4 , 1 , 2 , 3 } ) ;
input = new NDArray ( input - > permute ( permut ) ) ;
output = new NDArray ( output - > permute ( permut ) ) ;
}
2019-10-26 13:14:21 +02:00
batchnormMKLDNN ( input , mean , variance , weights , epsilon , output ) ;
delete weights ;
2019-11-13 15:15:18 +01:00
if ( axes [ 0 ] = = inRank - 1 & & inRank > 2 ) {
delete input ;
delete output ;
}
2019-10-26 13:14:21 +02:00
return Status : : OK ( ) ;
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK ( batchnorm ) {
// we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
auto input = INPUT_VARIABLE ( 0 ) ; // 2D:nc, 4D:nchw, 5D:ncdhw
auto mean = INPUT_VARIABLE ( 1 ) ; // [c]
auto variance = INPUT_VARIABLE ( 2 ) ; // [c]
NDArray * gamma = nullptr ; // [c]
NDArray * beta = nullptr ; // [c]
auto output = OUTPUT_VARIABLE ( 0 ) ; // same shape as input
const bool applyScale = ( bool ) INT_ARG ( 0 ) ;
const bool applyOffset = ( bool ) INT_ARG ( 1 ) ;
if ( applyScale )
gamma = INPUT_VARIABLE ( 3 ) ;
if ( applyOffset )
beta = INPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
const int numOfIntArgs = block . getIArguments ( ) - > size ( ) ;
std : : vector < int > axes ;
if ( numOfIntArgs > 2 )
for ( int i = 2 ; i < numOfIntArgs ; + + i )
axes . push_back ( INT_ARG ( i ) ) ;
else
axes . push_back ( input - > rankOf ( ) - 1 ) ; // default dimension to reduce along is last dimension
DataType inputType = input - > dataType ( ) ;
DataType meanType = mean - > dataType ( ) ;
DataType varType = variance - > dataType ( ) ;
DataType gammaType = gamma ! = nullptr ? gamma - > dataType ( ) : DataType : : FLOAT32 ;
DataType betaType = beta ! = nullptr ? beta - > dataType ( ) : DataType : : FLOAT32 ;
DataType outType = output - > dataType ( ) ;
const int inRank = input - > rankOf ( ) ;
2019-11-13 15:15:18 +01:00
return block . isUseMKLDNN ( ) & & axes . size ( ) = = 1 & & ( axes [ 0 ] = = 1 | | axes [ 0 ] = = inRank - 1 ) & & ( inRank = = 2 | | inRank = = 4 | | inRank = = 5 ) & &
2019-10-26 13:14:21 +02:00
( inputType = = DataType : : FLOAT32 & & meanType = = DataType : : FLOAT32 & & varType = = DataType : : FLOAT32 & &
gammaType = = DataType : : FLOAT32 & & betaType = = DataType : : FLOAT32 & & outType = = DataType : : FLOAT32 ) ;
}
//////////////////////////////////////////////////////////////////////////
// PLATFORM_IMPL(batchnorm) {
// auto input = INPUT_VARIABLE(0);
// auto mean = INPUT_VARIABLE(1);
// auto variance = INPUT_VARIABLE(2);
// NDArray *gamma = nullptr;
// NDArray *beta = nullptr;
// auto output = OUTPUT_VARIABLE(0);
// const bool applyScale = (bool) INT_ARG(0);
// const bool applyOffset = (bool) INT_ARG(1);
// const double epsilon = T_ARG(0);
// if (applyScale)
// gamma = INPUT_VARIABLE(3);
// if (applyOffset)
// beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
// std::vector<int> axes;
// if (block.numI() > 2)
// for (int i = 2; i < block.numI(); ++i)
// axes.push_back(INT_ARG(i));
// else
// axes.push_back(input->rankOf() - 1);
// std::vector<Nd4jLong> shape({2, mean->lengthOf()});
// NDArray weights = NDArrayFactory::create<float>('c', shape, block.launchContext());
// weights({0, 1, 0, 0}).assign(1.0f);
// weights({1, 2, 0, 0}).assign(0.0f);
// mkldnn_memory_desc_t empty;
2019-11-20 11:23:08 +01:00
// dnnl::memory::desc batchnorm_src_md(empty), batchnorm_dst_md(empty), user_src_md(empty), user_dst_md(empty);
2019-10-26 13:14:21 +02:00
2019-11-20 11:23:08 +01:00
// auto flag = dnnl::normalization_flags::use_global_stats;
2019-10-26 13:14:21 +02:00
// if (applyScale || applyOffset)
2019-11-20 11:23:08 +01:00
// flag |= dnnl::normalization_flags::use_scale_shift;
2019-10-26 13:14:21 +02:00
// mkldnnUtils::getMKLDNNMemoryDescBatchNorm(input, nullptr, output,
// &batchnorm_src_md, nullptr, &batchnorm_dst_md,
// &user_src_md, nullptr, &user_dst_md, axes[0]);
2019-11-20 11:23:08 +01:00
// auto batchnorm_desc = dnnl::batch_normalization_forward::desc(dnnl::prop_kind::forward_inference, batchnorm_src_md, epsilon, flag);
2019-10-26 13:14:21 +02:00
// auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
2019-11-20 11:23:08 +01:00
// dnnl::stream stream(engine);
// auto batchnorm_prim_desc = dnnl::batch_normalization_forward::primitive_desc(batchnorm_desc, engine);
// auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
// auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
// auto batchnorm_mean_memory = dnnl::memory(batchnorm_prim_desc.mean_desc(), engine,
2019-10-26 13:14:21 +02:00
// mean->buffer());
2019-11-20 11:23:08 +01:00
// auto batchnorm_variance_memory = dnnl::memory(batchnorm_prim_desc.variance_desc(), engine,
2019-10-26 13:14:21 +02:00
// variance->buffer());
// auto batchnorm_src_memory = user_src_memory;
2019-11-20 11:23:08 +01:00
// dnnl::memory m(batchnorm_src_md, engine);
2019-10-26 13:14:21 +02:00
// if (m.get_desc() != user_src_memory.get_desc()) {
2019-11-20 11:23:08 +01:00
// batchnorm_src_memory = dnnl::memory(batchnorm_src_md, engine);
// dnnl::reorder(user_src_memory, batchnorm_src_memory).execute(stream, user_src_memory,
2019-10-26 13:14:21 +02:00
// batchnorm_src_memory);
// }
// auto batchnorm_dst_memory = user_dst_memory;
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
2019-11-20 11:23:08 +01:00
// batchnorm_dst_memory = dnnl::memory(batchnorm_prim_desc.dst_desc(), engine);
2019-10-26 13:14:21 +02:00
// }
// if (applyScale || applyOffset) {
// if (gamma != nullptr) {
// weights({0, 1, 0, 0}).assign(gamma);
// }
// if (beta != nullptr) {
// weights({1, 2, 0, 0}).assign(beta);
// }
2019-11-20 11:23:08 +01:00
// auto batchnorm_weights_memory = dnnl::memory(batchnorm_prim_desc.weights_desc(), engine, weights.buffer());
// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
2019-10-26 13:14:21 +02:00
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_WEIGHTS, batchnorm_weights_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// } else {
2019-11-20 11:23:08 +01:00
// dnnl::batch_normalization_forward(batchnorm_prim_desc).execute(stream,
2019-10-26 13:14:21 +02:00
// {{MKLDNN_ARG_SRC, batchnorm_src_memory},
// {MKLDNN_ARG_MEAN, batchnorm_mean_memory},
// {MKLDNN_ARG_VARIANCE, batchnorm_variance_memory},
// {MKLDNN_ARG_DST, batchnorm_dst_memory}});
// }
// if (batchnorm_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
2019-11-20 11:23:08 +01:00
// dnnl::reorder(batchnorm_dst_memory, user_dst_memory).execute(stream, batchnorm_dst_memory,
2019-10-26 13:14:21 +02:00
// user_dst_memory);
// }
// stream.wait();
// return Status::OK();
// }
//////////////////////////////////////////////////////////////////////////
// PLATFORM_CHECK(batchnorm) {
// // we don't want to use mkldnn if cpu doesn't support avx/avx2
// if (::optimalLevel() < 2)
// return false;
// auto input = INPUT_VARIABLE(0);
// auto mean = INPUT_VARIABLE(1);
// auto variance = INPUT_VARIABLE(2);
// NDArray *gamma = nullptr;
// NDArray *beta = nullptr;
// auto output = OUTPUT_VARIABLE(0);
// const bool applyScale = (bool) INT_ARG(0);
// const bool applyOffset = (bool) INT_ARG(1);
// const double epsilon = T_ARG(0);
// if (applyScale)
// gamma = INPUT_VARIABLE(3);
// if (applyOffset)
// beta = INPUT_VARIABLE(3 + static_cast<int>(applyScale));
// std::vector<int> axes;
// if (block.numI() > 2)
// for (int i = 2; i < block.numI(); ++i)
// axes.push_back(INT_ARG(i));
// else
// axes.push_back(input->rankOf() - 1);
// return block.isUseMKLDNN() &&
// nd4j::MKLDNNStream::isSupported({input, mean, variance, gamma, beta, output}) &&
// axes.size() == 1;
// }
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL ( batchnorm_bp ) {
2019-11-13 15:15:18 +01:00
NDArray * input = INPUT_VARIABLE ( 0 ) ; // 2D:nc, 4D:nchw, 5D:ncdhw
NDArray * mean = INPUT_VARIABLE ( 1 ) ; // [c]
NDArray * variance = INPUT_VARIABLE ( 2 ) ; // [c]
NDArray * gamma = nullptr ; // [c]
NDArray * beta = nullptr ; // [c]
NDArray * dLdO = INPUT_VARIABLE ( block . width ( ) - 1 ) ; // same as input
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
NDArray * dLdI = OUTPUT_VARIABLE ( 0 ) ; // same as input
NDArray * dLdM = OUTPUT_VARIABLE ( 1 ) ; // [c]
NDArray * dLdV = OUTPUT_VARIABLE ( 2 ) ; // [c]
NDArray * dLdG = nullptr ; // [c]
NDArray * dLdB = nullptr ; // [c]
2019-10-26 13:14:21 +02:00
const bool applyScale = ( bool ) INT_ARG ( 0 ) ;
const bool applyOffset = ( bool ) INT_ARG ( 1 ) ;
const float epsilon = T_ARG ( 0 ) ;
if ( applyScale ) {
2019-11-13 15:15:18 +01:00
gamma = INPUT_VARIABLE ( 3 ) ;
2019-10-26 13:14:21 +02:00
dLdG = OUTPUT_VARIABLE ( 3 ) ;
}
if ( applyOffset ) {
2019-11-13 15:15:18 +01:00
beta = INPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
2019-10-26 13:14:21 +02:00
dLdB = OUTPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
}
const int numOfIntArgs = block . getIArguments ( ) - > size ( ) ;
const int inRank = input - > rankOf ( ) ;
// get axes args to normalize input array over
std : : vector < int > axes ;
if ( numOfIntArgs > 2 )
for ( int i = 2 ; i < numOfIntArgs ; + + i )
axes . push_back ( INT_ARG ( i ) ) ;
else
axes . push_back ( inRank - 1 ) ; // default dimension to reduce along is last dimension
const int numOfAxes = axes . size ( ) ;
REQUIRE_TRUE ( numOfAxes = = 1 , 0 , " BATCHNORM_BP_MKLDNN op: mkl dnn library supports only one axis which represents channel dimension, but got %i axes instead! " , numOfAxes ) ;
REQUIRE_TRUE ( inRank = = 2 | | inRank = = 4 | | inRank = = 5 , 0 , " BATCHNORM_BP_MKLDNN op: possible values for rank of input array are 2, 4 or 5, but got %i instead! " , inRank ) ;
REQUIRE_TRUE ( input - > isSameShape ( dLdO ) , 0 , " BATCHNORM_BP_MKLDNN op: wrong shape of gradients array, expected is %s, but got %s instead ! " , ShapeUtils : : shapeAsString ( input ) . c_str ( ) , ShapeUtils : : shapeAsString ( dLdO ) . c_str ( ) ) ;
REQUIRE_TRUE ( mean - > rankOf ( ) = = 1 & & mean - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_BP_MKLDNN op: wrong shape of mean array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( mean ) . c_str ( ) ) ;
REQUIRE_TRUE ( variance - > rankOf ( ) = = 1 & & variance - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_BP_MKLDNN op: wrong shape of variance array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( variance ) . c_str ( ) ) ;
if ( gamma ! = nullptr )
REQUIRE_TRUE ( gamma - > rankOf ( ) = = 1 & & gamma - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_BP_MKLDNN op: wrong shape of gamma array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( gamma ) . c_str ( ) ) ;
if ( beta ! = nullptr )
REQUIRE_TRUE ( beta - > rankOf ( ) = = 1 & & beta - > sizeAt ( 0 ) = = input - > sizeAt ( axes [ 0 ] ) , 0 , " BATCHNORM_BP_MKLDNN op: wrong shape of beta array, expected is [%lld], but got %s instead ! " , input - > sizeAt ( axes [ 0 ] ) , ShapeUtils : : shapeAsString ( beta ) . c_str ( ) ) ;
2019-11-13 15:15:18 +01:00
// types of all input arrays should be the same
2019-10-26 13:14:21 +02:00
for ( int i = 1 ; i < block . width ( ) - 1 ; + + i )
REQUIRE_TRUE ( INPUT_VARIABLE ( 0 ) - > dataType ( ) = = INPUT_VARIABLE ( i ) - > dataType ( ) , 0 , " BATCHNORM_BP_MKLDNN op: types of all input arrays should be the same ! " ) ;
NDArray * weights = nullptr , * dLdW = nullptr ;
if ( applyScale | | applyOffset ) {
weights = new NDArray ( input - > ordering ( ) , { 2 , input - > sizeAt ( axes [ 0 ] ) } , input - > dataType ( ) ) ;
dLdW = new NDArray ( input - > ordering ( ) , { 2 , input - > sizeAt ( axes [ 0 ] ) } , input - > dataType ( ) ) ;
if ( applyScale )
( * weights ) ( { 0 , 1 , 0 , 0 } ) . assign ( gamma ) ;
else
( * weights ) ( { 0 , 1 , 0 , 0 } ) . assign ( 1 ) ;
if ( applyOffset )
( * weights ) ( { 1 , 2 , 0 , 0 } ) . assign ( beta ) ;
else
( * weights ) ( { 1 , 2 , 0 , 0 } ) . assign ( 0 ) ;
2019-09-11 20:50:28 +02:00
}
2019-10-26 13:14:21 +02:00
2019-11-13 15:15:18 +01:00
if ( axes [ 0 ] = = inRank - 1 & & inRank > 2 ) { // if nhwc or ndhwc
std : : vector < int > permut = inRank = = 4 ? std : : vector < int > ( { 0 , 3 , 1 , 2 } ) : std : : vector < int > ( { 0 , 4 , 1 , 2 , 3 } ) ;
input = new NDArray ( input - > permute ( permut ) ) ;
dLdO = new NDArray ( dLdO - > permute ( permut ) ) ;
dLdI = new NDArray ( dLdI - > permute ( permut ) ) ;
}
2019-10-26 13:14:21 +02:00
batchnormBackPropMKLDNN ( input , mean , variance , dLdO , weights , epsilon , dLdI , dLdW ) ;
2019-11-13 15:15:18 +01:00
* dLdM = 0 ;
* dLdV = 0 ;
2019-10-26 13:14:21 +02:00
if ( applyScale | | applyOffset ) {
if ( applyScale )
dLdG - > assign ( ( * dLdW ) ( { 0 , 1 , 0 , 0 } ) ) ;
if ( applyOffset )
dLdB - > assign ( ( * dLdW ) ( { 1 , 2 , 0 , 0 } ) ) ;
delete weights ;
delete dLdW ;
}
2019-11-13 15:15:18 +01:00
if ( axes [ 0 ] = = inRank - 1 & & inRank > 2 ) {
delete input ;
delete dLdO ;
delete dLdI ;
}
2019-10-26 13:14:21 +02:00
return Status : : OK ( ) ;
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK ( batchnorm_bp ) {
NDArray * input = INPUT_VARIABLE ( 0 ) ; // 2D:nc, 4D:nchw, 5D:ncdhw
NDArray * mean = INPUT_VARIABLE ( 1 ) ; // [c]
NDArray * variance = INPUT_VARIABLE ( 2 ) ; // [c]
NDArray * dLdO = INPUT_VARIABLE ( 3 ) ; // same as input
NDArray * gamma = nullptr ; // [c]
NDArray * beta = nullptr ; // [c]
NDArray * dLdI = OUTPUT_VARIABLE ( 0 ) ; // same as input
NDArray * dLdM = OUTPUT_VARIABLE ( 1 ) ; // [c]
NDArray * dLdV = OUTPUT_VARIABLE ( 2 ) ; // [c]
NDArray * dLdG = nullptr ; // [c]
NDArray * dLdB = nullptr ; // [c]
const bool applyScale = ( bool ) INT_ARG ( 0 ) ;
const bool applyOffset = ( bool ) INT_ARG ( 1 ) ;
if ( applyScale ) {
gamma = INPUT_VARIABLE ( 4 ) ;
dLdG = OUTPUT_VARIABLE ( 3 ) ;
}
if ( applyOffset ) {
beta = INPUT_VARIABLE ( 4 + ( int ) applyScale ) ;
dLdB = OUTPUT_VARIABLE ( 3 + ( int ) applyScale ) ;
}
const int numOfIntArgs = block . getIArguments ( ) - > size ( ) ;
std : : vector < int > axes ;
if ( numOfIntArgs > 2 )
for ( int i = 2 ; i < numOfIntArgs ; + + i )
axes . push_back ( INT_ARG ( i ) ) ;
else
axes . push_back ( input - > rankOf ( ) - 1 ) ; // default dimension to reduce along is last dimension
DataType inputType = input - > dataType ( ) ;
DataType meanType = mean - > dataType ( ) ;
DataType varType = variance - > dataType ( ) ;
DataType dLdOType = dLdO - > dataType ( ) ;
DataType gammaType = gamma ! = nullptr ? gamma - > dataType ( ) : DataType : : FLOAT32 ;
DataType betaType = beta ! = nullptr ? beta - > dataType ( ) : DataType : : FLOAT32 ;
DataType dLdIType = dLdI - > dataType ( ) ;
DataType dLdGType = gamma ! = nullptr ? dLdG - > dataType ( ) : DataType : : FLOAT32 ;
DataType dLdBType = beta ! = nullptr ? dLdB - > dataType ( ) : DataType : : FLOAT32 ;
const int inRank = input - > rankOf ( ) ;
2019-11-13 15:15:18 +01:00
return block . isUseMKLDNN ( ) & & axes . size ( ) = = 1 & & ( axes [ 0 ] = = 1 | | axes [ 0 ] = = inRank - 1 ) & & ( inRank = = 2 | | inRank = = 4 | | inRank = = 5 ) & &
2019-10-26 13:14:21 +02:00
( inputType = = DataType : : FLOAT32 & & meanType = = DataType : : FLOAT32 & & varType = = DataType : : FLOAT32 & &
dLdOType = = DataType : : FLOAT32 & & gammaType = = DataType : : FLOAT32 & & betaType = = DataType : : FLOAT32 & &
dLdIType = = DataType : : FLOAT32 & & dLdGType = = DataType : : FLOAT32 & & dLdBType = = DataType : : FLOAT32 ) ;
}
}
}
2019-09-11 20:50:28 +02:00
}