2020-03-04 17:36:42 +01:00
/*******************************************************************************
* Copyright ( c ) 2019 - 2020 Konduit K . K .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Oleg Semeniv <oleg.semeniv@gmail.com>
//
//
# include <ops/declarable/PlatformHelper.h>
# include <ops/declarable/OpRegistrator.h>
# include <system/platform_boilerplate.h>
# include <helpers/MKLDNNStream.h>
# include "mkldnnUtils.h"
using namespace dnnl ;
namespace sd {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////
static void softmaxMKLDNN ( const NDArray * x , NDArray * z , const int axis ) {
const auto xRank = x - > rankOf ( ) ;
const auto zRank = z - > rankOf ( ) ;
std : : vector < int64_t > dimsX ( xRank ) , dimsZ ( zRank ) ;
for ( auto i = 0 ; i < xRank ; i + + ) {
dimsX [ i ] = x - > sizeAt ( i ) ;
dimsZ [ i ] = z - > sizeAt ( i ) ;
}
dnnl : : memory : : dims xShape = dnnl : : memory : : dims ( dimsX ) ;
dnnl : : memory : : dims zShape = dnnl : : memory : : dims ( dimsZ ) ;
dnnl : : memory : : format_tag format = dnnl : : memory : : format_tag : : a ; // 1 == xRank
if ( 2 = = xRank & & 1 = = axis ) {
format = dnnl : : memory : : format_tag : : ab ;
}
else if ( 2 = = xRank & & 0 = = axis ) {
format = dnnl : : memory : : format_tag : : ba ;
}
else if ( 3 = = xRank ) {
format = dnnl : : memory : : format_tag : : abc ;
}
else if ( 4 = = xRank & & 3 = = axis ) {
format = dnnl : : memory : : format_tag : : abcd ;
}
else if ( 4 = = xRank & & 1 = = axis & & dimsX [ 2 ] * dimsX [ 3 ] > 1 ) {
format = dnnl : : memory : : format_tag : : acdb ;
}
else if ( 4 = = xRank ) {
format = dnnl : : memory : : format_tag : : abcd ;
}
else if ( 5 = = xRank ) {
format = dnnl : : memory : : format_tag : : abcde ;
}
else if ( 6 = = xRank ) {
format = dnnl : : memory : : format_tag : : abcdef ;
}
dnnl : : memory : : data_type xType = dnnl : : memory : : data_type : : f32 ;
dnnl : : memory : : data_type zType = dnnl : : memory : : data_type : : f32 ;
dnnl : : memory : : desc x_mkl_md = dnnl : : memory : : desc ( xShape , xType , format ) ;
dnnl : : memory : : desc x_user_md = dnnl : : memory : : desc ( xShape , xType , format ) ;
if ( x - > ews ( ) ! = 1 | | x - > ordering ( ) ! = ' c ' ) {
x_user_md . data . format_kind = dnnl_blocked ; // overrides format
for ( auto i = 0 ; i < xRank ; + + i ) {
x_user_md . data . format_desc . blocking . strides [ i ] = x - > strideAt ( i ) ;
}
}
// z
dnnl : : memory : : desc z_mkl_md = dnnl : : memory : : desc ( zShape , zType , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc z_user_md = dnnl : : memory : : desc ( zShape , zType , format ) ;
if ( z - > ews ( ) ! = 1 | | z - > ordering ( ) ! = ' c ' ) {
z_user_md . data . format_kind = dnnl_blocked ; // overrides format
for ( auto i = 0 ; i < xRank ; + + i ) {
z_user_md . data . format_desc . blocking . strides [ i ] = z - > strideAt ( i ) ;
}
}
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
// Create attributes (to handle alpha and beta if necessary)
dnnl : : primitive_attr attr ; // it is empty since we have usual values for alpha (=1) and beta (=0)
// operation primitive description
// todo check this
dnnl : : softmax_forward : : desc op_desc ( dnnl : : prop_kind : : forward_inference , x_mkl_md , axis ) ;
dnnl : : softmax_forward : : primitive_desc op_prim_desc ( op_desc , attr , engine ) ;
// arguments (memory buffers) necessary for calculations
std : : unordered_map < int , dnnl : : memory > args ;
dnnl : : stream stream ( engine ) ;
// provide memory buffers and check whether reorder is required
// input
auto x_user_mem = dnnl : : memory ( x_user_md , engine , x - > getBuffer ( ) ) ;
const bool xReorder = op_prim_desc . src_desc ( ) ! = x_user_mem . get_desc ( ) ;
auto x_mkl_mem = xReorder ? dnnl : : memory ( op_prim_desc . src_desc ( ) , engine ) : x_user_mem ;
if ( xReorder )
dnnl : : reorder ( x_user_mem , x_mkl_mem ) . execute ( stream , x_user_mem , x_mkl_mem ) ;
args [ DNNL_ARG_SRC ] = x_mkl_mem ;
// z
auto z_user_mem = dnnl : : memory ( z_user_md , engine , z - > getBuffer ( ) ) ;
const bool zReorder = op_prim_desc . dst_desc ( ) ! = z_user_mem . get_desc ( ) ;
auto z_mkl_mem = zReorder ? dnnl : : memory ( op_prim_desc . dst_desc ( ) , engine ) : z_user_mem ;
args [ DNNL_ARG_DST ] = z_mkl_mem ;
// run calculations
dnnl : : softmax_forward ( op_prim_desc ) . execute ( stream , args ) ;
// reorder outputs if necessary
if ( zReorder )
dnnl : : reorder ( z_mkl_mem , z_user_mem ) . execute ( stream , z_mkl_mem , z_user_mem ) ;
stream . wait ( ) ;
}
PLATFORM_IMPL ( softmax , ENGINE_CPU ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto output = OUTPUT_VARIABLE ( 0 ) ;
const int rank = input - > rankOf ( ) ;
int dim = block . getIArguments ( ) - > size ( ) > 0 ? INT_ARG ( 0 ) : rank - 1 ;
if ( dim < 0 ) {
dim + = rank ;
}
REQUIRE_TRUE ( dim < rank & & dim > = 0 , 0 , " SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead ! " , rank , dim ) ;
Tanh mkldnn implementation (#296)
* libnd4j first step of softmax mkldnn implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of mkldnn softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and added softmax to MklDnnTests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for softmax mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor corrections to avoid risk connected with negative axis usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed dataType selection per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix for mac and windows builds
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j builds fix
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j first spet of elementwize tanh implementation on mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed typo in error message for softmax MKLDNN, test case added, implementation of tanh on MKLDNN, need supported DataType testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for tanh and temporary performance test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed mkldnn platform loader for tanh
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j MklDnn tanh removed unsupported data types, removed performance test case, added more appropriate equivalence test case, code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed problem with empty input case for MklDnn tanh and softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
2020-03-06 15:11:22 +01:00
REQUIRE_TRUE ( rank < = 6 , 0 , " SOFTMAX_MKLDNN OP: the rank of input must be less or qual 6, but got rank = %i instead ! " , rank ) ;
2020-03-04 17:36:42 +01:00
// mkldnnSoftMax
softmaxMKLDNN ( input , output , dim ) ;
return Status : : OK ( ) ;
}
PLATFORM_CHECK ( softmax , ENGINE_CPU ) {
auto x = INPUT_VARIABLE ( 0 ) ;
auto z = OUTPUT_VARIABLE ( 0 ) ;
const DataType xType = x - > dataType ( ) ;
const DataType zType = z - > dataType ( ) ;
const int xRank = x - > rankOf ( ) ;
bool bSupportedRanks = ( xRank > 2 & & xRank < 7 ) ;
/*
Source Destination
f32 f32
*/
Tanh mkldnn implementation (#296)
* libnd4j first step of softmax mkldnn implementation
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j raw implementation of mkldnn softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge master and added softmax to MklDnnTests
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j some corrections for softmax mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j minor corrections to avoid risk connected with negative axis usage
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed dataType selection per request
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fix for mac and windows builds
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j builds fix
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j first spet of elementwize tanh implementation on mkldnn
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed typo in error message for softmax MKLDNN, test case added, implementation of tanh on MKLDNN, need supported DataType testing
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j several fixes for tanh and temporary performance test added
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed mkldnn platform loader for tanh
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j MklDnn tanh removed unsupported data types, removed performance test case, added more appropriate equivalence test case, code clean up
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
* libnd4j fixed problem with empty input case for MklDnn tanh and softmax
Signed-off-by: Oleg <oleg.semeniv@gmail.com>
2020-03-06 15:11:22 +01:00
return ! x - > isEmpty ( ) & & block . isUseMKLDNN ( ) & & bSupportedRanks & & ( xType = = DataType : : FLOAT32 & & zType = = DataType : : FLOAT32 ) ;
2020-03-04 17:36:42 +01:00
}
}
}
}