2020-03-04 17:36:42 +01:00
/*******************************************************************************
* Copyright ( c ) 2019 - 2020 Konduit K . K .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
//
# include <ops/declarable/PlatformHelper.h>
# include <ops/declarable/OpRegistrator.h>
# include <system/platform_boilerplate.h>
# include <helpers/MKLDNNStream.h>
# include "mkldnnUtils.h"
using namespace dnnl ;
namespace sd {
namespace ops {
namespace platforms {
2020-03-12 16:25:29 +01:00
2020-03-04 17:36:42 +01:00
//////////////////////////////////////////////////////////////////////
static void softmaxMKLDNN ( const NDArray * x , NDArray * z , const int axis ) {
const auto xRank = x - > rankOf ( ) ;
2020-03-12 16:25:29 +01:00
dnnl : : memory : : dims xShape , zShape ;
2020-03-04 17:36:42 +01:00
2020-03-12 16:25:29 +01:00
mkldnnUtils : : getDims ( x , xRank , xShape ) ;
mkldnnUtils : : getDims ( z , xRank , zShape ) ;
2020-03-04 17:36:42 +01:00
2020-03-12 16:25:29 +01:00
dnnl : : memory : : format_tag format = mkldnnUtils : : getFormat ( xRank ) ;
// optimized cases
if ( 2 = = xRank & & 0 = = axis ) {
2020-03-04 17:36:42 +01:00
format = dnnl : : memory : : format_tag : : ba ;
}
2020-03-12 16:25:29 +01:00
else if ( 4 = = xRank & & 1 = = axis & & ( x - > sizeAt ( 2 ) * x - > sizeAt ( 3 ) ) > 1 ) {
2020-03-04 17:36:42 +01:00
format = dnnl : : memory : : format_tag : : acdb ;
}
dnnl : : memory : : data_type xType = dnnl : : memory : : data_type : : f32 ;
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( xShape , xType , format ) ;
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( xShape , xType , format ) ;
2020-03-12 16:25:29 +01:00
mkldnnUtils : : setBlockStrides ( x , xRank , x_user_md ) ;
2020-03-04 17:36:42 +01:00
// z
2020-03-12 16:25:29 +01:00
dnnl : : memory : : desc z_mkl_md = dnnl : : memory : : desc ( zShape , xType , format ) ;
dnnl : : memory : : desc z_user_md = dnnl : : memory : : desc ( zShape , xType , format ) ;
mkldnnUtils : : setBlockStrides ( z , xRank , z_user_md ) ;
2020-03-04 17:36:42 +01:00
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// Create attributes (to handle alpha and beta if necessary)
dnnl : : primitive_attr attr ; // it is empty since we have usual values for alpha (=1) and beta (=0)
// operation primitive description
dnnl : : softmax_forward : : desc op_desc ( dnnl : : prop_kind : : forward_inference , x_mkl_md , axis ) ;
dnnl : : softmax_forward : : primitive_desc op_prim_desc ( op_desc , attr , engine ) ;
// arguments (memory buffers) necessary for calculations
std : : unordered_map < int , dnnl : : memory > args ;
dnnl : : stream stream ( engine ) ;
// provide memory buffers and check whether reorder is required
// input
2020-03-12 16:25:29 +01:00
mkldnnUtils : : loadDataToMklStream ( x , engine , stream , args , x_user_md , op_prim_desc . src_desc ( ) , DNNL_ARG_SRC ) ;
2020-03-04 17:36:42 +01:00
// z
auto z_user_mem = dnnl : : memory ( z_user_md , engine , z - > getBuffer ( ) ) ;
const bool zReorder = op_prim_desc . dst_desc ( ) ! = z_user_mem . get_desc ( ) ;
auto z_mkl_mem = zReorder ? dnnl : : memory ( op_prim_desc . dst_desc ( ) , engine ) : z_user_mem ;
args [ DNNL_ARG_DST ] = z_mkl_mem ;
// run calculations
dnnl : : softmax_forward ( op_prim_desc ) . execute ( stream , args ) ;
// reorder outputs if necessary
if ( zReorder )
dnnl : : reorder ( z_mkl_mem , z_user_mem ) . execute ( stream , z_mkl_mem , z_user_mem ) ;
stream . wait ( ) ;
}
PLATFORM_IMPL ( softmax , ENGINE_CPU ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
const int rank = input - > rankOf ( ) ;
int dim = block . getIArguments ( ) - > size ( ) > 0 ? INT_ARG ( 0 ) : rank - 1 ;
if ( dim < 0 ) {
dim + = rank ;
}
REQUIRE_TRUE ( dim < rank & & dim > = 0 , 0 , " SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead ! " , rank , dim ) ;
Tanh mkldnn implementation (#296)
* libnd4j first step of softmax mkldnn implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of mkldnn softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and added softmax to MklDnnTests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for softmax mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor corrections to avoid risk connected with negative axis usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed dataType selection per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix for mac and windows builds
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j builds fix
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j first spet of elementwize tanh implementation on mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed typo in error message for softmax MKLDNN, test case added, implementation of tanh on MKLDNN, need supported DataType testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for tanh and temporary performance test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed mkldnn platform loader for tanh
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j MklDnn tanh removed unsupported data types, removed performance test case, added more appropriate equivalence test case, code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed problem with empty input case for MklDnn tanh and softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
2020-03-06 15:11:22 +01:00
REQUIRE_TRUE ( rank < = 6 , 0 , " SOFTMAX_MKLDNN OP: the rank of input must be less or qual 6, but got rank = %i instead ! " , rank ) ;
2020-03-04 17:36:42 +01:00
// mkldnnSoftMax
softmaxMKLDNN ( input , output , dim ) ;
return Status : : OK ( ) ;
}
PLATFORM_CHECK ( softmax , ENGINE_CPU ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto z = OUTPUT_VARIABLE ( 0 ) ;
const DataType xType = x - > dataType ( ) ;
const DataType zType = z - > dataType ( ) ;
const int xRank = x - > rankOf ( ) ;
bool bSupportedRanks = ( xRank > 2 & & xRank < 7 ) ;
/*
Source Destination
f32 f32
*/
Tanh mkldnn implementation (#296)
* libnd4j first step of softmax mkldnn implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of mkldnn softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and added softmax to MklDnnTests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for softmax mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor corrections to avoid risk connected with negative axis usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed dataType selection per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix for mac and windows builds
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j builds fix
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j first spet of elementwize tanh implementation on mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed typo in error message for softmax MKLDNN, test case added, implementation of tanh on MKLDNN, need supported DataType testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for tanh and temporary performance test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed mkldnn platform loader for tanh
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j MklDnn tanh removed unsupported data types, removed performance test case, added more appropriate equivalence test case, code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed problem with empty input case for MklDnn tanh and softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
2020-03-06 15:11:22 +01:00
return ! x - > isEmpty ( ) & & block . isUseMKLDNN ( ) & & bSupportedRanks & & ( xType = = DataType : : FLOAT32 & & zType = = DataType : : FLOAT32 ) ;
2020-03-04 17:36:42 +01:00
}
2020-03-12 16:25:29 +01:00
//////////////////////////////////////////////////////////////////////
static void softmaxBpMKLDNN ( const NDArray * x , const NDArray * dLdz , NDArray * dLdx , const int axis ) {
const auto xRank = x - > rankOf ( ) ;
const auto dLdzRank = dLdz - > rankOf ( ) ;
dnnl : : memory : : dims xShape , dLdxShape , dLdzShape ;
mkldnnUtils : : getDims ( x , xRank , xShape ) ;
mkldnnUtils : : getDims ( dLdx , xRank , dLdxShape ) ;
mkldnnUtils : : getDims ( dLdz , dLdzRank , dLdzShape ) ;
dnnl : : memory : : format_tag format = mkldnnUtils : : getFormat ( xRank ) ;
// x
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( xShape , dnnl : : memory : : data_type : : f32 , format ) ;
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( xShape , dnnl : : memory : : data_type : : f32 , format ) ;
mkldnnUtils : : setBlockStrides ( x , xRank , x_user_md ) ;
// dLdx
dnnl : : memory : : desc dLdx_mkl_md = dnnl : : memory : : desc ( dLdxShape , dnnl : : memory : : data_type : : f32 , format ) ;
dnnl : : memory : : desc dLdx_user_md = dnnl : : memory : : desc ( dLdxShape , dnnl : : memory : : data_type : : f32 , format ) ;
mkldnnUtils : : setBlockStrides ( dLdx , xRank , dLdx_user_md ) ;
// todo if mkl does not support broadcast we can remove this
format = mkldnnUtils : : getFormat ( dLdzRank ) ;
// dLdz
dnnl : : memory : : desc dLdz_mkl_md = dnnl : : memory : : desc ( dLdzShape , dnnl : : memory : : data_type : : f32 , format ) ;
dnnl : : memory : : desc dLdz_user_md = dnnl : : memory : : desc ( dLdzShape , dnnl : : memory : : data_type : : f32 , format ) ;
mkldnnUtils : : setBlockStrides ( dLdz , dLdzRank , dLdz_user_md ) ;
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// operation primitive description
// forward description
dnnl : : softmax_forward : : desc op_ff_desc ( dnnl : : prop_kind : : forward_inference , x_mkl_md , axis ) ;
dnnl : : softmax_forward : : primitive_desc op_ff_prim_desc ( op_ff_desc , engine ) ;
// backward description
dnnl : : softmax_backward : : desc op_bp_desc ( dLdz_mkl_md , dLdx_mkl_md , axis ) ;
dnnl : : softmax_backward : : primitive_desc op_bp_prim_desc ( op_bp_desc , engine , op_ff_prim_desc ) ;
// arguments (memory buffers) necessary for calculations
std : : unordered_map < int , dnnl : : memory > argsbp , argsff ;
dnnl : : stream stream ( engine ) ;
// provide memory buffers and check whether reorder is required for forward
// input
mkldnnUtils : : loadDataToMklStream ( x , engine , stream , argsff , x_user_md , op_ff_prim_desc . src_desc ( ) , DNNL_ARG_SRC ) ;
// dLdx
auto dLdx_user_mem = dnnl : : memory ( dLdx_user_md , engine , dLdx - > getBuffer ( ) ) ;
const bool dLdxReorder = op_ff_prim_desc . dst_desc ( ) ! = dLdx_user_mem . get_desc ( ) ;
auto dLdx_mkl_mem = dLdxReorder ? dnnl : : memory ( op_ff_prim_desc . dst_desc ( ) , engine ) : dLdx_user_mem ;
argsff [ DNNL_ARG_DST ] = dLdx_mkl_mem ;
// check and arg set for backprob
argsbp [ DNNL_ARG_DIFF_SRC ] = dLdx_mkl_mem ;
argsbp [ DNNL_ARG_DST ] = dLdx_mkl_mem ;
// dLdz
mkldnnUtils : : loadDataToMklStream ( dLdz , engine , stream , argsbp , dLdz_user_md , op_bp_prim_desc . diff_dst_desc ( ) , DNNL_ARG_DIFF_DST ) ;
// run calculations forward
dnnl : : softmax_forward ( op_ff_prim_desc ) . execute ( stream , argsff ) ;
// run calculations backward
dnnl : : softmax_backward ( op_bp_prim_desc ) . execute ( stream , argsbp ) ;
// reorder outputs if necessary
if ( dLdxReorder )
dnnl : : reorder ( dLdx_mkl_mem , dLdx_user_mem ) . execute ( stream , dLdx_mkl_mem , dLdx_user_mem ) ;
stream . wait ( ) ;
}
PLATFORM_IMPL ( softmax_bp , ENGINE_CPU ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto dLdz = INPUT_VARIABLE ( 1 ) ;
auto dLdx = OUTPUT_VARIABLE ( 0 ) ;
const int rank = input - > rankOf ( ) ;
const int dLdzRank = dLdz - > rankOf ( ) ;
int dim = block . getIArguments ( ) - > size ( ) > 0 ? INT_ARG ( 0 ) : rank - 1 ;
if ( dim < 0 ) {
dim + = rank ;
}
REQUIRE_TRUE ( dim < rank & & dim > = 0 , 0 , " SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead ! " , rank , dim ) ;
REQUIRE_TRUE ( rank < = 6 & & dLdzRank < = 6 , 0 , " SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead ! " , rank , dLdzRank ) ;
// mkldnnSoftMax
softmaxBpMKLDNN ( input , dLdz , dLdx , dim ) ;
return Status : : OK ( ) ;
}
PLATFORM_CHECK ( softmax_bp , ENGINE_CPU ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto dLdz = INPUT_VARIABLE ( 1 ) ;
auto dLdx = OUTPUT_VARIABLE ( 0 ) ;
const DataType xType = x - > dataType ( ) ;
const DataType dLdzType = dLdz - > dataType ( ) ;
const DataType dLdxType = dLdx - > dataType ( ) ;
const int xRank = x - > rankOf ( ) ;
const int dLdzRank = dLdz - > rankOf ( ) ;
bool bSupportedRanks = xRank < 7 & & dLdzRank = = xRank & & ( ! x - > isEmpty ( ) & & ! dLdz - > isEmpty ( ) ) ;
if ( bSupportedRanks ) {
for ( int i = 0 ; i < xRank ; i + + ) {
if ( x - > sizeAt ( i ) ! = dLdz - > sizeAt ( i ) ) {
bSupportedRanks = false ;
break ;
}
}
}
//Source Destination
//f32 f32
return block . isUseMKLDNN ( ) & & bSupportedRanks & & ( xType = = DataType : : FLOAT32 & & dLdzType = = DataType : : FLOAT32 & & dLdxType = = DataType : : FLOAT32 ) ;
}
2020-03-04 17:36:42 +01:00
}
}
}