2021-02-09 05:16:31 +01:00
/*
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License , Version 2.0 which is available at
* * https : //www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership .
* * Unless required by applicable law or agreed to in writing , software
* * distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* * License for the specific language governing permissions and limitations
* * under the License .
* *
* * SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
*/
2020-03-04 17:36:42 +01:00
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
//
# include <ops/declarable/PlatformHelper.h>
# include <ops/declarable/OpRegistrator.h>
# include <system/platform_boilerplate.h>
# include <helpers/MKLDNNStream.h>
# include "mkldnnUtils.h"
using namespace dnnl ;
namespace sd {
namespace ops {
namespace platforms {
2020-03-12 16:25:29 +01:00
2020-03-04 17:36:42 +01:00
//////////////////////////////////////////////////////////////////////
static void softmaxMKLDNN ( const NDArray * x , NDArray * z , const int axis ) {
2020-05-12 06:47:09 +02:00
dnnl : : memory : : dims shape = x - > getShapeAsFlatVector ( ) ;
2020-03-04 17:36:42 +01:00
2020-05-12 06:47:09 +02:00
const int xRank = x - > rankOf ( ) ;
2020-03-04 17:36:42 +01:00
2020-05-12 06:47:09 +02:00
dnnl : : memory : : format_tag xFormat = mkldnnUtils : : getFormat ( * x ) ;
dnnl : : memory : : format_tag zFormat = mkldnnUtils : : getFormat ( * z ) ;
2020-03-04 17:36:42 +01:00
2020-03-12 16:25:29 +01:00
// optimized cases
if ( 2 = = xRank & & 0 = = axis ) {
2020-05-12 06:47:09 +02:00
if ( x - > ews ( ) = = 1 )
xFormat = dnnl : : memory : : format_tag : : ba ;
if ( z - > ews ( ) = = 1 )
zFormat = dnnl : : memory : : format_tag : : ba ;
2020-03-04 17:36:42 +01:00
}
2020-03-12 16:25:29 +01:00
else if ( 4 = = xRank & & 1 = = axis & & ( x - > sizeAt ( 2 ) * x - > sizeAt ( 3 ) ) > 1 ) {
2020-05-12 06:47:09 +02:00
if ( x - > ews ( ) = = 1 )
xFormat = dnnl : : memory : : format_tag : : acdb ;
if ( z - > ews ( ) = = 1 )
zFormat = dnnl : : memory : : format_tag : : acdb ;
2020-03-04 17:36:42 +01:00
}
dnnl : : memory : : data_type xType = dnnl : : memory : : data_type : : f32 ;
2020-05-12 06:47:09 +02:00
dnnl : : memory : : desc x_mkl_md , x_user_md , z_mkl_md , z_user_md ;
x_user_md = x_mkl_md = dnnl : : memory : : desc ( shape , xType , xFormat ) ;
mkldnnUtils : : setBlockStrides ( * x , x_user_md ) ;
2020-03-04 17:36:42 +01:00
// z
2020-05-12 06:47:09 +02:00
z_user_md = z_mkl_md = dnnl : : memory : : desc ( shape , xType , zFormat ) ;
mkldnnUtils : : setBlockStrides ( * z , z_user_md ) ;
2020-03-04 17:36:42 +01:00
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// Create attributes (to handle alpha and beta if necessary)
dnnl : : primitive_attr attr ; // it is empty since we have usual values for alpha (=1) and beta (=0)
// operation primitive description
dnnl : : softmax_forward : : desc op_desc ( dnnl : : prop_kind : : forward_inference , x_mkl_md , axis ) ;
dnnl : : softmax_forward : : primitive_desc op_prim_desc ( op_desc , attr , engine ) ;
// arguments (memory buffers) necessary for calculations
std : : unordered_map < int , dnnl : : memory > args ;
dnnl : : stream stream ( engine ) ;
// provide memory buffers and check whether reorder is required
// input
2020-05-12 06:47:09 +02:00
mkldnnUtils : : loadDataToMklStream ( * x , engine , stream , x_user_md , op_prim_desc . src_desc ( ) , args [ DNNL_ARG_SRC ] ) ;
2020-03-04 17:36:42 +01:00
// z
2020-05-12 06:47:09 +02:00
auto z_user_mem = mkldnnUtils : : loadDataToMklStream ( * z , engine , stream , z_user_md , op_prim_desc . dst_desc ( ) , args [ DNNL_ARG_DST ] ) ;
2020-03-04 17:36:42 +01:00
// run calculations
dnnl : : softmax_forward ( op_prim_desc ) . execute ( stream , args ) ;
// reorder outputs if necessary
2020-05-12 06:47:09 +02:00
if ( op_prim_desc . dst_desc ( ) ! = z_user_mem . get_desc ( ) )
dnnl : : reorder ( args [ DNNL_ARG_DST ] , z_user_mem ) . execute ( stream , args [ DNNL_ARG_DST ] , z_user_mem ) ;
2020-03-04 17:36:42 +01:00
stream . wait ( ) ;
}
PLATFORM_IMPL ( softmax , ENGINE_CPU ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
const int rank = input - > rankOf ( ) ;
int dim = block . getIArguments ( ) - > size ( ) > 0 ? INT_ARG ( 0 ) : rank - 1 ;
if ( dim < 0 ) {
dim + = rank ;
}
REQUIRE_TRUE ( dim < rank & & dim > = 0 , 0 , " SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead ! " , rank , dim ) ;
Tanh mkldnn implementation (#296)
* libnd4j first step of softmax mkldnn implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of mkldnn softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and added softmax to MklDnnTests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for softmax mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor corrections to avoid risk connected with negative axis usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed dataType selection per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix for mac and windows builds
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j builds fix
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j first spet of elementwize tanh implementation on mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed typo in error message for softmax MKLDNN, test case added, implementation of tanh on MKLDNN, need supported DataType testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for tanh and temporary performance test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed mkldnn platform loader for tanh
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j MklDnn tanh removed unsupported data types, removed performance test case, added more appropriate equivalence test case, code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed problem with empty input case for MklDnn tanh and softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
2020-03-06 15:11:22 +01:00
REQUIRE_TRUE ( rank < = 6 , 0 , " SOFTMAX_MKLDNN OP: the rank of input must be less or qual 6, but got rank = %i instead ! " , rank ) ;
2020-03-04 17:36:42 +01:00
// mkldnnSoftMax
softmaxMKLDNN ( input , output , dim ) ;
return Status : : OK ( ) ;
}
PLATFORM_CHECK ( softmax , ENGINE_CPU ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto z = OUTPUT_VARIABLE ( 0 ) ;
const DataType xType = x - > dataType ( ) ;
const DataType zType = z - > dataType ( ) ;
const int xRank = x - > rankOf ( ) ;
bool bSupportedRanks = ( xRank > 2 & & xRank < 7 ) ;
/*
Source Destination
f32 f32
*/
Tanh mkldnn implementation (#296)
* libnd4j first step of softmax mkldnn implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of mkldnn softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and added softmax to MklDnnTests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for softmax mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor corrections to avoid risk connected with negative axis usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed dataType selection per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix for mac and windows builds
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j builds fix
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j first spet of elementwize tanh implementation on mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed typo in error message for softmax MKLDNN, test case added, implementation of tanh on MKLDNN, need supported DataType testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for tanh and temporary performance test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed mkldnn platform loader for tanh
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j MklDnn tanh removed unsupported data types, removed performance test case, added more appropriate equivalence test case, code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed problem with empty input case for MklDnn tanh and softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
2020-03-06 15:11:22 +01:00
return ! x - > isEmpty ( ) & & block . isUseMKLDNN ( ) & & bSupportedRanks & & ( xType = = DataType : : FLOAT32 & & zType = = DataType : : FLOAT32 ) ;
2020-03-04 17:36:42 +01:00
}
2020-03-12 16:25:29 +01:00
//////////////////////////////////////////////////////////////////////
static void softmaxBpMKLDNN ( const NDArray * x , const NDArray * dLdz , NDArray * dLdx , const int axis ) {
2020-05-12 06:47:09 +02:00
dnnl : : memory : : desc x_user_md , x_mkl_md , dLdx_mkl_md , dLdx_user_md , dLdz_mkl_md , dLdz_user_md ;
2020-03-12 16:25:29 +01:00
// x
2020-05-12 06:47:09 +02:00
x_mkl_md = x_user_md = dnnl : : memory : : desc ( x - > getShapeAsFlatVector ( ) , dnnl : : memory : : data_type : : f32 , mkldnnUtils : : getFormat ( * x ) ) ;
mkldnnUtils : : setBlockStrides ( * x , x_user_md ) ;
2020-03-12 16:25:29 +01:00
// dLdx
2020-05-12 06:47:09 +02:00
dLdx_mkl_md = dLdx_user_md = dnnl : : memory : : desc ( dLdx - > getShapeAsFlatVector ( ) , dnnl : : memory : : data_type : : f32 , mkldnnUtils : : getFormat ( * dLdx ) ) ;
mkldnnUtils : : setBlockStrides ( * dLdx , dLdx_user_md ) ;
2020-03-12 16:25:29 +01:00
// dLdz
2020-05-12 06:47:09 +02:00
dLdz_mkl_md = dLdz_user_md = dnnl : : memory : : desc ( dLdz - > getShapeAsFlatVector ( ) , dnnl : : memory : : data_type : : f32 , mkldnnUtils : : getFormat ( * dLdz ) ) ;
mkldnnUtils : : setBlockStrides ( * dLdz , dLdz_user_md ) ;
2020-03-12 16:25:29 +01:00
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// operation primitive description
// forward description
dnnl : : softmax_forward : : desc op_ff_desc ( dnnl : : prop_kind : : forward_inference , x_mkl_md , axis ) ;
dnnl : : softmax_forward : : primitive_desc op_ff_prim_desc ( op_ff_desc , engine ) ;
// backward description
dnnl : : softmax_backward : : desc op_bp_desc ( dLdz_mkl_md , dLdx_mkl_md , axis ) ;
dnnl : : softmax_backward : : primitive_desc op_bp_prim_desc ( op_bp_desc , engine , op_ff_prim_desc ) ;
// arguments (memory buffers) necessary for calculations
std : : unordered_map < int , dnnl : : memory > argsbp , argsff ;
dnnl : : stream stream ( engine ) ;
// provide memory buffers and check whether reorder is required for forward
// input
2020-05-12 06:47:09 +02:00
mkldnnUtils : : loadDataToMklStream ( * x , engine , stream , x_user_md , op_ff_prim_desc . src_desc ( ) , argsff [ DNNL_ARG_SRC ] ) ;
// dLdz
mkldnnUtils : : loadDataToMklStream ( * dLdz , engine , stream , dLdz_user_md , op_bp_prim_desc . diff_dst_desc ( ) , argsbp [ DNNL_ARG_DIFF_DST ] ) ;
2020-03-12 16:25:29 +01:00
// dLdx
2020-05-12 06:47:09 +02:00
auto dLdx_user_mem = mkldnnUtils : : loadDataToMklStream ( * dLdx , engine , stream , dLdx_user_md , op_ff_prim_desc . src_desc ( ) , argsff [ DNNL_ARG_DST ] ) ;
2020-03-12 16:25:29 +01:00
// check and arg set for backprob
2020-05-12 06:47:09 +02:00
argsbp [ DNNL_ARG_DIFF_SRC ] = argsff [ DNNL_ARG_DST ] ;
argsbp [ DNNL_ARG_DST ] = argsff [ DNNL_ARG_DST ] ;
2020-03-12 16:25:29 +01:00
// run calculations forward
dnnl : : softmax_forward ( op_ff_prim_desc ) . execute ( stream , argsff ) ;
// run calculations backward
dnnl : : softmax_backward ( op_bp_prim_desc ) . execute ( stream , argsbp ) ;
// reorder outputs if necessary
2020-05-12 06:47:09 +02:00
if ( op_ff_prim_desc . dst_desc ( ) ! = dLdx_user_mem . get_desc ( ) )
dnnl : : reorder ( argsff [ DNNL_ARG_DST ] , dLdx_user_mem ) . execute ( stream , argsff [ DNNL_ARG_DST ] , dLdx_user_mem ) ;
2020-03-12 16:25:29 +01:00
stream . wait ( ) ;
}
PLATFORM_IMPL ( softmax_bp , ENGINE_CPU ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto dLdz = INPUT_VARIABLE ( 1 ) ;
auto dLdx = OUTPUT_VARIABLE ( 0 ) ;
const int rank = input - > rankOf ( ) ;
const int dLdzRank = dLdz - > rankOf ( ) ;
int dim = block . getIArguments ( ) - > size ( ) > 0 ? INT_ARG ( 0 ) : rank - 1 ;
if ( dim < 0 ) {
dim + = rank ;
}
REQUIRE_TRUE ( dim < rank & & dim > = 0 , 0 , " SOFTMAX_MKLDNN_BP OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead ! " , rank , dim ) ;
REQUIRE_TRUE ( rank < = 6 & & dLdzRank < = 6 , 0 , " SOFTMAX_MKLDNN_BP OP: the rank of input and dLdz must be less or qual 6, but got input rank = %i and dLdz rank rank = %i instead ! " , rank , dLdzRank ) ;
// mkldnnSoftMax
softmaxBpMKLDNN ( input , dLdz , dLdx , dim ) ;
return Status : : OK ( ) ;
}
PLATFORM_CHECK ( softmax_bp , ENGINE_CPU ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto dLdz = INPUT_VARIABLE ( 1 ) ;
auto dLdx = OUTPUT_VARIABLE ( 0 ) ;
const DataType xType = x - > dataType ( ) ;
const DataType dLdzType = dLdz - > dataType ( ) ;
const DataType dLdxType = dLdx - > dataType ( ) ;
const int xRank = x - > rankOf ( ) ;
const int dLdzRank = dLdz - > rankOf ( ) ;
bool bSupportedRanks = xRank < 7 & & dLdzRank = = xRank & & ( ! x - > isEmpty ( ) & & ! dLdz - > isEmpty ( ) ) ;
if ( bSupportedRanks ) {
for ( int i = 0 ; i < xRank ; i + + ) {
if ( x - > sizeAt ( i ) ! = dLdz - > sizeAt ( i ) ) {
bSupportedRanks = false ;
break ;
}
}
}
//Source Destination
//f32 f32
return block . isUseMKLDNN ( ) & & bSupportedRanks & & ( xType = = DataType : : FLOAT32 & & dLdzType = = DataType : : FLOAT32 & & dLdxType = = DataType : : FLOAT32 ) ;
}
2020-03-04 17:36:42 +01:00
}
}
}