2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2020-05-12 06:47:09 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership .
2020-05-12 06:47:09 +02:00
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
# include <ops/declarable/PlatformHelper.h>
# include <ops/declarable/OpRegistrator.h>
# include <system/platform_boilerplate.h>
# include <helpers/MKLDNNStream.h>
# include "mkldnnUtils.h"
# include <numeric>
namespace sd {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void concatMKLDNN ( const std : : vector < const NDArray * > & inArrs , NDArray & output , const int axis ) {
// data type
dnnl : : memory : : data_type type ;
if ( output . dataType ( ) = = DataType : : FLOAT32 )
type = dnnl : : memory : : data_type : : f32 ;
else if ( output . dataType ( ) = = DataType : : HALF )
type = dnnl : : memory : : data_type : : f16 ;
else if ( output . dataType ( ) = = DataType : : BFLOAT16 )
type = dnnl : : memory : : data_type : : bf16 ;
else if ( output . dataType ( ) = = DataType : : UINT8 )
type = dnnl : : memory : : data_type : : u8 ;
else
type = dnnl : : memory : : data_type : : s8 ;
std : : vector < dnnl : : memory : : desc > x_user_md ( inArrs . size ( ) ) , x_mkl_md ( inArrs . size ( ) ) ;
// inputs
for ( int i = 0 ; i < inArrs . size ( ) ; + + i ) {
dnnl : : memory : : dims dims = inArrs [ i ] - > getShapeAsFlatVector ( ) ;
x_user_md [ i ] = x_mkl_md [ i ] = dnnl : : memory : : desc ( dims , type , mkldnnUtils : : getFormat ( * inArrs [ i ] ) ) ;
mkldnnUtils : : setBlockStrides ( * inArrs [ i ] , x_user_md [ i ] ) ;
}
// output
dnnl : : memory : : dims dims = output . getShapeAsFlatVector ( ) ;
dnnl : : memory : : desc z_mkl_md = dnnl : : memory : : desc ( dims , type , dnnl : : memory : : format_tag : : any ) ;
dnnl : : memory : : desc z_user_md = dnnl : : memory : : desc ( dims , type , mkldnnUtils : : getFormat ( output ) ) ;
mkldnnUtils : : setBlockStrides ( output , z_user_md ) ;
std : : unordered_map < int , dnnl : : memory > args ;
auto engine = mkldnnUtils : : getEngine ( LaunchContext : : defaultContext ( ) - > engine ( ) ) ;
dnnl : : concat : : primitive_desc op_prim_desc ( axis , x_mkl_md , engine ) ;
dnnl : : stream stream ( engine ) ;
// inputs
for ( int i = 0 ; i < inArrs . size ( ) ; + + i )
mkldnnUtils : : loadDataToMklStream ( * inArrs [ i ] , engine , stream , x_user_md [ i ] , op_prim_desc . src_desc ( i ) , args [ DNNL_ARG_MULTIPLE_SRC + i ] ) ;
// outputs
auto z_user_mem = mkldnnUtils : : loadDataToMklStream ( output , engine , stream , z_user_md , op_prim_desc . dst_desc ( ) , args [ DNNL_ARG_DST ] ) ;
// primitive execution
dnnl : : concat ( op_prim_desc ) . execute ( stream , args ) ;
// reorder output if necessary
if ( op_prim_desc . dst_desc ( ) ! = z_user_mem . get_desc ( ) )
dnnl : : reorder ( args [ DNNL_ARG_DST ] , z_user_mem ) . execute ( stream , args [ DNNL_ARG_DST ] , z_user_mem ) ;
stream . wait ( ) ;
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL ( concat , ENGINE_CPU ) {
REQUIRE_TRUE ( block . width ( ) > 0 , 0 , " CONCAT MKLDNN op: No input arrays were provided " ) ;
const bool isAxisInLastArr = block . getBArguments ( ) - > size ( ) = = 0 ? false : B_ARG ( 0 ) ;
const int numOfInArrs = isAxisInLastArr ? block . width ( ) - 1 : block . width ( ) ;
// first of all take into account possible presence of empty arrays
// also if scalar is present -> copy its value to vector with length=1
std : : vector < const NDArray * > nonEmptyArrs ;
std : : vector < int > arrsToDelete ;
int index = 0 ;
bool allOfSameType = true ;
auto rankOfFirstArr = block . width ( ) > 0 ? INPUT_VARIABLE ( 0 ) - > rankOf ( ) : 0 ;
auto typeOfFirstArr = block . width ( ) > 0 ? INPUT_VARIABLE ( 0 ) - > dataType ( ) : block . dataType ( ) ;
for ( int i = 0 ; i < numOfInArrs ; + + i ) {
auto input = INPUT_VARIABLE ( i ) ;
auto currentRank = input - > rankOf ( ) ;
if ( ! input - > isEmpty ( ) ) {
allOfSameType & = ( typeOfFirstArr = = input - > dataType ( ) ) ;
if ( input - > rankOf ( ) = = 0 ) {
auto vec = new NDArray ( ' c ' , { 1 } , input - > dataType ( ) , block . launchContext ( ) ) ;
vec - > assign ( input ) ;
nonEmptyArrs . push_back ( vec ) ;
arrsToDelete . push_back ( index ) ;
}
else {
nonEmptyArrs . push_back ( input ) ;
}
+ + index ;
}
}
const int numOfNonEmptyArrs = nonEmptyArrs . size ( ) ;
if ( numOfNonEmptyArrs = = 0 ) {
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
REQUIRE_TRUE ( OUTPUT_VARIABLE ( 0 ) - > isEmpty ( ) , 0 , " CONCAT MKLDNN op: If all input variables are empty, output must be empty " ) ;
return Status : : OK ( ) ;
}
const int rank = nonEmptyArrs [ 0 ] - > rankOf ( ) ; // look up to first non-empty array
int axis = isAxisInLastArr ? INPUT_VARIABLE ( block . width ( ) - 1 ) - > e < int > ( 0 ) : INT_ARG ( 0 ) ;
if ( axis < 0 ) {
axis + = rank ;
}
// ******** input validation ******** //
REQUIRE_TRUE ( allOfSameType , 0 , " CONCAT MKLDNN op: all of input arrays must have same type ! " ) ;
REQUIRE_TRUE ( nonEmptyArrs [ 0 ] - > dataType ( ) = = OUTPUT_VARIABLE ( 0 ) - > dataType ( ) , 0 , " CONCAT MKLDNN op: output array should have the same type as inputs arrays ! " ) ;
REQUIRE_TRUE ( 0 < = axis & & ( axis < rank | | ( axis = = 0 & & rank = = 0 ) ) , 0 , " CONCAT MKLDNN op: input axis must be in range [0, %i], but got %i instead! " , rank - 1 , axis ) ;
for ( int i = 1 ; i < numOfNonEmptyArrs ; + + i )
REQUIRE_TRUE ( nonEmptyArrs [ i ] - > rankOf ( ) = = rank , 0 , " CONCAT MKLDNN op: all input arrays must have the same rank ! " ) ;
for ( int i = 1 ; i < numOfNonEmptyArrs ; + + i ) {
for ( int dim = 0 ; dim < rank ; + + dim )
if ( dim ! = axis )
REQUIRE_TRUE ( nonEmptyArrs [ i ] - > sizeAt ( dim ) = = nonEmptyArrs [ 0 ] - > sizeAt ( dim ) , 0 , " CONCAT MKLDNN op: all input arrays must have the same dimensions (except those on input axis) ! " ) ;
}
// ******** end of input validation ******** //
auto output = OUTPUT_VARIABLE ( 0 ) ;
if ( numOfNonEmptyArrs = = 1 )
output - > assign ( nonEmptyArrs [ 0 ] ) ;
else
concatMKLDNN ( nonEmptyArrs , * output , axis ) ;
// delete dynamically allocated vectors with length=1
for ( int index : arrsToDelete )
delete nonEmptyArrs [ index ] ;
return Status : : OK ( ) ;
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK ( concat , ENGINE_CPU ) {
auto z = OUTPUT_VARIABLE ( 0 ) ;
const auto zType = z - > dataType ( ) ;
2020-05-19 20:56:41 +02:00
const bool isAxisInLastArr = block . getBArguments ( ) - > size ( ) = = 0 ? false : B_ARG ( 0 ) ;
const int numOfInArrs = isAxisInLastArr ? block . width ( ) - 1 : block . width ( ) ;
return z - > rankOf ( ) < 7 & & numOfInArrs < = 3072
& & ( zType = = DataType : : FLOAT32 | | zType = = DataType : : HALF | | zType = = DataType : : BFLOAT16 | | zType = = DataType : : UINT8 | | zType = = DataType : : INT8 ) ;
2020-05-12 06:47:09 +02:00
}
}
}
}