2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
2019-09-11 19:12:09 +02:00
// @author Yurii Shyrma (iuriish@yahoo.com)
2019-06-06 14:21:15 +02:00
//
# include <algorithm>
# include <helpers/ShapeUtils.h>
# include <climits>
# include <numeric>
# include <algorithm>
# include <set>
# include <flatbuffers/util.h>
namespace nd4j {
//////////////////////////////////////////////////////////////////////////
// evaluate shape for array resulting from tensorDot operation, also evaluate shapes and dimensions permutations for transposition of two input arrays
std : : vector < Nd4jLong > ShapeUtils : : evalShapeForTensorDot ( const Nd4jLong * aShapeInfo , const Nd4jLong * bShapeInfo , std : : vector < int > axesA , std : : vector < int > axesB , std : : vector < int > & permutAt , std : : vector < int > & permutBt , std : : vector < Nd4jLong > & shapeAt , std : : vector < Nd4jLong > & shapeBt ) {
int axeAsize = ( int ) axesA . size ( ) ;
int axeBsize = ( int ) axesB . size ( ) ;
int aRank = aShapeInfo [ 0 ] ;
int bRank = bShapeInfo [ 0 ] ;
if ( axeAsize ! = axeBsize )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the numbers of a axes and b axes to make dot product along must have identical values ! " ) ;
if ( axeAsize > aRank | | axeBsize > bRank )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the length of vector of a or b axes is larger than array rank ! " ) ;
// axes validation
for ( int i = 0 ; i < axeBsize ; i + + ) {
if ( axesA [ i ] < 0 )
axesA [ i ] + = aRank ;
if ( axesB [ i ] < 0 )
axesB [ i ] + = bRank ;
if ( aShapeInfo [ axesA [ i ] + 1 ] ! = bShapeInfo [ axesB [ i ] + 1 ] )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the dimensions at given axes for both input arrays must be the same ! " ) ;
}
// check whether axesA and axesB contain only unique numbers
std : : set < Nd4jLong > uniqueElems ( axesA . begin ( ) , axesA . end ( ) ) ;
if ( ( int ) uniqueElems . size ( ) ! = axeAsize )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the vector of a axes contains duplicates ! " ) ;
uniqueElems . clear ( ) ;
uniqueElems = std : : set < Nd4jLong > ( axesB . begin ( ) , axesB . end ( ) ) ;
if ( ( int ) uniqueElems . size ( ) ! = axeBsize )
throw std : : runtime_error ( " ShapeUtils::evalShapeForTensorDot method: the vector of b axes contains duplicates ! " ) ;
std : : vector < int > list_A , list_B ;
for ( int i = 0 ; i < aRank ; i + + )
if ( std : : find ( axesA . begin ( ) , axesA . end ( ) , i ) = = axesA . end ( ) )
list_A . emplace_back ( i ) ;
for ( int i = 0 ; i < bRank ; i + + )
if ( std : : find ( axesB . begin ( ) , axesB . end ( ) , i ) = = axesB . end ( ) )
list_B . emplace_back ( i ) ;
permutAt = list_A ;
permutAt . insert ( permutAt . end ( ) , axesA . begin ( ) , axesA . end ( ) ) ;
permutBt = axesB ;
permutBt . insert ( permutBt . end ( ) , list_B . begin ( ) , list_B . end ( ) ) ;
2019-08-10 08:14:18 +02:00
Nd4jLong n2 = 1 ;
2019-06-06 14:21:15 +02:00
for ( int i = 0 ; i < axeAsize ; i + + )
n2 * = aShapeInfo [ axesA [ i ] + 1 ] ;
shapeAt = { - 1 , n2 } ;
std : : vector < Nd4jLong > oldShapeA ;
oldShapeA . resize ( list_A . size ( ) ) ;
for ( int i = 0 ; i < oldShapeA . size ( ) ; + + i )
oldShapeA [ i ] = aShapeInfo [ list_A [ i ] + 1 ] ;
2019-08-10 08:14:18 +02:00
Nd4jLong n3 = 1 ;
2019-06-06 14:21:15 +02:00
for ( int i = 0 ; i < axeBsize ; i + + )
n3 * = bShapeInfo [ axesB [ i ] + 1 ] ;
shapeBt = { n3 , - 1 } ;
std : : vector < Nd4jLong > oldShapeB ;
oldShapeB . resize ( list_B . size ( ) ) ;
for ( int i = 0 ; i < oldShapeB . size ( ) ; i + + )
oldShapeB [ i ] = bShapeInfo [ list_B [ i ] + 1 ] ;
std : : vector < Nd4jLong > aPlusB ( oldShapeA ) ;
aPlusB . insert ( aPlusB . end ( ) , oldShapeB . begin ( ) , oldShapeB . end ( ) ) ;
return aPlusB ;
}
//////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : evalShapeForTensorDot ( const NDArray * a , const NDArray * b , const std : : vector < int > & axesA , const std : : vector < int > & axesB , std : : vector < int > & permutAt , std : : vector < int > & permutBt , std : : vector < Nd4jLong > & shapeAt , std : : vector < Nd4jLong > & shapeBt ) {
return evalShapeForTensorDot ( a - > getShapeInfo ( ) , b - > getShapeInfo ( ) , axesA , axesB , permutAt , permutBt , shapeAt , shapeBt ) ;
}
2019-06-15 13:34:34 +02:00
//////////////////////////////////////////////////////////////////////////
// evaluate output shape for reduce operation when input shape is empty
Nd4jLong * ShapeUtils : : evalReduceShapeInfoEmpty ( const char order , std : : vector < int > & dimsToExclude , const Nd4jLong * shapeInfo , const nd4j : : DataType dataType , const bool keepDims , nd4j : : memory : : Workspace * workspace ) {
if ( dimsToExclude . size ( ) = = 0 ) { // return copy of input shape
Nd4jLong * outShapeInfo = ShapeBuilders : : copyShapeInfoAndType ( shapeInfo , dataType , true , workspace ) ;
ShapeDescriptor descriptor ( outShapeInfo , dataType ) ;
RELEASE ( outShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
const int rank = shape : : rank ( shapeInfo ) ;
Nd4jLong * outShapeInfo = nullptr ;
if ( dimsToExclude . size ( ) = = rank ) { // return scalar or shape filled with unities
if ( ! keepDims )
outShapeInfo = ShapeBuilders : : createScalarShapeInfo ( dataType , workspace ) ;
else
outShapeInfo = ShapeBuilders : : createShapeInfo ( dataType , order , std : : vector < Nd4jLong > ( rank , 1 ) , workspace ) ;
}
else {
shape : : checkDimensions ( rank , dimsToExclude ) ;
std : : vector < Nd4jLong > outShape ;
if ( keepDims ) {
outShape . assign ( shapeInfo + 1 , shapeInfo + 1 + rank ) ;
for ( const auto & dim : dimsToExclude )
outShape [ dim ] = 1 ;
}
else {
for ( uint i = 0 , j = 0 ; i < rank ; + + i ) {
if ( j < dimsToExclude . size ( ) & & i = = dimsToExclude [ j ] )
+ + j ;
else
outShape . emplace_back ( shapeInfo [ i + 1 ] ) ;
}
}
outShapeInfo = ShapeBuilders : : createShapeInfo ( dataType , order , outShape , workspace ) ;
}
ShapeDescriptor descriptor ( outShapeInfo , dataType ) ;
RELEASE ( outShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const NDArray & arr , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
return evalReduceShapeInfo ( order , dimsToExclude , arr , arr . dataType ( ) , keepDims , supportOldShapes , workspace ) ;
2019-06-06 14:21:15 +02:00
}
2019-06-15 13:34:34 +02:00
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const Nd4jLong * shapeInfo , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
return evalReduceShapeInfo ( order , dimsToExclude , shapeInfo , ArrayOptions : : dataType ( shapeInfo ) , keepDims , supportOldShapes , workspace ) ;
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////
2019-06-15 13:34:34 +02:00
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const NDArray & arr , const nd4j : : DataType dataType , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
return evalReduceShapeInfo ( order , dimsToExclude , arr . getShapeInfo ( ) , dataType , keepDims , supportOldShapes , workspace ) ;
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////
// evaluate shape resulting from reduce operation
2019-06-15 13:34:34 +02:00
Nd4jLong * ShapeUtils : : evalReduceShapeInfo ( const char order , std : : vector < int > & dimsToExclude , const Nd4jLong * shapeInfo , const nd4j : : DataType dataType , const bool keepDims , const bool supportOldShapes , nd4j : : memory : : Workspace * workspace ) {
if ( ArrayOptions : : arrayType ( shapeInfo ) = = ArrayType : : EMPTY )
return ShapeUtils : : evalReduceShapeInfoEmpty ( order , dimsToExclude , shapeInfo , dataType , keepDims , workspace ) ;
2019-06-06 14:21:15 +02:00
Nd4jLong * newShapeInfo = nullptr ;
int rank = shape : : rank ( const_cast < Nd4jLong * > ( shapeInfo ) ) ;
2019-06-15 13:34:34 +02:00
if ( dimsToExclude . size ( ) = = 0 ) { // return scalar or array with len=1 in this case
2019-06-06 14:21:15 +02:00
if ( keepDims & & rank > 1 ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( rank ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = rank ;
for ( int i = 0 ; i < rank ; + + i )
newShapeInfo [ i + 1 ] = 1 ;
ShapeUtils : : updateStridesAndType ( newShapeInfo , shapeInfo , order ) ;
ArrayOptions : : setDataType ( newShapeInfo , dataType ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
else if ( supportOldShapes ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
shape : : shapeOldScalar ( dataType , newShapeInfo , ' c ' ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
else {
newShapeInfo = ShapeBuilders : : createScalarShapeInfo ( dataType , workspace ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
}
2019-06-15 13:34:34 +02:00
shape : : checkDimensions ( rank , dimsToExclude ) ;
2019-06-06 14:21:15 +02:00
2019-06-15 13:34:34 +02:00
int dimSize = dimsToExclude . size ( ) ;
2019-06-06 14:21:15 +02:00
if ( keepDims ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( rank ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = rank ;
for ( int i = 0 ; i < rank ; + + i )
2019-06-15 13:34:34 +02:00
if ( std : : binary_search ( dimsToExclude . begin ( ) , dimsToExclude . end ( ) , i ) ) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
2019-06-06 14:21:15 +02:00
newShapeInfo [ i + 1 ] = 1 ;
else
newShapeInfo [ i + 1 ] = shapeInfo [ i + 1 ] ;
ShapeUtils : : updateStridesAndType ( newShapeInfo , shapeInfo , order ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
int newRank = rank - dimSize ;
2019-06-15 13:34:34 +02:00
if ( newRank = = 0 | | ( dimSize = = 1 & & dimsToExclude [ 0 ] = = INT_MAX ) ) { // check whether given dimension is meant for the whole dimension
2019-06-06 14:21:15 +02:00
if ( supportOldShapes ) {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
shape : : shapeOldScalar ( ArrayOptions : : dataType ( shapeInfo ) , newShapeInfo , ' c ' ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
else {
newShapeInfo = ShapeBuilders : : createScalarShapeInfo ( ArrayOptions : : dataType ( shapeInfo ) , workspace ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
}
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( newRank ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = newRank ; // set rank
int j = 1 ;
for ( int i = 0 ; i < rank ; + + i )
2019-06-15 13:34:34 +02:00
if ( ! std : : binary_search ( dimsToExclude . begin ( ) , dimsToExclude . end ( ) , i ) ) // dimsToExclude is already sorted after shape::checkDimensions() has been applied
2019-06-06 14:21:15 +02:00
newShapeInfo [ j + + ] = shapeInfo [ i + 1 ] ;
//ensure whether vector has proper shape for old shape type
if ( newRank = = 1 & & supportOldShapes ) {
int oldValue = newShapeInfo [ 1 ] ;
RELEASE ( newShapeInfo , workspace ) ;
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ; // set newRank = 2
newShapeInfo [ 0 ] = 2 ;
2019-06-15 13:34:34 +02:00
if ( dimsToExclude [ 0 ] = = 0 ) {
2019-06-06 14:21:15 +02:00
newShapeInfo [ 1 ] = 1 ;
newShapeInfo [ 2 ] = oldValue ;
}
else {
newShapeInfo [ 1 ] = oldValue ;
newShapeInfo [ 2 ] = 1 ;
}
}
ShapeUtils : : updateStridesAndType ( newShapeInfo , shapeInfo , order ) ;
ShapeDescriptor descriptor ( newShapeInfo , dataType ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shape for array which is result of repeat operation applied to arr
2019-08-21 20:10:29 +02:00
std : : vector < Nd4jLong > ShapeUtils : : evalRepeatShape ( int axis , const std : : vector < int > & repeats , const NDArray & arr ) {
2019-06-06 14:21:15 +02:00
2019-08-21 20:10:29 +02:00
if ( axis < 0 )
axis + = arr . rankOf ( ) ;
2019-06-06 14:21:15 +02:00
2019-08-21 20:10:29 +02:00
if ( repeats . size ( ) ! = 1 & & repeats . size ( ) ! = arr . sizeAt ( axis ) )
throw std : : invalid_argument ( " ShapeUtils::evalRepeatShape: size of repeats vector must be 1 or equal to dimension at given axis ! " ) ;
2019-06-06 14:21:15 +02:00
2019-08-21 20:10:29 +02:00
std : : vector < Nd4jLong > outShape = arr . getShapeAsVector ( ) ;
2019-06-06 14:21:15 +02:00
2019-08-21 20:10:29 +02:00
if ( repeats . size ( ) = = 1 )
outShape [ axis ] * = repeats [ 0 ] ;
2019-06-06 14:21:15 +02:00
2019-08-21 20:10:29 +02:00
else
outShape [ axis ] = std : : accumulate ( repeats . begin ( ) , repeats . end ( ) , 0 ) ;
2019-06-06 14:21:15 +02:00
return outShape ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
Nd4jLong * ShapeUtils : : evalPermShapeInfo ( const int * dimensions , const int rank , const NDArray & arr , nd4j : : memory : : Workspace * workspace ) {
if ( ! arr . nonNull ( ) )
throw std : : runtime_error ( " ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: either array is nullptr! " ) ;
if ( rank ! = arr . rankOf ( ) )
throw std : : runtime_error ( " ShapeUtils::evalPermShapeInfo static method: wrong arguments in pn/termute method: rank is not suitable! " ) ;
auto shapeInfoLength = shape : : shapeInfoLength ( rank ) ;
// allocate memory for new array - shapeInfo
Nd4jLong * shapeInfoNew = nullptr ;
ALLOCATE ( shapeInfoNew , workspace , shapeInfoLength , Nd4jLong ) ;
// copy arr _shapeInfo into new array
memcpy ( shapeInfoNew , arr . getShapeInfo ( ) , shape : : shapeInfoByteLength ( rank ) ) ;
// perform buffer permutation
shape : : doPermuteShapeInfo ( shapeInfoNew , dimensions ) ;
ShapeDescriptor descriptor ( shapeInfoNew ) ;
RELEASE ( shapeInfoNew , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
Nd4jLong * ShapeUtils : : evalPermShapeInfo ( const Nd4jLong * dimensions , const int rank , const NDArray & arr , nd4j : : memory : : Workspace * workspace ) {
std : : vector < int > dims ( dimensions , dimensions + rank ) ;
return evalPermShapeInfo ( dims . data ( ) , rank , arr , workspace ) ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of transposed array
Nd4jLong * ShapeUtils : : evalTranspShapeInfo ( const NDArray & arr , nd4j : : memory : : Workspace * workspace ) {
int rank = arr . rankOf ( ) ;
std : : vector < int > dimensions ( rank ) ;
for ( int i = 0 ; i < rank ; + + i )
dimensions [ i ] = rank - 1 - i ;
return evalPermShapeInfo ( dimensions . data ( ) , dimensions . size ( ) , arr , workspace ) ;
}
//////////////////////////////////////////////////////////////////////////
bool ShapeUtils : : copyVectorPart ( std : : vector < int > & target , std : : vector < int > & source , int rank , int offset ) {
if ( source . size ( ) < offset + rank )
return false ;
for ( int e = offset ; e < offset + rank ; e + + )
target . push_back ( source [ e ] ) ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
// return new (shorter) sorted dimensions array without dimensions that are present in input vector
std : : vector < int > ShapeUtils : : evalDimsToExclude ( const int rank , const int dimsLen , const int * dimensions ) {
std : : vector < int > newDimensions ;
if ( dimsLen = = 0 ) { // if input vector is empty then return whole shape range
newDimensions . resize ( rank ) ;
std : : iota ( newDimensions . begin ( ) , newDimensions . end ( ) , 0 ) ; // fill with 0, 1, ... rank-1
}
else {
bool isAbsent ;
for ( int i = 0 ; i < rank ; + + i ) {
isAbsent = true ;
for ( int j = 0 ; j < dimsLen ; + + j ) {
int dim = dimensions [ j ] > = 0 ? dimensions [ j ] : dimensions [ j ] + rank ;
if ( i = = dim ) {
isAbsent = false ;
break ;
}
}
if ( isAbsent )
newDimensions . emplace_back ( i ) ;
}
}
return newDimensions ;
}
//////////////////////////////////////////////////////////////////////////
std : : vector < int > ShapeUtils : : evalDimsToExclude ( const int rank , const std : : vector < int > & dimensions ) {
return ShapeUtils : : evalDimsToExclude ( rank , dimensions . size ( ) , dimensions . data ( ) ) ;
}
//////////////////////////////////////////////////////////////////////////
// check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end
bool ShapeUtils : : areShapesBroadcastable ( const NDArray & arr1 , const NDArray & arr2 ) {
return areShapesBroadcastable ( arr1 . getShapeInfo ( ) , arr2 . getShapeInfo ( ) ) ;
}
bool ShapeUtils : : areShapesBroadcastable ( Nd4jLong * shapeInfo1 , Nd4jLong * shapeInfo2 ) {
int minRank = shape : : rank ( shapeInfo1 ) < shape : : rank ( shapeInfo2 ) ? shape : : rank ( shapeInfo1 ) : shape : : rank ( shapeInfo2 ) ;
for ( int i = - 1 ; i > = - minRank ; - - i )
if ( shape : : sizeAt ( shapeInfo1 , i ) ! = shape : : sizeAt ( shapeInfo2 , i ) & & shape : : sizeAt ( shapeInfo1 , i ) ! = 1 & & shape : : sizeAt ( shapeInfo2 , i ) ! = 1 )
return false ;
return true ;
}
bool ShapeUtils : : areShapesBroadcastable ( const std : : vector < Nd4jLong > & shape1 , const std : : vector < Nd4jLong > & shape2 ) {
const auto rank1 = shape1 . size ( ) ;
const auto rank2 = shape2 . size ( ) ;
const int minRank = rank1 < rank2 ? rank1 : rank2 ;
for ( int i = 1 ; i < = minRank ; + + i )
if ( shape1 [ rank1 - i ] ! = shape2 [ rank2 - i ] & & shape1 [ rank1 - i ] ! = 1 & & shape2 [ rank2 - i ] ! = 1 )
return false ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation, if true then return shapeInfo of resulting array
// if evalMinMax == false the array with larger rank has to be passed as first argument
bool ShapeUtils : : evalBroadcastShapeInfo ( const NDArray & max , const NDArray & min , const bool evalMinMax , Nd4jLong * & resultShapeInfo , nd4j : : memory : : Workspace * workspace ) {
return evalBroadcastShapeInfo ( max . getShapeInfo ( ) , min . getShapeInfo ( ) , evalMinMax , resultShapeInfo , workspace ) ;
}
bool ShapeUtils : : evalBroadcastShapeInfo ( Nd4jLong * max , Nd4jLong * min , const bool evalMinMax , Nd4jLong * & resultShapeInfo , nd4j : : memory : : Workspace * workspace ) {
// check whether broadcast operation is possible for input arrays
if ( ! areShapesBroadcastable ( max , min ) )
return false ;
auto maxShapeInfo = max ; //max.getShapeInfo();
auto minShapeInfo = min ; //min.getShapeInfo();
if ( evalMinMax & & ( shape : : rank ( max ) < shape : : rank ( min ) ) ) {
maxShapeInfo = min ;
minShapeInfo = max ;
}
const auto maxRank = shape : : rank ( maxShapeInfo ) ;
const auto minRank = shape : : rank ( minShapeInfo ) ;
// evaluate shapeInfo for resulting array
if ( resultShapeInfo ! = nullptr )
throw std : : runtime_error ( " std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) ! " ) ;
Nd4jLong * tmpShapeInfo = nullptr ;
ALLOCATE ( tmpShapeInfo , workspace , shape : : shapeInfoLength ( maxRank ) , Nd4jLong ) ;
// FIXME: get rid of memcpy here
memcpy ( tmpShapeInfo , maxShapeInfo , shape : : shapeInfoByteLength ( maxRank ) ) ;
for ( int i = 0 ; i < minRank ; + + i )
Merge master to upstream (#7945)
* Shugeo strided slice zeros (#14)
* Modified strided_slice op to properly work with empty-like shapes.
* Fixed test for reduce_mean with empty-like input.
* [WIP] Last merge (#15)
* correct logsoftmax looss (#2)
* Small SameDiff listener fix (#4)
* Various fixes (#6)
* #7839 Fix for asXMatrix and tests
* #7866 EmbeddingSequenceLayer dtype fix + test
* #7856 SameDiff save/load stream methods
* #7859 RegressionEvaluation rank 4 fix + tests + axis configuration
* EvaluationBinary 3d/4d
* More evaluation 3d/4d tests
* #7847 Evaluation empty checks
* Small test ifx
* #7848 Fix median edge case
* Improve DL4J samediff layer tests
* [WIP] FastText wrapper implemented (#8)
* FastText implemented
* Some fixes
* Fix shapes for wordsNearest
* Validation of input vectors
* Fixes
* Fixed test
* Thread tagged
* Some tweaks
* setContextClassLoader for DeallocatorServiceThread
* Numpy format tests (#1)
* Various fixes (#11)
* #7852 SameDiff gather fix
* #7892 SameDiff placeholder to constant conversion
* #7890 validate input rank for MLN/CG init methods
* Fix broken permute shape calculation
* Permute and gather fixes
* Tests
* #7850 LogSumExp fix + test
* Handful of test fixes
* Empty arrays with non-scalar shapes (#10)
* minor rearrangements for lambdas
* empty tensors with non-scalar shapes
* numpy empty tensors with non-scalar shapes
* few more empty tweaks
* Small fixes
* conv3d signature update
* micro fix in batchnorm mkldnn
* Import fixes
* Fix
* MKL-DNN update
* Small fill fix
* fill with empty input + test
* Fixes
* Small error improvement
* Fix
* one special test
* couple of fixes for lstm
* Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone
* Fixes
* FP16
* Unsigned
* BFloat16
* Fill op - empty tweaks
* - couple of fixes for empty arrays construction
- stack updated
* strided slice fix
* one transform test
* provide method for reducing shapeInfo in case of input array is empty
* Fixed reduceAlongDimensions to use empty input properly.
* couple of broadcast tests
* couple of tests broadcast tests + tweak to make them pass
* add check of non-empty to methods producing sub-arrays
* Fixed reshapeC with zeros in shape.
* complete empty check in reduce_... legacy ops
* Concat and cumsum/prod
* Tweak to empty shape inference on import
* add empty check to the rest of reduce legacy ops
* one more test
* correct typo in evalReduceShapeInfoEmpty
* Added tests for reduce_* ops to tests with zero shapes.
* few more tests for empty reductions
* Fixed strided_slice op with empty case and tests.
* one more empty reduction test
* Fixed strided_slice test.
* add empty check to NDArray::reshapei
* infOrMax
* empty min/max with infinity tests
* made unstack working correctly with empty arrays
* few IndexReduce tests + tweaks for empty shapes
* add test for empty concat
* few tests fixed
* Validation fix for reductions on empty shapes
* Reverse fix
* Reduction shape calc fixes
* SameDiff.generateOutputVariable: don't use shape function to determine number of outputs
* Range fix
* - NDArray constructor updated for scalars/empty arrays
- few tests fixed
* More fixes
* Empty creator fixes
* concat fix
* concat fix
* TF import tests: allow 'both all NaN' and 'both all inf' to pass
* Slice, zero fraction, and reshape fixes
* transpose, gather
* Zero fraction
* scalar cast fix
* Empty reduction axis support
* few more tests fixed
* Fixed input checks conforming with TF for concat op and tests.
* few tests fixed
* matmul scalar shape fix
* Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats.
* broadcast bool fix
* few more tests
* few more tests
* correct evalReduceShapeInfoEmpty
* argmax/argmin + tests
* one more empty edge case + one more test
* argmax/argmin/realdiv_bp tweaks
* empty reshape test + fix
* Helper fixes
* Small fixes
* Gather test fix
* Gather test fix
* Small fixes
* reduce scalar zero values
* scalar mean workaround
* Remove debug code
* along dim mean workaround
* one more test
* - equalsTo() tweak for empty arrays
- one more test
* broadcast tweaks
* [WIP] Fixing outstanding issues for NLP (#9)
* Avoid using not-inited objects
* Test fixed.
* Redundant method avoided for models like FastText
* KMeans++ implementation
* KMeans++ implementation
* Disable parallel execution
* KMeans++
* Tests
* Dev branch merge (#16)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Fix some issues on master (#17)
* Fix DataVec test issue
* Fix issue with dl4j SameDiff output layer
* Dtype fix for lambda layers
* #7912 BertIterator dtype fix (use float32 not global default)
* [WIP] Next set of CUDA stuff (#7)
New CUDA implementations and improvements
* bad file
* Dev branch master merge (#23)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* SameDiff ops, TF import and fixes (#24)
* CheckNumerics tests + fixes + misc fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fake quant
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fixes
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* FakeQuantWithMinMaxArgs
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* CheckNumerics fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix libnd4j ALL_INTS and ALL_FLOATS declaration (uint and bfloat types)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Small fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Javadoc
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Exception tweak
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix for out of scope stack allocated var use
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignores
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Ignore for known failing test (already logged issue)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Merge upstream to fork (#25)
* Add thousand-separator commas to TotalParams (#7915)
* Add thousand-separator commas to TotalParams
The number of parameters can be quite large, and it would help the reading of the summary printout to have the TotalParams column & values at the bottom have thousand-separator-commas in them.
* Add thousand-separator commas to MultiLayerNetwork
Corresponding change to MultiLayerNetwork
Signed-off-by: Jxtps Jxtps <jxtps435@gmail.com>
* Update contributing and issue/PR templates (#7934)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Fix link to AdaDelta paper (#7942)
Fix link to AdaDelta paper hosted on matthewzeiler.com
Signed-off-by: Jxtps
* Fixes, and ignores for known/logged failing issues (#7943)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* SameDiff + DL4J/SameDiff: Multiple fixes (#28)
* #7919 HDF5 attribute buffer length fix
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7909 Arbiter constructor exception ux improvements
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7925 RNN output layer length checks
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Add listener for validating inputs are not incorrectly modified
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* #7939 Integrate NonInplaceValidationListener into tests
* #7844 DL4J SameDiff fixes for variable minibatch size
* DL4J SameDiff fixes - ensure gradient for input placeholder is available
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Tweaks to ExternalErrorsFunction - use placeholders, make more robust
* Another fix
* More fixes
* More SameDiff/DL4J fixes
* Scope out scalar array creation in BaseScalarOp
* Remove debug code
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] Final dev branch merge (#29)
* SameDiff: convertDataType and gradient check util improvements (#12)
* GradCheck util improvements
* StopGradient constructor + test
* SameDiff: Add datatype conversion
* Javadoc and add DataType.isNumerical()
* Small fix
* Fix SameDiff TF import test cases intermediate naming (workaround for bad default)
* TFGraphTestAllHelper: check intermediates in execution order
* Add missing debug listener
* [WIP] lstmBlock fix + other changes (#13)
- fixes lstmBlock issue
- changes NDArray method reshape(), permute(), transpose() by making them return instance instead of pointer
- CheckNumerics op
- fixes for ReduceBool IsInfOrNan & IsFinite
* Small test fix
* CheckNumerics op wrapper
* Compatibility of deserialization (#18)
Signed-off-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
* SameDiff: add activation gradient checking support for debugging (#19)
* SameDiff gradient checker: first pass on activation gradient checks
* Fixes + tests for activation gradient checking
* Javadoc
* [WIP] Some nd4j data type corrections (#20)
* Adjust data type
* Set correct Data type.
* Size of proper data type.
* fix averaged cpu load (#22)
* [WIP] Multiple dataset iterators (#27)
* Splitting dataset into arbitrary number
* Fixes
* Multiple split of iterator
* Test
* Test
* Some fixes
* signature change
* one more tweak
Signed-off-by: raver119 <raver119@gmail.com>
* one more test for sequential use of DataSetIteratorSplitter
Signed-off-by: raver119 <raver119@gmail.com>
* Fixes
* Fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* one more test for Alexander
Signed-off-by: raver119 <raver119@gmail.com>
* minor test fix
Signed-off-by: raver119 <raver119@gmail.com>
* Some fixes
* Some fixes
* couple of assertions tweaked
Signed-off-by: raver119 <raver119@gmail.com>
* MDS splitter test :/
Signed-off-by: raver119 <raver119@gmail.com>
* Minor refactoring
* Multi dataset
* Some fixes
* More tests
* Small number of test fixes/improvements (failures on CI) (#31)
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* [WIP] More CUDA stuff (#26)
* initial commit
Signed-off-by: raver119 <raver119@gmail.com>
* LRN BP CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* less memory
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed bug with crop_and_resize op helper.
* get rid of unnecessary index-calculation dunction
Signed-off-by: Yurii <yurii@skymind.io>
* Fixed sort with nth_element cuda-based helper.
* Refactored nth_element.
* Refactored nth_element op and tests.
* Modified usage of dim array with sortTad routine.
* Refactored main routine of helper for non_max_image_suppression op.
* non_max_image_suppression op helper with cuda kernel implementation. Initial revision.
* fix vol2col cuda kernel
* meh
Signed-off-by: raver119 <raver119@gmail.com>
* topK concept
Signed-off-by: raver119 <raver119@gmail.com>
* unsorted topK with scanWitdh of 1
Signed-off-by: raver119 <raver119@gmail.com>
* correct vol2col tests
* sorted/unsorted topK
Signed-off-by: raver119 <raver119@gmail.com>
* implementation and fixing col2im/col2vol
* Corrected usage flags with input/output with reverse op.
* dup is const now
Signed-off-by: raver119 <raver119@gmail.com>
* percentile op
Signed-off-by: raver119 <raver119@gmail.com>
* group tests for mapool2d
Signed-off-by: Yurii <yurii@skymind.io>
* special test for george
Signed-off-by: raver119 <raver119@gmail.com>
* less threads for sortTad
Signed-off-by: raver119 <raver119@gmail.com>
* provide conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* remove auther in sort tad kernel code
Signed-off-by: Yurii <yurii@skymind.io>
* provide depthwise_conv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* - max_pooling_with_argmax
- null check for special use
Signed-off-by: raver119 <raver119@gmail.com>
* dts cuda
Signed-off-by: raver119 <raver119@gmail.com>
* provide sconv2d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* std cuda
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op to conform TF implementation.
* Improved suppression helper.
* provide pooling3d for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* more of minor lstm rearrangements
Signed-off-by: raver119 <raver119@gmail.com>
* (bi)dynamic_rnn
Signed-off-by: raver119 <raver119@gmail.com>
* templates init order
Signed-off-by: raver119 <raver119@gmail.com>
* Refactored non_max_suppression op.
* Added cuda kernel for non_max_suppression.
* CPU sort by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value
Signed-off-by: raver119 <raver119@gmail.com>
* CPU sort TAD by key/value tests
Signed-off-by: raver119 <raver119@gmail.com>
* Eliminate compiler error with cuda implementation.
* - repaired gradCheck in cuda
- provide conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* missed signature
Signed-off-by: raver119 <raver119@gmail.com>
* provide depthwise_conv2d_bp for cuda
Signed-off-by: Yurii <yurii@skymind.io>
* Implementation of lup helper with cuda kernel. Initial commit.
* further work on backprops for convolutions
Signed-off-by: Yurii <yurii@skymind.io>
* CUDA linear sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* CUDA tad sort by key/val
Signed-off-by: raver119 <raver119@gmail.com>
* start providing of backprop for pooling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* Added atomicAdd for bool datatype.
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic partition scalar CUDA
Signed-off-by: raver119 <raver119@gmail.com>
* important comment
Signed-off-by: raver119 <raver119@gmail.com>
* fix pooling2d/3d backprop helpers
Signed-off-by: Yurii <yurii@skymind.io>
* Added non-linear test with dynamic_partition.
* Improved test for dynamic_partition.
* dynamic_partition TAD concept
Signed-off-by: raver119 <raver119@gmail.com>
* - dynamic_partition TAD CUDA impl
- dynamic_partition TAD CPU fix
Signed-off-by: raver119 <raver119@gmail.com>
* - rewrite cpu code for usampling2d/3d
- write cuda code for usampling2d/3d
Signed-off-by: Yurii <yurii@skymind.io>
* dynamic_stitch CUDA vector case
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case concept
Signed-off-by: raver119 <raver119@gmail.com>
* dynamic_stitch CUDA TAD case impl
Signed-off-by: raver119 <raver119@gmail.com>
* Added tests for dynamic_stitch 3D-4D cases.
* minor tests tweaks
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed type check for dynamic stitch.
* min/max bp
Signed-off-by: raver119 <raver119@gmail.com>
* rewrite code for upsampling2d/3d cpu
Signed-off-by: Yurii <yurii@skymind.io>
* reduce min/max/norm_max bp
Signed-off-by: raver119 <raver119@gmail.com>
* lup implementation. Additional enhancements.
* provide code for upsamling2d/3d backprop
Signed-off-by: Yurii <yurii@skymind.io>
* weightedCrossEntropyWithLogits
Signed-off-by: raver119 <raver119@gmail.com>
* Fixed template math atomicMul for 64bit ints.
* Refactored dynamic_partition_bp op.
* inverseBroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
* DynamicPartitionBP test datatype fixed.
* - nd4j_atomicMul Windows fix
- cpu/NDArrayLambda.hpp excluded from CUDA
Signed-off-by: raver119 <raver119@gmail.com>
2019-06-27 17:37:04 +02:00
if ( ( maxShapeInfo [ maxRank - i ] ! = 0 & & maxShapeInfo [ maxRank - i ] < minShapeInfo [ minRank - i ] ) | | minShapeInfo [ minRank - i ] = = 0 )
2019-06-06 14:21:15 +02:00
tmpShapeInfo [ maxRank - i ] = minShapeInfo [ minRank - i ] ;
ShapeUtils : : updateStridesAndType ( tmpShapeInfo , DataTypeUtils : : pickPairwiseResultType ( maxShapeInfo , minShapeInfo ) , shape : : order ( maxShapeInfo ) ) ;
2019-06-15 13:34:34 +02:00
if ( shape : : isEmpty ( max ) | | shape : : isEmpty ( min ) ) {
ArrayOptions : : setPropertyBit ( tmpShapeInfo , ARRAY_EMPTY ) ;
memset ( shape : : stride ( tmpShapeInfo ) , 0 , shape : : rank ( tmpShapeInfo ) * sizeof ( Nd4jLong ) ) ;
}
2019-06-06 14:21:15 +02:00
ShapeDescriptor descriptor ( tmpShapeInfo ) ;
RELEASE ( tmpShapeInfo , workspace ) ;
resultShapeInfo = ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
// check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
bool ShapeUtils : : evalCommonBroadcastShapeInfo ( const std : : vector < const NDArray * > & arrays , Nd4jLong * & resultShapeInfo , memory : : Workspace * workspace ) {
if ( resultShapeInfo ! = nullptr )
throw std : : runtime_error ( " ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) ! " ) ;
int size = arrays . size ( ) ;
int maxRank = arrays [ size - 1 ] - > rankOf ( ) ;
for ( int i = 0 ; i < size - 1 ; + + i ) {
if ( arrays [ i ] - > rankOf ( ) > maxRank )
maxRank = arrays [ i ] - > rankOf ( ) ;
for ( int j = i + 1 ; j < size ; + + j )
if ( ! areShapesBroadcastable ( * arrays [ i ] , * arrays [ j ] ) )
return false ;
}
Nd4jLong * tmpShapeInfo = nullptr ;
ALLOCATE ( tmpShapeInfo , workspace , shape : : shapeInfoLength ( maxRank ) , Nd4jLong ) ;
memset ( tmpShapeInfo , 0 , shape : : shapeInfoByteLength ( maxRank ) ) ;
tmpShapeInfo [ 0 ] = maxRank ;
for ( const auto & item : arrays ) {
for ( int i = - 1 ; i > = - item - > rankOf ( ) ; - - i )
if ( tmpShapeInfo [ i + 1 + maxRank ] < item - > sizeAt ( i ) )
tmpShapeInfo [ i + 1 + maxRank ] = item - > sizeAt ( i ) ;
}
shape : : updateStrides ( tmpShapeInfo , arrays [ 0 ] - > ordering ( ) ) ;
ArrayOptions : : setDataType ( tmpShapeInfo , arrays [ 0 ] - > dataType ( ) ) ;
ShapeDescriptor descriptor ( tmpShapeInfo ) ;
RELEASE ( tmpShapeInfo , workspace ) ;
resultShapeInfo = ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( descriptor ) ;
return true ;
}
//////////////////////////////////////////////////////////////////////////
2019-10-01 08:10:19 +02:00
// return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
// for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
std : : vector < int > ShapeUtils : : getDimsWithSameShape ( const NDArray & arr1 , const NDArray & arr2 ) {
2019-06-06 14:21:15 +02:00
2019-10-01 08:10:19 +02:00
const NDArray * min , * max ;
2019-06-06 14:21:15 +02:00
2019-10-01 08:10:19 +02:00
if ( arr1 . rankOf ( ) > = arr2 . rankOf ( ) ) {
max = & arr1 ;
min = & arr2 ;
}
else {
max = & arr2 ;
min = & arr1 ;
}
2019-06-06 14:21:15 +02:00
2019-10-01 08:10:19 +02:00
const int rankDiff = max - > rankOf ( ) - min - > rankOf ( ) ;
std : : vector < int > dims ;
for ( int i = 0 ; i < min - > rankOf ( ) ; + + i )
if ( min - > sizeAt ( i ) = = max - > sizeAt ( rankDiff + i ) )
dims . emplace_back ( rankDiff + i ) ;
return dims ;
2019-06-06 14:21:15 +02:00
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for resulting array from tile operation
Nd4jLong * ShapeUtils : : evalTileShapeInfo ( const NDArray & arr , const std : : vector < Nd4jLong > & reps , nd4j : : memory : : Workspace * workspace ) {
// check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
2019-07-12 10:51:51 +02:00
int repsSize = reps . size ( ) ;
2019-08-10 08:14:18 +02:00
Nd4jLong product = 1 ;
2019-06-06 14:21:15 +02:00
for ( const auto & item : reps )
product * = item ;
if ( product = = 0 )
throw std : : runtime_error ( " NDArray::tile method: one of the elements in reps array is zero ! " ) ;
int rankOld = arr . rankOf ( ) ;
2019-07-12 10:51:51 +02:00
int diff = rankOld - repsSize ;
2019-06-06 14:21:15 +02:00
// evaluate new shapeInfo
Nd4jLong * newShapeInfo = nullptr ;
if ( diff < 0 ) {
2019-07-12 10:51:51 +02:00
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( repsSize ) , Nd4jLong ) ;
newShapeInfo [ 0 ] = repsSize ; // set new rank
2019-06-06 14:21:15 +02:00
for ( int i = 1 ; i < = - diff ; + + i )
newShapeInfo [ i ] = 1 ; // set unities to be new dimensions at left-hand side of newShapeInfo shape place
memcpy ( newShapeInfo + 1 - diff , arr . getShapeInfo ( ) + 1 , rankOld * sizeof ( Nd4jLong ) ) ; // copy old dimensions to the right-hand side of newShapeInfo shape place
2019-07-12 10:51:51 +02:00
for ( int i = 1 ; i < = repsSize ; + + i )
2019-06-06 14:21:15 +02:00
newShapeInfo [ i ] * = reps [ i - 1 ] ; // set new shape by multiplying old dimensions by corresponding numbers from reps
}
else {
ALLOCATE ( newShapeInfo , workspace , shape : : shapeInfoLength ( rankOld ) , Nd4jLong ) ;
memcpy ( newShapeInfo , arr . getShapeInfo ( ) , shape : : shapeInfoByteLength ( rankOld ) ) ; // copy all elements of _shapeInfo to newShapeInfo
2019-07-12 10:51:51 +02:00
for ( int i = 1 ; i < = repsSize ; + + i )
newShapeInfo [ rankOld + 1 - i ] * = reps [ repsSize - i ] ; // set new shape by multiplying old dimensions by corresponding numbers from reps
2019-06-06 14:21:15 +02:00
}
shape : : updateStrides ( newShapeInfo , arr . ordering ( ) ) ;
ArrayOptions : : setDataType ( newShapeInfo , arr . dataType ( ) ) ;
ShapeDescriptor descriptor ( newShapeInfo ) ;
RELEASE ( newShapeInfo , workspace ) ;
return ConstantShapeHelper : : getInstance ( ) - > bufferForShapeInfo ( descriptor ) . primaryAsT < Nd4jLong > ( ) ;
}
std : : vector < Nd4jLong > ShapeUtils : : pullShapeFromShapeInfo ( Nd4jLong * shapeInfo ) {
std : : vector < Nd4jLong > shape ( shape : : rank ( shapeInfo ) ) ;
int shapeSize = shape . size ( ) ;
for ( int e = 0 ; e < shapeSize ; e + + )
shape [ e ] = shape : : shapeOf ( shapeInfo ) [ e ] ;
return shape ;
}
std : : string ShapeUtils : : shapeAsString ( const NDArray * array ) {
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < array - > rankOf ( ) ; e + + ) {
result + = flatbuffers : : NumToString ( array - > sizeAt ( e ) ) ;
if ( e < array - > rankOf ( ) - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : strideAsString ( const NDArray * array ) {
std : : string result ;
auto shapeBuffer = array - > getShapeInfo ( ) ; //Nd4jLong*
int rank = ( int ) * shapeBuffer ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < rank ; e + + ) {
if ( e > 0 )
result . append ( " , " ) ;
Nd4jLong stride = * ( shapeBuffer + rank + 1 + e ) ;
result + = flatbuffers : : NumToString ( stride ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : shapeAsString ( const std : : vector < Nd4jLong > & shape ) {
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < shape . size ( ) ; e + + ) {
result + = flatbuffers : : NumToString ( shape . at ( e ) ) ;
if ( e < shape . size ( ) - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : shapeAsString ( const Nd4jLong * shapeInfo ) {
if ( ! shapeInfo )
throw std : : runtime_error ( " ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr ! " ) ;
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < shapeInfo [ 0 ] ; e + + ) {
result + = flatbuffers : : NumToString ( shapeInfo [ e + 1 ] ) ;
if ( e < shapeInfo [ 0 ] - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
std : : string ShapeUtils : : shapeAsString ( const int rank , const Nd4jLong * shapeInfo ) {
if ( ! shapeInfo )
throw std : : runtime_error ( " ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr ! " ) ;
std : : string result ;
result . append ( " [ " ) ;
for ( int e = 0 ; e < rank ; e + + ) {
result + = flatbuffers : : NumToString ( shapeInfo [ e ] ) ;
if ( e < rank - 1 )
result . append ( " , " ) ;
}
result . append ( " ] " ) ;
return result ;
}
//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
Nd4jLong * ShapeUtils : : evalDiagShapeInfo ( const Nd4jLong * shapeInfoConst , nd4j : : memory : : Workspace * workspace ) {
auto shapeInfo = const_cast < Nd4jLong * > ( shapeInfoConst ) ;
const auto rank = shape : : rank ( shapeInfo ) ;
Nd4jLong * outputShapeInfo = nullptr ;
if ( shape : : isVector ( shapeInfo ) | | shape : : isScalar ( shapeInfo ) ) {
ALLOCATE ( outputShapeInfo , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
outputShapeInfo [ 0 ] = 2 ;
outputShapeInfo [ 1 ] = outputShapeInfo [ 2 ] = shape : : length ( shapeInfo ) ;
}
else {
ALLOCATE ( outputShapeInfo , workspace , shape : : shapeInfoLength ( 2 * rank ) , Nd4jLong ) ;
outputShapeInfo [ 0 ] = 2 * rank ;
for ( int i = 1 ; i < = rank ; + + i )
outputShapeInfo [ i ] = outputShapeInfo [ i + rank ] = shapeInfo [ i ] ;
}
ShapeUtils : : updateStridesAndType ( outputShapeInfo , shapeInfo , shape : : order ( shapeInfo ) ) ;
auto result = ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( outputShapeInfo ) ;
RELEASE ( outputShapeInfo , workspace ) ;
return result ;
}
std : : vector < int > ShapeUtils : : evalBroadcastBackwardAxis ( const Nd4jLong * operandShapeInfo , const Nd4jLong * resultShapeInfo ) {
// rRank >= oRank always !!
const auto oRank = shape : : rank ( operandShapeInfo ) ;
const auto rRank = shape : : rank ( resultShapeInfo ) ;
const auto diff = rRank - oRank ;
std : : vector < int > axis ;
for ( int i = 0 ; i < rRank ; + + i )
if ( i < diff | | shape : : sizeAt ( operandShapeInfo , i - diff ) ! = shape : : sizeAt ( resultShapeInfo , i ) )
axis . push_back ( i ) ;
return axis ;
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong * ShapeUtils : : matrixProductShape ( Nd4jLong * theFirstShape , Nd4jLong * theSecondShape , bool shouldTranspondFirst , bool shouldTranspondSecond , nd4j : : DataType dtype , nd4j : : memory : : Workspace * workspace ) {
auto inA = theFirstShape ;
auto inB = theSecondShape ;
Nd4jLong * shape ;
ALLOCATE ( shape , workspace , shape : : shapeInfoLength ( 2 ) , Nd4jLong ) ;
Nd4jLong * tmpA = ShapeBuilders : : copyShapeInfo ( inA , true , workspace ) ;
Nd4jLong * tmpB = ShapeBuilders : : copyShapeInfo ( inB , true , workspace ) ;
if ( shouldTranspondFirst )
shape : : transposeInplace ( tmpA ) ;
if ( shouldTranspondSecond )
shape : : transposeInplace ( tmpB ) ;
if ( shape : : rank ( tmpA ) = = 1 & & shape : : isMatrix ( tmpB ) ) {
// special case here
shape [ 0 ] = 1 ;
shape [ 1 ] = tmpB [ 2 ] ;
Nd4jLong * newShape = ShapeBuilders : : createShapeInfo ( dtype , ' f ' , 2 , shape , workspace ) ;
RELEASE ( shape , workspace ) ;
RELEASE ( tmpA , workspace ) ;
RELEASE ( tmpB , workspace ) ;
return newShape ;
} else if ( shape : : isScalar ( tmpA ) & & shape : : isScalar ( tmpB ) ) {
// just scalar vs scalar
shape [ 0 ] = 1 ;
shape [ 1 ] = 1 ;
} else if ( shape : : isMatrix ( tmpA ) & & shape : : isVector ( tmpB ) ) {
// gemv case
if ( shape : : rank ( tmpB ) = = 2 ) {
shape [ 0 ] = tmpA [ 1 ] ;
shape [ 1 ] = tmpB [ 2 ] ;
} else {
// we have new 1D shape here
auto newShape = ShapeBuilders : : createVectorShapeInfo ( dtype , tmpA [ 1 ] , workspace ) ;
RELEASE ( shape , workspace ) ;
RELEASE ( tmpA , workspace ) ;
RELEASE ( tmpB , workspace ) ;
return newShape ;
}
} else if ( ( shape : : isMatrix ( tmpA ) & & shape : : isMatrix ( tmpB ) ) | |
( shape : : isVector ( tmpA ) & & shape : : isMatrix ( tmpB ) ) | |
( shape : : isColumnVector ( tmpA ) & & shape : : isVector ( tmpB ) ) ) {
// gemm case
shape [ 0 ] = tmpA [ 1 ] ;
shape [ 1 ] = tmpB [ 2 ] ;
} else if ( ( shape : : isVector ( tmpA ) & & shape : : isScalar ( tmpB ) ) | |
( shape : : isScalar ( tmpA ) & & shape : : isVector ( tmpB ) ) ) {
// element-wise
shape [ 0 ] = 1 ;
shape [ 1 ] = ( int ) nd4j : : math : : nd4j_max < Nd4jLong > ( shape : : length ( tmpA ) , shape : : length ( tmpB ) ) ;
} else if ( shape : : isRowVector ( tmpA ) & & shape : : isRowVector ( tmpB ) ) {
// dot case
shape [ 0 ] = 1 ;
shape [ 1 ] = 1 ;
} else if ( shape : : isRowVector ( tmpA ) & & shape : : isColumnVector ( tmpB ) ) {
// dot case
shape [ 0 ] = 1 ;
shape [ 1 ] = 1 ;
}
Nd4jLong * newShape = ShapeBuilders : : createShapeInfo ( dtype , ' f ' , 2 , shape , workspace ) ;
RELEASE ( shape , workspace ) ;
RELEASE ( tmpA , workspace ) ;
RELEASE ( tmpB , workspace ) ;
return newShape ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < int > ShapeUtils : : evalPermutFromTo ( const std : : vector < Nd4jLong > & shapeFrom , const std : : vector < Nd4jLong > & shapeTo ) {
auto rank = shapeFrom . size ( ) ;
if ( rank ! = shapeTo . size ( ) )
throw std : : runtime_error ( " ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation ! " ) ;
if ( std : : equal ( begin ( shapeFrom ) , end ( shapeFrom ) , begin ( shapeTo ) ) ) // if shapes are identical (permutation is unnecessary) then return empty vector
return std : : vector < int > ( ) ;
std : : vector < int > permutation ( rank , - 2 ) ; // vector to be returned
std : : vector < Nd4jLong > shapeTo2 ( shapeTo ) ; // make copy of const vector since we will change the content of shapeTo
for ( int i = 0 ; i < rank ; + + i )
for ( int j = 0 ; j < rank ; + + j )
if ( shapeFrom [ i ] = = shapeTo2 [ j ] ) {
permutation [ j ] = i ;
shapeTo2 [ j ] = - 2 ; // mark coincidence as -2 in order to not account index of shapeTo twice
break ;
}
if ( std : : find ( begin ( permutation ) , end ( permutation ) , - 2 ) ! = end ( permutation ) ) // if -2 is still present in vector then permutation is impossible
throw std : : runtime_error ( " ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation ! " ) ;
return permutation ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : composeShapeUsingDimsAndIdx ( const std : : vector < int > & dimsAndIdx ) {
auto size = dimsAndIdx . size ( ) ;
if ( size % 2 ! = 0 )
throw std : : runtime_error ( " ShapeUtils::composeShapeUsingDimsAndIdx static method: the size of input vector must be even ! " ) ;
size / = 2 ;
std : : vector < Nd4jLong > shape ( size ) ;
int index ;
for ( int i = 0 ; i < size ; + + i ) {
index = dimsAndIdx [ i + size ] ;
if ( index > size - 1 )
throw std : : runtime_error ( " ShapeUtils::composeShapeUsingDimsAndIdx static method: input index is too large ! " ) ;
shape [ index ] = dimsAndIdx [ i ] ;
}
return shape ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : evalShapeForMatmul ( const Nd4jLong * xShapeInfo , const Nd4jLong * yShapeInfo , const bool transX , const bool transY ) {
const auto xRank = xShapeInfo [ 0 ] ;
const auto yRank = yShapeInfo [ 0 ] ;
const Nd4jLong x0Dim = transX ? xShapeInfo [ xRank ] : xShapeInfo [ xRank - 1 ] ;
const Nd4jLong y0Dim = transY ? yShapeInfo [ yRank ] : yShapeInfo [ yRank - 1 ] ;
const Nd4jLong x1Dim = transX ? xShapeInfo [ xRank - 1 ] : xShapeInfo [ xRank ] ;
const Nd4jLong y1Dim = transY ? yShapeInfo [ yRank - 1 ] : yShapeInfo [ yRank ] ;
if ( xRank = = 1 & & yRank = = 1 ) { // dot case, output is scalar
if ( xShapeInfo [ 1 ] ! = yShapeInfo [ 1 ] ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i ! " , xShapeInfo [ 1 ] , yShapeInfo [ 1 ] ) ;
throw std : : invalid_argument ( " " ) ;
}
2019-06-15 13:34:34 +02:00
return std : : vector < Nd4jLong > ( { } ) ;
2019-06-06 14:21:15 +02:00
}
if ( xRank = = 1 & & yRank = = 2 ) { // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
if ( xShapeInfo [ 1 ] ! = y0Dim ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s ! " , ShapeUtils : : shapeAsString ( xShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( yShapeInfo ) . c_str ( ) ) ;
throw std : : invalid_argument ( " " ) ;
}
return std : : vector < Nd4jLong > ( { y1Dim } ) ;
}
if ( xRank = = 2 & & yRank = = 1 ) { // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
if ( x1Dim ! = yShapeInfo [ 1 ] ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s ! " , ShapeUtils : : shapeAsString ( xShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( yShapeInfo ) . c_str ( ) ) ;
throw std : : invalid_argument ( " " ) ;
}
return std : : vector < Nd4jLong > ( { x0Dim } ) ;
}
// rest cases - usual 2Dx2D or batched mmul
if ( xRank ! = yRank ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul static method: the ranks of arrays must be the same, but got xRank = %i and yRank = %i ! \n " , xRank , yRank ) ;
throw std : : invalid_argument ( " " ) ;
}
if ( x1Dim ! = y0Dim ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xDim %i != yDim %i \n " , x1Dim , y0Dim ) ;
throw std : : invalid_argument ( " " ) ;
}
for ( int i = 0 ; i < xRank - 2 ; + + i )
if ( xShapeInfo [ i + 1 ] ! = yShapeInfo [ i + 1 ] ) {
nd4j_printf ( " ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xShape = %s, yShape = %s ! \n " , ShapeUtils : : shapeAsString ( xShapeInfo ) . c_str ( ) , ShapeUtils : : shapeAsString ( yShapeInfo ) . c_str ( ) ) ;
throw std : : invalid_argument ( " " ) ;
}
std : : vector < Nd4jLong > cShape ( xRank ) ;
// copy batch part of shape (if present)
for ( int i = 0 ; i < xRank - 2 ; + + i )
cShape [ i ] = xShapeInfo [ i + 1 ] ;
// copy rest part of shape (two dims: multiplication part)
cShape [ xRank - 2 ] = x0Dim ;
cShape [ xRank - 1 ] = y1Dim ;
return cShape ;
}
////////////////////////////////////////////////////////////////////////////////
Nd4jLong ShapeUtils : : getNumOfSubArrs ( const Nd4jLong * shapeInfo , const std : : vector < int > & dimsToExclude ) {
Nd4jLong numOfSubArrs = 1 ;
if ( dimsToExclude . size ( ) = = shape : : rank ( shapeInfo ) | | dimsToExclude . size ( ) = = 0 ) // means there is only one sub-array and it coincides with whole array
return numOfSubArrs ;
for ( const auto & dim : dimsToExclude )
numOfSubArrs * = shapeInfo [ dim + 1 ] ;
return numOfSubArrs ;
}
////////////////////////////////////////////////////////////////////////////////
void ShapeUtils : : evalIdxRangesForSubArr ( const Nd4jLong subArrIdx , const Nd4jLong * shapeInfo , const std : : vector < int > & dimsToExclude , Nd4jLong * idxRanges ) {
const auto rank = shape : : rank ( shapeInfo ) ;
const auto subArrRank = static_cast < int > ( dimsToExclude . size ( ) ) ;
if ( subArrRank > rank )
throw std : : invalid_argument ( " ShapeUtils::evalIdxRangesForSubArr static method: dimsToExclude is empty or has size > rank of array ! " ) ;
if ( subArrRank = = 0 ) { // means whole array
memset ( idxRanges , 0 , 2 * rank * sizeof ( Nd4jLong ) ) ;
return ;
}
std : : vector < Nd4jLong > shapeOfSubArr ( subArrRank ) , indexes ( subArrRank ) ;
for ( int i = 0 ; i < subArrRank ; + + i )
shapeOfSubArr [ i ] = shapeInfo [ dimsToExclude [ i ] + 1 ] ;
2019-09-11 19:12:09 +02:00
shape : : index2coords ( subArrIdx , subArrRank , shapeOfSubArr . data ( ) , indexes . data ( ) ) ;
2019-06-06 14:21:15 +02:00
memset ( idxRanges , 0 , 2 * rank * sizeof ( Nd4jLong ) ) ;
for ( int i = 0 ; i < subArrRank ; + + i ) {
int currIdx = 2 * dimsToExclude [ i ] ;
idxRanges [ currIdx ] = indexes [ i ] ;
idxRanges [ currIdx + 1 ] = indexes [ i ] + 1 ;
}
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < Nd4jLong > ShapeUtils : : evalDimsWithoutUnities ( const Nd4jLong * shapeInfo ) {
std : : vector < Nd4jLong > result ;
for ( int i = 1 ; i < = shapeInfo [ 0 ] ; + + i )
if ( shapeInfo [ i ] ! = 1 )
result . push_back ( shapeInfo [ i ] ) ;
return result ;
}
////////////////////////////////////////////////////////////////////////////////
void ShapeUtils : : updateStridesAndType ( Nd4jLong * dest , const Nd4jLong * source , const char order ) {
shape : : updateStrides ( dest , order ) ;
ArrayOptions : : copyDataType ( dest , source ) ;
}
////////////////////////////////////////////////////////////////////////////////
void ShapeUtils : : updateStridesAndType ( Nd4jLong * dest , const DataType dtype , const char order ) {
shape : : updateStrides ( dest , order ) ;
ArrayOptions : : setDataType ( dest , dtype ) ;
}
////////////////////////////////////////////////////////////////////////////////
std : : vector < int > ShapeUtils : : tadAxesForSimpleBroadcast ( const NDArray & max , const NDArray & min ) {
const int maxRank = max . rankOf ( ) ;
const int minRank = min . rankOf ( ) ;
const int diff = maxRank - minRank ;
Nd4jLong numOfMinTads ( 1 ) , numOfMaxTads ( 1 ) ;
std : : vector < int > maxTadDims ;
for ( int i = 0 ; i < minRank ; + + i ) {
if ( min . sizeAt ( i ) = = max . sizeAt ( diff + i ) )
maxTadDims . push_back ( diff + i ) ;
else {
numOfMinTads * = min . sizeAt ( i ) ;
numOfMaxTads * = max . sizeAt ( i ) ;
}
}
if ( min . lengthOf ( ) > max . lengthOf ( ) ) { // in this case tad is max array
for ( int i = 0 ; i < diff ; + + i )
numOfMaxTads * = max . sizeAt ( i ) ;
return numOfMaxTads = = 1 ? maxTadDims : std : : vector < int > ( ) ;
}
return numOfMinTads = = 1 ? maxTadDims : std : : vector < int > ( ) ;
}
2019-10-01 08:10:19 +02:00
Nd4jLong ShapeUtils : : stringBufferHeaderRequirements ( Nd4jLong numStrings ) {
// we store +1 offset
auto base = numStrings + 1 ;
// since we return number of bytes...
return base * sizeof ( Nd4jLong ) ;
}
////////////////////////////////////////////////////////////////////////////////
/*
bool ShapeUtils : : isSubArrayCase ( const NDArray & arr1 , const NDArray & arr2 , std : : vector < int > & sameDims ) {
if ( ! sameDims . empty ( ) )
sameDims . clear ( ) ;
const NDArray * max = & arr1 ;
const NDArray * min = & arr2 ;
2019-06-06 14:21:15 +02:00
2019-10-01 08:10:19 +02:00
if ( arr1 . lengthOf ( ) < arr2 . lengthOf ( ) ) {
max = & arr2 ;
min = & arr1 ;
2019-06-06 14:21:15 +02:00
}
2019-10-01 08:10:19 +02:00
int numUnitiesInMin = 0 ;
for ( int iMax = - 1 , iMin = - 1 ; iMax > = - max - > rankOf ( ) & & iMin > = - min - > rankOf ( ) ; ) {
if ( max - > sizeAt ( iMax ) = = 1 ) { // ignore unities in shape
- - iMax ;
continue ;
}
if ( min - > sizeAt ( iMin ) = = 1 ) { // ignore unities in shape
+ + numUnitiesInMin ;
- - iMin ;
continue ;
}
if ( max - > sizeAt ( iMax ) = = min - > sizeAt ( iMin ) ) {
sameDims . insert ( sameDims . begin ( ) , iMax + max - > rankOf ( ) ) ;
- - iMin ;
}
- - iMax ;
}
return sameDims . size ( ) + numUnitiesInMin = = min - > rankOf ( ) ;
}
*/
2019-06-06 14:21:15 +02:00
}