2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma
//
# include <algorithm>
# include <helpers/ShapeUtils.h>
# include <climits>
# include <numeric>
# include <algorithm>
# include <set>
# include <flatbuffers/util.h>
namespace nd4j {
//////////////////////////////////////////////////////////////////////////
// evaluate shape for array resulting from tensorDot operation, also evaluate shapes and dimensions permutations for transposition of two input arrays
std : : vector < Nd4jLong > ShapeUtils : : evalShapeForTensorDot ( const Nd4jLong * aShapeInfo , const Nd4jLong * bShapeInfo , std : : vector < int > axesA , std : : vector < int > axesB , std : : vector < int > & permutAt , std : : vector < int > & permutBt , std : : vector < Nd4jLong > & shapeAt , std : : vector < Nd4jLong > & shapeBt ) {
int axeAsize = ( int ) axesA . size ( ) ;
int axeBsize = ( int ) axesB . size ( ) ;
int aRank = aShapeInfo [ 0 ] ;
int bRank = bShapeInfo [ 0 ] ;
if ( axeAsize ! = axeBsize )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the numbers of a axes and b axes to make dot product along must have identical values ! " ) ;
if ( axeAsize > aRank | | axeBsize > bRank )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the length of vector of a or b axes is larger than array rank ! " ) ;
// axes validation
for ( int i = 0 ; i < axeBsize ; i + + ) {
if ( axesA [ i ] < 0 )
axesA [ i ] + = aRank ;
if ( axesB [ i ] < 0 )
axesB [ i ] + = bRank ;
if ( aShapeInfo [ axesA [ i ] + 1 ] ! = bShapeInfo [ axesB [ i ] + 1 ] )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the dimensions at given axes for both input arrays must be the same ! " ) ;
}
// check whether axesA and axesB contain only unique numbers
std : : set < Nd4jLong > uniqueElems ( axesA . begin ( ) , axesA . end ( ) ) ;
if ( ( int ) uniqueElems . size ( ) ! = axeAsize )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the vector of a axes contains duplicates ! " ) ;
uniqueElems . clear ( ) ;
uniqueElems = std : : set < Nd4jLong > ( axesB . begin ( ) , axesB . end ( ) ) ;
if ( ( int ) uniqueElems . size ( ) ! = axeBsize )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the vector of b axes contains duplicates ! " ) ;
std : : vector < int > list_A , list_B ;
for ( int i = 0 ; i < aRank ; i + + )
if ( std : : find ( axesA . begin ( ) , axesA . end ( ) , i ) = = axesA . end ( ) )
list_A . emplace_back ( i ) ;
for ( int i = 0 ; i < bRank ; i + + )
if ( std : : find ( axesB . begin ( ) , axesB . end ( ) , i ) = = axesB . end ( ) )
list_B . emplace_back ( i ) ;
permutAt = list_A ;
permutAt . insert ( permutAt . end ( ) , axesA . begin ( ) , axesA . end ( ) ) ;
permutBt = axesB ;
permutBt . insert ( permutBt . end ( ) , list_B . begin ( ) , list_B . end ( ) ) ;
int n2 = 1 ;
for ( int i = 0 ; i < axeAsize ; i + + )
n2 * = aShapeInfo [ axesA [ i ] + 1 ] ;
shapeAt = { - 1 , n2 } ;
std : : vector < Nd4jLong > oldShapeA ;
oldShapeA . resize ( list_A . size ( ) ) ;
for ( int i = 0 ; i < oldShapeA . size ( ) ; + + i )
oldShapeA [ i ] = aShapeInfo [ list_A [ i ] + 1 ] ;
int n3 = 1 ;
for ( int i = 0 ; i < axeBsize ; i + + )
n3 * = bShapeInfo [ axesB [ i ] + 1 ] ;
shapeBt = { n3 , - 1 } ;
std : : vector < Nd4jLong > oldShapeB ;
oldShapeB . resize ( list_B . size ( ) ) ;
for ( int i = 0 ; i < oldShapeB . size ( ) ; i + + )
oldShapeB [ i ] = bShapeInfo [ list_B [ i ] + 1 ] ;
std : : vector < Nd4jLong > aPlusB ( oldShapeA ) ;
aPlusB . insert ( aPlusB . end ( ) , oldShapeB . begin ( ) , oldShapeB . end ( ) ) ;
return aPlusB ;
}
//////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : evalShapeForTensorDot ( const NDArray * a , const NDArray * b , const std : : vector < int > & axesA , const std : : vector < int > & axesB , std : : vector < int > & permutAt , std : : vector < int > & permutBt , std : : vector < Nd4jLong > & shapeAt , std : : vector < Nd4jLong > & shapeBt ) {
return evalShapeForTensorDot ( a - > getShapeInfo ( ) , b - > getShapeInfo ( ) , axesA , axesB , permutAt , permutBt , shapeAt , shapeBt ) ;
}
2019-06-15 13:34:34 +02:00
//////////////////////////////////////////////////////////////////////////
// evaluate output shape for reduce operation when input shape is empty
Nd4jLong * ShapeUtils : : evalReduceShapeInfoEmpty ( const char order , std : : vector < int > & dimsToExclude , const Nd4jLong * shapeInfo , const nd4j : : DataType dataType , const bool keepDims , nd4j : : memory : : Workspace * workspace ) {
if ( dimsToExclude . size ( ) = = 0 ) { // return copy of input shape
Nd4jLong * outShapeInfo = ShapeBuilders : : copyShapeInfoAndType ( shapeInfo , dataType , true , workspace ) ;
ShapeDescriptor descriptor ( outShapeInfo , dataType ) ;
RELEASE ( outShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
const int rank = shape : : rank ( shapeInfo ) ;
Nd4jLong * outShapeInfo = nullptr ;
if ( dimsToExclude . size ( ) = = rank ) { // return scalar or shape filled with unities
if ( ! keepDims )
outShapeInfo = ShapeBuilders : : createScalarShapeInfo ( dataType , workspace ) ;
else
outShapeInfo = ShapeBuilders : : createShapeInfo ( dataType , order , std : : vector < Nd4jLong > ( rank , 1 ) , workspace ) ;
}
else {
shape : : checkDimensions ( rank , dimsToExclude ) ;
std : : vector < Nd4jLong > outShape ;
if ( keepDims ) {
outShape . assign ( shapeInfo + 1 , shapeInfo + 1 + rank ) ;
for ( const auto & dim : dimsToExclude )
outShape [ dim ] = 1 ;
}
else {
for ( uint i = 0 , j = 0 ; i < rank ; + + i ) {
if ( j < dimsToExclude . size ( ) & & i = = dimsToExclude [ j ] )
+ + j ;
else
outShape . emplace_back ( shapeInfo [ i + 1 ] ) ;
}
}
outShapeInfo = ShapeBuilders : : createShapeInfo ( dataType , order , outShape , workspace ) ;
}
ShapeDescriptor descriptor ( outShapeInfo , dataType ) ;
RELEASE ( outShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const NDArray & arr , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
return evalReduceShapeInfo ( order , dimsToExclude , arr , arr . dataType ( ) , keepDims , supportOldShapes , workspace ) ;
2019-06-06 14:21:15 +02:00
}
2019-06-15 13:34:34 +02:00
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const Nd4jLong * shapeInfo , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
return evalReduceShapeInfo ( order , dimsToExclude , shapeInfo , ArrayOptions : : dataType ( shapeInfo ) , keepDims , supportOldShapes , workspace ) ;
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////
2019-06-15 13:34:34 +02:00
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const NDArray & arr , const nd4j : : DataType dataType , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
return evalReduceShapeInfo ( order , dimsToExclude , arr . getShapeInfo ( ) , dataType , keepDims , supportOldShapes , workspace ) ;
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////
// evaluate shape resulting from reduce operation
2019-06-15 13:34:34 +02:00
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const Nd4jLong * shapeInfo , const nd4j : : DataType dataType , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
if ( ArrayOptions : : arrayType ( shapeInfo ) = = ArrayType : : EMPTY )
return ShapeUtils : : evalReduceShapeInfoEmpty ( order , dimsToExclude , shapeInfo , dataType , keepDims , workspace ) ;
2019-06-06 14:21:15 +02:00
Nd4jLong * newShapeInfo = nullptr ;
int rank = shape : : rank ( const_cast < Nd4jLong * > ( shapeInfo ) ) ;
2019-06-15 13:34:34 +02:00
if ( dimsToExclude . size ( ) = = 0 ) { // return scalar or array with len=1 in this case
2019-06-06 14:21:15 +02:00
if ( keepDims & & rank > 1 ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( rank ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = rank ;
for ( int i = 0 ; i < rank ; + + i )
newShapeInfo [ i + 1 ] = 1 ;
ShapeUtils : : updateStridesAndType ( newShapeInfo , shapeInfo , order ) ;
ArrayOptions : : setDataType ( newShapeInfo , dataType ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
else if ( supportOldShapes ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
shape : : shapeOldScalar ( dataType , newShapeInfo , ' c ' ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
else {
newShapeInfo = ShapeBuilders : : createScalarShapeInfo ( dataType , workspace ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
}
2019-06-15 13:34:34 +02:00
shape : : checkDimensions ( rank , dimsToExclude ) ;
2019-06-06 14:21:15 +02:00
2019-06-15 13:34:34 +02:00
int dimSize = dimsToExclude . size ( ) ;
2019-06-06 14:21:15 +02:00
if ( keepDims ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( rank ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = rank ;
for ( int i = 0 ; i < rank ; + + i )
2019-06-15 13:34:34 +02:00
if ( std : : binary_search ( dimsToExclude . begin ( ) , dimsToExclude . end ( ) , i ) ) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
2019-06-06 14:21:15 +02:00
newShapeInfo [ i + 1 ] = 1 ;
else
newShapeInfo [ i + 1 ] = shapeInfo [ i + 1 ] ;
ShapeUtils : : updateStridesAndType ( newShapeInfo , shapeInfo , order ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
int newRank = rank - dimSize ;
2019-06-15 13:34:34 +02:00
if ( newRank = = 0 | | ( dimSize = = 1 & & dimsToExclude [ 0 ] = = INT_MAX ) ) { // check whether given dimension is meant for the whole dimension
2019-06-06 14:21:15 +02:00
if ( supportOldShapes ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
shape : : shapeOldScalar ( ArrayOptions : : dataType ( shapeInfo ) , newShapeInfo , ' c ' ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
else {
newShapeInfo = ShapeBuilders : : createScalarShapeInfo ( ArrayOptions : : dataType ( shapeInfo ) , workspace ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
}
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( newRank ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = newRank ; // set rank
int j = 1 ;
for ( int i = 0 ; i < rank ; + + i )
2019-06-15 13:34:34 +02:00
if ( ! std : : binary_search ( dimsToExclude . begin ( ) , dimsToExclude . end ( ) , i ) ) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
2019-06-06 14:21:15 +02:00
newShapeInfo [ j + + ] = shapeInfo [ i + 1 ] ;
//ensure whether vector has proper shape for old shape type
if ( newRank = = 1 & & supportOldShapes ) {
int oldValue = newShapeInfo [ 1 ] ;
RELEASE ( newShapeInfo , workspace ) ;
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ; // set newRank = 2
newShapeInfo [ 0 ] = 2 ;
2019-06-15 13:34:34 +02:00
if ( dimsToExclude [ 0 ] = = 0 ) {
2019-06-06 14:21:15 +02:00
newShapeInfo [ 1 ] = 1 ;
newShapeInfo [ 2 ] = oldValue ;
}
else {
newShapeInfo [ 1 ] = oldValue ;
newShapeInfo [ 2 ] = 1 ;
}
}
ShapeUtils : : updateStridesAndType ( newShapeInfo , shapeInfo , order ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shape for array which is result of repeat operation applied to arr
std : : vector < Nd4jLong > ShapeUtils : : evalRepeatShape ( int dimension , const std : : vector < Nd4jLong > & repeats , const NDArray & arr ) {
int rank = arr . rankOf ( ) ;
if ( dimension < 0 )
dimension + = rank ;
std : : vector < Nd4jLong > reps ;
if ( ( int ) reps . size ( ) < rank ) {
if ( dimension > 0 ) {
for ( int e = 0 ; e < rank - ( int ) repeats . size ( ) ; e + + )
reps . push_back ( 1 ) ;
for ( auto r : repeats )
reps . push_back ( r ) ;
} else {
for ( auto r : repeats )
reps . push_back ( r ) ;
for ( int e = 0 ; e < rank - ( int ) repeats . size ( ) ; e + + )
reps . push_back ( 1 ) ;
}
} /* else {
for ( auto r : repeats )
reps . push_back ( r ) ;
} */
std : : vector < Nd4jLong > outShape ( rank ) ;
for ( int i = 0 ; i < rank ; i + + )
outShape [ i ] = arr . sizeAt ( i ) * reps . at ( i ) ;
return outShape ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
Nd4jLong * ShapeUtils : : evalPermShapeInfo ( const int * dimensions , const int rank , const NDArray & arr , nd4j : : memory : : Workspace * workspace ) {
if ( ! arr . nonNull ( ) )
throw std : : runtime_error ( " ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr! " ) ;
if ( rank ! = arr . rankOf ( ) )
throw std : : runtime_error ( " ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable! " ) ;
auto shapeInfoLength = shape : : shapeInfoLength ( rank ) ;
// allocate memory for new array - shapeInfo
Nd4jLong * shapeInfoNew = nullptr ;
ALLOCATE ( shapeInfoNew , workspace , shapeInfoLength , Nd4jLong ) ;
// copy arr _shapeInfo into new array
memcpy ( shapeInfoNew , arr . getShapeInfo ( ) , shape : : shapeInfoByteLength ( rank ) ) ;
// perform buffer permutation
shape : : doPermuteShapeInfo ( shapeInfoNew , dimensions ) ;
ShapeDescriptor descriptor ( shapeInfoNew ) ;
RELEASE ( shapeInfoNew , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
Nd4jLong * ShapeUtils : : evalPermShapeInfo ( const Nd4jLong * dimensions , const int rank , const NDArray & arr , nd4j : : memory : : Workspace * workspace ) {
std : : vector < int > dims ( dimensions , dimensions + rank ) ;
return evalPermShapeInfo ( dims . data ( ) , rank , arr , workspace ) ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of transposed array
Nd4jLong * ShapeUtils : : evalTranspShapeInfo ( const NDArray & arr , nd4j : : memory : : Workspace * workspace ) {
int rank = arr . rankOf ( ) ;
std : : vector < int > dimensions ( rank ) ;
for ( int i = 0 ; i < rank ; + + i )
dimensions [ i ] = rank - 1 - i ;
return evalPermShapeInfo ( dimensions . data ( ) , dimensions . size ( ) , arr , workspace ) ;
}
//////////////////////////////////////////////////////////////////////////
bool ShapeUtils : : copyVectorPart ( std : : vector < int > & target , std : : vector < int > & source , int rank , int offset ) {
if ( source . size ( ) < offset + rank )
return false ;
for ( int e = offset ; e < offset + rank ; e + + )
target . push_back ( source [ e ] ) ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
// return new (shorter) sorted dimensions array without dimensions that are present in input vector
std : : vector < int > ShapeUtils : : evalDimsToExclude ( const int rank , const int dimsLen , const int * dimensions ) {
std : : vector < int > newDimensions ;
if ( dimsLen = = 0 ) { // if input vector is empty then return whole shape range
newDimensions . resize ( rank ) ;
std : : iota ( newDimensions . begin ( ) , newDimensions . end ( ) , 0 ) ; // fill with 0, 1, ... rank-1
}
else {
bool isAbsent ;
for ( int i = 0 ; i < rank ; + + i ) {
isAbsent = true ;
for ( int j = 0 ; j < dimsLen ; + + j ) {
int dim = dimensions [ j ] > = 0 ? dimensions [ j ] : dimensions [ j ] + rank ;
if ( i = = dim ) {
isAbsent = false ;
break ;
}
}
if ( isAbsent )
newDimensions . emplace_back ( i ) ;
}
}
return newDimensions ;
}
//////////////////////////////////////////////////////////////////////////
std : : vector < int > ShapeUtils : : evalDimsToExclude ( const int rank , const std : : vector < int > & dimensions ) {
return ShapeUtils : : evalDimsToExclude ( rank , dimensions . size ( ) , dimensions . data ( ) ) ;
}
//////////////////////////////////////////////////////////////////////////
// check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end
bool ShapeUtils : : areShapesBroadcastable ( const NDArray & arr1 , const NDArray & arr2 ) {
return areShapesBroadcastable ( arr1 . getShapeInfo ( ) , arr2 . getShapeInfo ( ) ) ;
}
bool ShapeUtils : : areShapesBroadcastable ( Nd4jLong * shapeInfo1 , Nd4jLong * shapeInfo2 ) {
int minRank = shape : : rank ( shapeInfo1 ) < shape : : rank ( shapeInfo2 ) ? shape : : rank ( shapeInfo1 ) : shape : : rank ( shapeInfo2 ) ;
for ( int i = - 1 ; i > = - minRank ; - - i )
if ( shape : : sizeAt ( shapeInfo1 , i ) ! = shape : : sizeAt ( shapeInfo2 , i ) & & shape : : sizeAt ( shapeInfo1 , i ) ! = 1 & & shape : : sizeAt ( shapeInfo2 , i ) ! = 1 )
return false ;
return true ;
}
bool ShapeUtils : : areShapesBroadcastable ( const std : : vector < Nd4jLong > & shape1 , const std : : vector < Nd4jLong > & shape2 ) {
const auto rank1 = shape1 . size ( ) ;
const auto rank2 = shape2 . size ( ) ;
const int minRank = rank1 < rank2 ? rank1 : rank2 ;
for ( int i = 1 ; i < = minRank ; + + i )
if ( shape1 [ rank1 - i ] ! = shape2 [ rank2 - i ] & & shape1 [ rank1 - i ] ! = 1 & & shape2 [ rank2 - i ] ! = 1 )
return false ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false the array with larger rank has to be passed as first argument
bool ShapeUtils : : evalBroadcastShapeInfo ( const NDArray & max , const NDArray & min , const bool evalMinMax , Nd4jLong * & resultShapeInfo , nd4j : : memory : : Workspace * workspace ) {
return evalBroadcastShapeInfo ( max . getShapeInfo ( ) , min . getShapeInfo ( ) , evalMinMax , resultShapeInfo , workspace ) ;
}
bool ShapeUtils : : evalBroadcastShapeInfo ( Nd4jLong * max , Nd4jLong * min , const bool evalMinMax , Nd4jLong * & resultShapeInfo , nd4j : : memory : : Workspace * workspace ) {
// check whether broadcast operation is possible for input arrays
if ( ! areShapesBroadcastable ( max , min ) )
return false ;
auto maxShapeInfo = max ; //max.getShapeInfo();
auto minShapeInfo = min ; //min.getShapeInfo();
if ( evalMinMax & & ( shape : : rank ( max ) < shape : : rank ( min ) ) ) {
maxShapeInfo = min ;
minShapeInfo = max ;
}
const auto maxRank = shape : : rank ( maxShapeInfo ) ;
const auto minRank = shape : : rank ( minShapeInfo ) ;
// evaluate shapeInfo for resulting array
if ( resultShapeInfo ! = nullptr )
throw std : : runtime_error ( " std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) ! " ) ;
Nd4jLong * tmpShapeInfo = nullptr ;
ALLOCATE ( tmpShapeInfo , workspace , shape : : shapeInfoLength ( maxRank ) , Nd4jLong ) ;
// FIXME: get rid of memcpy here
memcpy ( tmpShapeInfo , maxShapeInfo , shape : : shapeInfoByteLength ( maxRank ) ) ;
for ( int i = 0 ; i < minRank ; + + i )
if ( maxShapeInfo [ maxRank - i ] < minShapeInfo [ minRank - i ] )
tmpShapeInfo [ maxRank - i ] = minShapeInfo [ minRank - i ] ;
2019-06-15 13:34:34 +02:00
// nullify zero axis
for ( int e = 0 ; e < maxRank ; e + + )
if ( maxShapeInfo [ e + 1 ] = = 0 )
tmpShapeInfo [ e + 1 ] = 0 ;
int delta = maxRank - minRank ;
for ( int e = minRank - 1 ; e > = 0 ; e - - )
if ( minShapeInfo [ e + 1 ] = = 0 )
tmpShapeInfo [ e + 1 + delta ] = 0 ;
2019-06-06 14:21:15 +02:00
ShapeUtils : : updateStridesAndType ( tmpShapeInfo , DataTypeUtils : : pickPairwiseResultType ( maxShapeInfo , minShapeInfo ) , shape : : order ( maxShapeInfo ) ) ;
2019-06-15 13:34:34 +02:00
if ( shape : : isEmpty ( max ) | | shape : : isEmpty ( min ) ) {
ArrayOptions : : setPropertyBit ( tmpShapeInfo , ARRAY_EMPTY ) ;
memset ( shape : : stride ( tmpShapeInfo ) , 0 , shape : : rank ( tmpShapeInfo ) * sizeof ( Nd4jLong ) ) ;
}
2019-06-06 14:21:15 +02:00
ShapeDescriptor descriptor ( tmpShapeInfo ) ;
RELEASE ( tmpShapeInfo , workspace ) ;
resultShapeInfo = ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
bool ShapeUtils : : evalCommonBroadcastShapeInfo ( const std : : vector < const NDArray * > & arrays , Nd4jLong * & resultShapeInfo , memory : : Workspace * workspace ) {
if ( resultShapeInfo ! = nullptr )
throw std : : runtime_error ( " ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) ! " ) ;
int size = arrays . size ( ) ;
int maxRank = arrays [ size - 1 ] - > rankOf ( ) ;
for ( int i = 0 ; i < size - 1 ; + + i ) {
if ( arrays [ i ] - > rankOf ( ) > maxRank )
maxRank = arrays [ i ] - > rankOf ( ) ;
for ( int j = i + 1 ; j < size ; + + j )
if ( ! areShapesBroadcastable ( * arrays [ i ] , * arrays [ j ] ) )
return false ;
}
Nd4jLong * tmpShapeInfo = nullptr ;
ALLOCATE ( tmpShapeInfo , workspace , shape : : shapeInfoLength ( maxRank ) , Nd4jLong ) ;
memset ( tmpShapeInfo , 0 , shape : : shapeInfoByteLength ( maxRank ) ) ;
tmpShapeInfo [ 0 ] = maxRank ;
for ( const auto & item : arrays ) {
for ( int i = - 1 ; i > = - item - > rankOf ( ) ; - - i )
if ( tmpShapeInfo [ i + 1 + maxRank ] < item - > sizeAt ( i ) )
tmpShapeInfo [ i + 1 + maxRank ] = item - > sizeAt ( i ) ;
}
shape : : updateStrides ( tmpShapeInfo , arrays [ 0 ] - > ordering ( ) ) ;
ArrayOptions : : setDataType ( tmpShapeInfo , arrays [ 0 ] - > dataType ( ) ) ;
ShapeDescriptor descriptor ( tmpShapeInfo ) ;
RELEASE ( tmpShapeInfo , workspace ) ;
resultShapeInfo = ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( descriptor ) ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
// return sorted vector of dimensions of array with larger dimensions number along which two input arrays have same shape
// the array with larger dimensions number has to be passed as first argument
std : : vector < int > ShapeUtils : : getDimsWithSameShape ( const NDArray & max , const NDArray & min ) {
std : : vector < int > result ;
auto maxShapeInfo = max . getShapeInfo ( ) ;
auto minShapeInfo = min . getShapeInfo ( ) ;
int maxRank = maxShapeInfo [ 0 ] ;
int minRank = minShapeInfo [ 0 ] ;
for ( int i = 1 ; i < = minRank ; + + i )
if ( minShapeInfo [ i ] = = maxShapeInfo [ maxRank - minRank + i ] )
result . emplace_back ( maxRank - minRank + i - 1 ) ;
return result ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for resulting array from tile operation
Nd4jLong * ShapeUtils : : evalTileShapeInfo ( const NDArray & arr , const std : : vector < Nd4jLong > & reps , nd4j : : memory : : Workspace * workspace ) {
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
int dim = reps . size ( ) ;
int product = 1 ;
for ( const auto & item : reps )
product * = item ;
if ( product = = 0 )
throw std : : runtime_error ( " NDArray::tile method: one of the elements in reps array is zero ! " ) ;
int rankOld = arr . rankOf ( ) ;
int diff = rankOld - dim ;
// evaluate new shapeInfo
Nd4jLong * newShapeInfo = nullptr ;
if ( diff < 0 ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( dim ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = dim ; // set new rank
for ( int i = 1 ; i < = - diff ; + + i )
newShapeInfo [ i ] = 1 ; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
memcpy ( newShapeInfo + 1 - diff , arr . getShapeInfo ( ) + 1 , rankOld * sizeof ( Nd4jLong ) ) ; // copy old dimensions to the right-hand side of newShapeInfo shape place
for ( int i = 1 ; i < = dim ; + + i )
newShapeInfo [ i ] * = reps [ i - 1 ] ; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
else {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( rankOld ) , Nd4jLong ) ;
memcpy ( newShapeInfo , arr . getShapeInfo ( ) , shape : : shapeInfoByteLength ( rankOld ) ) ; // copy all elements of _shapeInfo to newShapeInfo
for ( int i = 1 ; i < = dim ; + + i )
newShapeInfo [ rankOld + 1 - i ] * = reps [ dim - i ] ; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
shape : : updateStrides ( newShapeInfo , arr . ordering ( ) ) ;
ArrayOptions : : setDataType ( newShapeInfo , arr . dataType ( ) ) ;
ShapeDescriptor descriptor ( newShapeInfo ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
std : : vector < Nd4jLong > ShapeUtils : : pullShapeFromShapeInfo ( Nd4jLong * shapeInfo ) {
std : : vector < Nd4jLong > shape ( shape : : rank ( shapeInfo ) ) ;
int shapeSize = shape . size ( ) ;
for ( int e = 0 ; e < shapeSize ; e + + )
shape [ e ] = shape : : shapeOf ( shapeInfo ) [ e ] ;
return shape ;
}
std : : string ShapeUtils : : shapeAsString ( const NDArray * array ) {
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < array - > rankOf ( ) ; e + + ) {
result + = flatbuffers : : NumToString ( array - > sizeAt ( e ) ) ;
if ( e < array - > rankOf ( ) - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : strideAsString ( const NDArray * array ) {
std : : string result ;
auto shapeBuffer = array - > getShapeInfo ( ) ; //Nd4jLong*
int rank = ( int ) * shapeBuffer ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < rank ; e + + ) {
if ( e > 0 )
result . append ( " , " ) ;
Nd4jLong stride = * ( shapeBuffer + rank + 1 + e ) ;
result + = flatbuffers : : NumToString ( stride ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : shapeAsString ( const std : : vector < Nd4jLong > & shape ) {
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < shape . size ( ) ; e + + ) {
result + = flatbuffers : : NumToString ( shape . at ( e ) ) ;
if ( e < shape . size ( ) - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : shapeAsString ( const Nd4jLong * shapeInfo ) {
if ( ! shapeInfo )
throw std : : runtime_error ( " ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr ! " ) ;
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < shapeInfo [ 0 ] ; e + + ) {
result + = flatbuffers : : NumToString ( shapeInfo [ e + 1 ] ) ;
if ( e < shapeInfo [ 0 ] - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : shapeAsString ( const int rank , const Nd4jLong * shapeInfo ) {
if ( ! shapeInfo )
throw std : : runtime_error ( " ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr ! " ) ;
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < rank ; e + + ) {
result + = flatbuffers : : NumToString ( shapeInfo [ e ] ) ;
if ( e < rank - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
Nd4jLong * ShapeUtils : : evalDiagShapeInfo ( const Nd4jLong * shapeInfoConst , nd4j : : memory : : Workspace * workspace ) {
auto shapeInfo = const_cast < Nd4jLong * > ( shapeInfoConst ) ;
const auto rank = shape : : rank ( shapeInfo ) ;
Nd4jLong * outputShapeInfo = nullptr ;
if ( shape : : isVector ( shapeInfo ) | | shape : : isScalar ( shapeInfo ) ) {
ALLOCATE ( outputShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
outputShapeInfo [ 0 ] = 2 ;
outputShapeInfo [ 1 ] = outputShapeInfo [ 2 ] = shape : : length ( shapeInfo ) ;
}
else {
ALLOCATE ( outputShapeInfo , workspace , shape : : shapeInfoLength ( 2 * rank ) , Nd4jLong ) ;
outputShapeInfo [ 0 ] = 2 * rank ;
for ( int i = 1 ; i < = rank ; + + i )
outputShapeInfo [ i ] = outputShapeInfo [ i + rank ] = shapeInfo [ i ] ;
}
ShapeUtils : : updateStridesAndType ( outputShapeInfo , shapeInfo , shape : : order ( shapeInfo ) ) ;
auto result = ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( outputShapeInfo ) ;
RELEASE ( outputShapeInfo , workspace ) ;
return result ;
}
std : : vector < int > ShapeUtils : : evalBroadcastBackwardAxis ( const Nd4jLong * operandShapeInfo , const Nd4jLong * resultShapeInfo ) {
// rRank >= oRank always !!
const auto oRank = shape : : rank ( operandShapeInfo ) ;
const auto rRank = shape : : rank ( resultShapeInfo ) ;
const auto diff = rRank - oRank ;
std : : vector < int > axis ;
for ( int i = 0 ; i < rRank ; + + i )
if ( i < diff | | shape : : sizeAt ( operandShapeInfo , i - diff ) ! = shape : : sizeAt ( resultShapeInfo , i ) )
axis . push_back ( i ) ;
return axis ;
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong * ShapeUtils : : matrixProductShape ( Nd4jLong * theFirstShape , Nd4jLong * theSecondShape , bool shouldTranspondFirst , bool shouldTranspondSecond , nd4j : : DataType dtype , nd4j : : memory : : Workspace * workspace ) {
auto inA = theFirstShape ;
auto inB = theSecondShape ;
Nd4jLong * shape ;
ALLOCATE ( shape , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
Nd4jLong * tmpA = ShapeBuilders : : copyShapeInfo ( inA , true , workspace ) ;
Nd4jLong * tmpB = ShapeBuilders : : copyShapeInfo ( inB , true , workspace ) ;
if ( shouldTranspondFirst )
shape : : transposeInplace ( tmpA ) ;
if ( shouldTranspondSecond )
shape : : transposeInplace ( tmpB ) ;
if ( shape : : rank ( tmpA ) = = 1 & & shape : : isMatrix ( tmpB ) ) {
// special case here
shape [ 0 ] = 1 ;
shape [ 1 ] = tmpB [ 2 ] ;
Nd4jLong * newShape = ShapeBuilders : : createShapeInfo ( dtype , ' f ' , 2 , shape , workspace ) ;
RELEASE ( shape , workspace ) ;
RELEASE ( tmpA , workspace ) ;
RELEASE ( tmpB , workspace ) ;
return newShape ;
} else if ( shape : : isScalar ( tmpA ) & & shape : : isScalar ( tmpB ) ) {
// just scalar vs scalar
shape [ 0 ] = 1 ;
shape [ 1 ] = 1 ;
} else if ( shape : : isMatrix ( tmpA ) & & shape : : isVector ( tmpB ) ) {
// gemv case
if ( shape : : rank ( tmpB ) = = 2 ) {
shape [ 0 ] = tmpA [ 1 ] ;
shape [ 1 ] = tmpB [ 2 ] ;
} else {
// we have new 1D shape here
auto newShape = ShapeBuilders : : createVectorShapeInfo ( dtype , tmpA [ 1 ] , workspace ) ;
RELEASE ( shape , workspace ) ;
RELEASE ( tmpA , workspace ) ;
RELEASE ( tmpB , workspace ) ;
return newShape ;
}
} else if ( ( shape : : isMatrix ( tmpA ) & & shape : : isMatrix ( tmpB ) ) | |
( shape : : isVector ( tmpA ) & & shape : : isMatrix ( tmpB ) ) | |
( shape : : isColumnVector ( tmpA ) & & shape : : isVector ( tmpB ) ) ) {
// gemm case
shape [ 0 ] = tmpA [ 1 ] ;
shape [ 1 ] = tmpB [ 2 ] ;
} else if ( ( shape : : isVector ( tmpA ) & & shape : : isScalar ( tmpB ) ) | |
( shape : : isScalar ( tmpA ) & & shape : : isVector ( tmpB ) ) ) {
// element-wise
shape [ 0 ] = 1 ;
shape [ 1 ] = ( int ) nd4j : : math : : nd4j_max < Nd4jLong > ( shape : : length ( tmpA ) , shape : : length ( tmpB ) ) ;
} else if ( shape : : isRowVector ( tmpA ) & & shape : : isRowVector ( tmpB ) ) {
// dot case
shape [ 0 ] = 1 ;
shape [ 1 ] = 1 ;
} else if ( shape : : isRowVector ( tmpA ) & & shape : : isColumnVector ( tmpB ) ) {
// dot case
shape [ 0 ] = 1 ;
shape [ 1 ] = 1 ;
}
Nd4jLong * newShape = ShapeBuilders : : createShapeInfo ( dtype , ' f ' , 2 , shape , workspace ) ;
RELEASE ( shape , workspace ) ;
RELEASE ( tmpA , workspace ) ;
RELEASE ( tmpB , workspace ) ;
return newShape ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < int > ShapeUtils : : evalPermutFromTo ( const std : : vector < Nd4jLong > & shapeFrom , const std : : vector < Nd4jLong > & shapeTo ) {
auto rank = shapeFrom . size ( ) ;
if ( rank ! = shapeTo . size ( ) )
throw std : : runtime_error ( " ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation ! " ) ;
if ( std : : equal ( begin ( shapeFrom ) , end ( shapeFrom ) , begin ( shapeTo ) ) ) // if shapes are identical (permutation is unnecessary) then return empty vector
return std : : vector < int > ( ) ;
std : : vector < int > permutation ( rank , - 2 ) ; // vector to be returned
std : : vector < Nd4jLong > shapeTo2 ( shapeTo ) ; // make copy of const vector since we will change the content of shapeTo
for ( int i = 0 ; i < rank ; + + i )
for ( int j = 0 ; j < rank ; + + j )
if ( shapeFrom [ i ] = = shapeTo2 [ j ] ) {
permutation [ j ] = i ;
shapeTo2 [ j ] = - 2 ; // mark coincidence as -2 in order to not account index of shapeTo twice
break ;
}
if ( std : : find ( begin ( permutation ) , end ( permutation ) , - 2 ) ! = end ( permutation ) ) // if -2 is still present in vector then permutation is impossible
throw std : : runtime_error ( " ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation ! " ) ;
return permutation ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : composeShapeUsingDimsAndIdx ( const std : : vector < int > & dimsAndIdx ) {
auto size = dimsAndIdx . size ( ) ;
if ( size % 2 ! = 0 )
throw std : : runtime_error ( " ShapeUtils::composeShapeUsingDimsAndIdx static method: the size of input vector must be even ! " ) ;
size / = 2 ;
std : : vector < Nd4jLong > shape ( size ) ;
int index ;
for ( int i = 0 ; i < size ; + + i ) {
index = dimsAndIdx [ i + size ] ;
if ( index > size - 1 )
throw std : : runtime_error ( " ShapeUtils::composeShapeUsingDimsAndIdx static method: input index is too large ! " ) ;
shape [ index ] = dimsAndIdx [ i ] ;
}
return shape ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : evalShapeForMatmul ( const Nd4jLong * xShapeInfo , const Nd4jLong * yShapeInfo , const bool transX , const bool transY ) {
const auto xRank = xShapeInfo [ 0 ] ;
const auto yRank = yShapeInfo [ 0 ] ;
const Nd4jLong x0Dim = transX ? xShapeInfo [ xRank ] : xShapeInfo [ xRank - 1 ] ;
const Nd4jLong y0Dim = transY ? yShapeInfo [ yRank ] : yShapeInfo [ yRank - 1 ] ;
const Nd4jLong x1Dim = transX ? xShapeInfo [ xRank - 1 ] : xShapeInfo [ xRank ] ;
const Nd4jLong y1Dim = transY ? yShapeInfo [ yRank - 1 ] : yShapeInfo [ yRank ] ;
if ( xRank = = 1 & & yRank = = 1 ) { // dot case, output is scalar
if ( xShapeInfo [ 1 ] ! = yShapeInfo [ 1 ] ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i ! " , xShapeInfo [ 1 ] , yShapeInfo [ 1 ] ) ;
throw std : : invalid_argument ( " " ) ;
}
2019-06-15 13:34:34 +02:00
return std : : vector < Nd4jLong > ( { } ) ;
2019-06-06 14:21:15 +02:00
}
if ( xRank = = 1 & & yRank = = 2 ) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
if ( xShapeInfo [ 1 ] ! = y0Dim ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s ! " , ShapeUtils : : shapeAsString ( xShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( yShapeInfo ) . c_str ( ) ) ;
throw std : : invalid_argument ( " " ) ;
}
return std : : vector < Nd4jLong > ( { y1Dim } ) ;
}
if ( xRank = = 2 & & yRank = = 1 ) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
if ( x1Dim ! = yShapeInfo [ 1 ] ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s ! " , ShapeUtils : : shapeAsString ( xShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( yShapeInfo ) . c_str ( ) ) ;
throw std : : invalid_argument ( " " ) ;
}
return std : : vector < Nd4jLong > ( { x0Dim } ) ;
}
// rest cases - usual 2Dx2D or batched mmul
if ( xRank ! = yRank ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul static method: the ranks of arrays must be the same, but got xRank = %i and yRank = %i ! \n " , xRank , yRank ) ;
throw std : : invalid_argument ( " " ) ;
}
if ( x1Dim ! = y0Dim ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xDim %i != yDim %i \n " , x1Dim , y0Dim ) ;
throw std : : invalid_argument ( " " ) ;
}
for ( int i = 0 ; i < xRank - 2 ; + + i )
if ( xShapeInfo [ i + 1 ] ! = yShapeInfo [ i + 1 ] ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xShape = %s, yShape = %s ! \n " , ShapeUtils : : shapeAsString ( xShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( yShapeInfo ) . c_str ( ) ) ;
throw std : : invalid_argument ( " " ) ;
}
std : : vector < Nd4jLong > cShape ( xRank ) ;
// copy batch part of shape (if present)
for ( int i = 0 ; i < xRank - 2 ; + + i )
cShape [ i ] = xShapeInfo [ i + 1 ] ;
// copy rest part of shape (two dims: multiplication part)
cShape [ xRank - 2 ] = x0Dim ;
cShape [ xRank - 1 ] = y1Dim ;
return cShape ;
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong ShapeUtils : : getNumOfSubArrs ( const Nd4jLong * shapeInfo , const std : : vector < int > & dimsToExclude ) {
Nd4jLong numOfSubArrs = 1 ;
if ( dimsToExclude . size ( ) = = shape : : rank ( shapeInfo ) | | dimsToExclude . size ( ) = = 0 ) // means there is only one sub-array and it coincides with whole array
return numOfSubArrs ;
for ( const auto & dim : dimsToExclude )
numOfSubArrs * = shapeInfo [ dim + 1 ] ;
return numOfSubArrs ;
}
////////////////////////////////////////////////////////////////////////////////
void ShapeUtils : : evalIdxRangesForSubArr ( const Nd4jLong subArrIdx , const Nd4jLong * shapeInfo , const std : : vector < int > & dimsToExclude , Nd4jLong * idxRanges ) {
const auto rank = shape : : rank ( shapeInfo ) ;
const auto subArrRank = static_cast < int > ( dimsToExclude . size ( ) ) ;
if ( subArrRank > rank )
throw std : : invalid_argument ( " ShapeUtils::evalIdxRangesForSubArr static method: dimsToExclude is empty or has size > rank of array ! " ) ;
if ( subArrRank = = 0 ) { // means whole array
memset ( idxRanges , 0 , 2 * rank * sizeof ( Nd4jLong ) ) ;
return ;
}
std : : vector < Nd4jLong > shapeOfSubArr ( subArrRank ) , indexes ( subArrRank ) ;
for ( int i = 0 ; i < subArrRank ; + + i )
shapeOfSubArr [ i ] = shapeInfo [ dimsToExclude [ i ] + 1 ] ;
shape : : index2coords ( subArrRank , shapeOfSubArr . data ( ) , subArrIdx , indexes . data ( ) ) ;
memset ( idxRanges , 0 , 2 * rank * sizeof ( Nd4jLong ) ) ;
for ( int i = 0 ; i < subArrRank ; + + i ) {
int currIdx = 2 * dimsToExclude [ i ] ;
idxRanges [ currIdx ] = indexes [ i ] ;
idxRanges [ currIdx + 1 ] = indexes [ i ] + 1 ;
}
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : evalDimsWithoutUnities ( const Nd4jLong * shapeInfo ) {
std : : vector < Nd4jLong > result ;
for ( int i = 1 ; i < = shapeInfo [ 0 ] ; + + i )
if ( shapeInfo [ i ] ! = 1 )
result . push_back ( shapeInfo [ i ] ) ;
return result ;
}
////////////////////////////////////////////////////////////////////////////////
void ShapeUtils : : updateStridesAndType ( Nd4jLong * dest , const Nd4jLong * source , const char order ) {
shape : : updateStrides ( dest , order ) ;
ArrayOptions : : copyDataType ( dest , source ) ;
}
////////////////////////////////////////////////////////////////////////////////
void ShapeUtils : : updateStridesAndType ( Nd4jLong * dest , const DataType dtype , const char order ) {
shape : : updateStrides ( dest , order ) ;
ArrayOptions : : setDataType ( dest , dtype ) ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < int > ShapeUtils : : tadAxesForSimpleBroadcast ( const NDArray & max , const NDArray & min ) {
const int maxRank = max . rankOf ( ) ;
const int minRank = min . rankOf ( ) ;
const int diff = maxRank - minRank ;
Nd4jLong numOfMinTads ( 1 ) , numOfMaxTads ( 1 ) ;
std : : vector < int > maxTadDims ;
for ( int i = 0 ; i < minRank ; + + i ) {
if ( min . sizeAt ( i ) = = max . sizeAt ( diff + i ) )
maxTadDims . push_back ( diff + i ) ;
else {
numOfMinTads * = min . sizeAt ( i ) ;
numOfMaxTads * = max . sizeAt ( i ) ;
}
}
if ( min . lengthOf ( ) > max . lengthOf ( ) ) { // in this case tad is max array
for ( int i = 0 ; i < diff ; + + i )
numOfMaxTads * = max . sizeAt ( i ) ;
return numOfMaxTads = = 1 ? maxTadDims : std : : vector < int > ( ) ;
}
return numOfMinTads = = 1 ? maxTadDims : std : : vector < int > ( ) ;
}
Nd4jLong ShapeUtils : : stringBufferHeaderRequirements ( Nd4jLong numStrings ) {
// we store +1 offset
auto base = numStrings + 1 ;
// since we return number of bytes...
return base * sizeof ( Nd4jLong ) ;
}
}