2020-03-25 06:40:30 +01:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
//
# include <ops/declarable/helpers/transforms.h>
# include <helpers/ShapeUtils.h>
# include <numeric>
# include <helpers/Loops.h>
namespace sd {
namespace ops {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template < typename X , typename Y >
static void gatherND_ ( NDArray & input , NDArray & indices , NDArray & output ) {
2020-05-09 07:06:14 +02:00
const X * x = reinterpret_cast < X * > ( input . buffer ( ) ) ;
const Y * y = reinterpret_cast < Y * > ( indices . buffer ( ) ) ;
X * z = reinterpret_cast < X * > ( output . buffer ( ) ) ;
2020-03-25 06:40:30 +01:00
const int xRank = input . rankOf ( ) ;
const int yRank = indices . rankOf ( ) ;
const int zRank = output . rankOf ( ) ;
const int maxRank = sd : : math : : nd4j_max < int > ( yRank , sd : : math : : nd4j_max < int > ( xRank , zRank ) ) ;
const Nd4jLong zLen = output . lengthOf ( ) ;
const uint yLastDim = indices . sizeAt ( - 1 ) ;
const int diff = zRank - xRank ;
const bool bEqual = yLastDim = = xRank ;
auto func = PRAGMA_THREADS_FOR {
int xCoords [ MAX_RANK ] , zCoords [ MAX_RANK ] , temp ;
for ( auto i = start ; i < stop ; i + + ) {
2020-05-09 07:06:14 +02:00
shape : : index2coordsCPU ( start , i , output . shapeInfo ( ) , zCoords ) ;
2020-03-25 06:40:30 +01:00
2020-05-09 07:06:14 +02:00
const auto zOffset = shape : : getOffset ( output . shapeInfo ( ) , zCoords ) ;
2020-03-25 06:40:30 +01:00
temp = zCoords [ yRank - 1 ] ;
zCoords [ yRank - 1 ] = 0 ;
2020-05-09 07:06:14 +02:00
const auto yOffset = shape : : getOffset ( indices . shapeInfo ( ) , zCoords ) ;
2020-03-25 06:40:30 +01:00
zCoords [ yRank - 1 ] = temp ;
if ( bEqual )
memcpy ( xCoords , zCoords , zRank * sizeof ( int ) ) ;
else if ( diff > = 0 )
memcpy ( xCoords , zCoords + diff , xRank * sizeof ( int ) ) ;
else
memcpy ( xCoords - diff , zCoords , zRank * sizeof ( int ) ) ;
for ( uint j = 0 ; j < yLastDim ; + + j )
xCoords [ j ] = y [ yOffset + j * indices . stridesOf ( ) [ yRank - 1 ] ] ; // last stride
2020-05-09 07:06:14 +02:00
const auto xOffset = shape : : getOffset ( input . shapeInfo ( ) , xCoords ) ;
2020-03-25 06:40:30 +01:00
z [ zOffset ] = x [ xOffset ] ;
}
} ;
samediff : : Threads : : parallel_tad ( func , 0 , zLen ) ;
}
////////////////////////////////////////////////////////////////////////
void gatherND ( sd : : LaunchContext * context , NDArray & input , NDArray & indices , NDArray & output ) {
BUILD_DOUBLE_SELECTOR ( input . dataType ( ) , indices . dataType ( ) , gatherND_ , ( input , indices , output ) , LIBND4J_TYPES , INDEXING_TYPES ) ;
}
////////////////////////////////////////////////////////////////////////
template < typename T >
static void gather_ ( NDArray * input , const NDArray * indices , NDArray * output , const std : : vector < int > & intArgs ) {
int axis = intArgs . size ( ) > 0 ? intArgs [ 0 ] : 0 ;
const int inputRank = input - > rankOf ( ) ;
if ( axis < 0 )
axis + = inputRank ;
const int numOfIntArgs = intArgs . size ( ) ;
if ( indices ! = nullptr ) {
for ( Nd4jLong i = 0 ; i < indices - > lengthOf ( ) ; + + i )
if ( indices - > e < Nd4jLong > ( i ) > = input - > sizeAt ( axis ) )
throw std : : runtime_error ( " helpers::gather function: indices array contains wrong elements, each element must be smaller than corresponding dimension of input array ! " ) ;
// first case: indices consist of only one scalar
if ( indices - > isScalar ( ) ) {
if ( input - > rankOf ( ) < = 1 ) {
//For scalar indices, rank 0 or 1 input: can't do tensor along dimension 0 as this is whole array... instead, we want to get a scalar
auto idx = indices - > e < Nd4jLong > ( 0 ) ;
auto scalarNDArray = input - > e ( idx ) ;
output - > assign ( scalarNDArray ) ;
} else {
auto dimensions = ShapeUtils : : evalDimsToExclude ( input - > rankOf ( ) , { axis } ) ;
2020-05-09 07:06:14 +02:00
auto tadPack = sd : : ConstantTadHelper : : getInstance ( ) - > tadForDimensions ( input - > shapeInfo ( ) , dimensions ) ;
2020-03-25 06:40:30 +01:00
2020-05-09 07:06:14 +02:00
auto tadArr = NDArray ( reinterpret_cast < void * > ( reinterpret_cast < T * > ( input - > buffer ( ) ) + tadPack . primaryOffsets ( ) [ indices - > e < Nd4jLong > ( 0 ) ] ) , tadPack . primaryShapeInfo ( ) , output - > getContext ( ) ) ;
2020-03-25 06:40:30 +01:00
output - > assign ( & tadArr ) ;
}
}
else if ( input - > rankOf ( ) = = 1 & & indices - > isVector ( ) ) {
// special case
auto func = PRAGMA_THREADS_FOR {
for ( auto e = start ; e < stop ; e + + )
output - > p ( e , input - > e < T > ( indices - > e < Nd4jLong > ( e ) ) ) ;
} ;
samediff : : Threads : : parallel_for ( func , 0 , indices - > lengthOf ( ) ) ;
}
else {
std : : vector < int > dimsOut ( indices - > rankOf ( ) ) ;
std : : iota ( dimsOut . begin ( ) , dimsOut . end ( ) , axis ) ; // fill with axis, axis+1, ... indices->rankOf()-1
2020-05-09 07:06:14 +02:00
const Nd4jLong numOfSubArrs = ShapeUtils : : getNumOfSubArrs ( output - > shapeInfo ( ) , dimsOut ) ;
2020-03-25 06:40:30 +01:00
auto func = PRAGMA_THREADS_FOR {
for ( auto i = start ; i < stop ; i + + ) {
NDArray subArrOut = ( * output ) ( i , dimsOut ) ;
NDArray subArrIn = ( * input ) ( indices - > e < Nd4jLong > ( i ) , { axis } ) ;
subArrOut . assign ( subArrIn ) ;
}
} ;
samediff : : Threads : : parallel_tad ( func , 0 , numOfSubArrs ) ;
}
}
else {
for ( int i = 1 ; i < numOfIntArgs ; + + i )
if ( intArgs [ i ] > = input - > sizeAt ( axis ) )
throw std : : runtime_error ( " helpers::gather function: some of input indexes is larger than corresponding shape of input array ! " ) ;
// we only allow scalar/vector case here
if ( numOfIntArgs = = 2 ) { // scalar case
output - > assign ( ( * input ) ( intArgs [ 1 ] , { axis } ) ) ;
}
else { // vector case
2020-05-09 07:06:14 +02:00
const Nd4jLong numOfSubArrs = ShapeUtils : : getNumOfSubArrs ( output - > shapeInfo ( ) , { axis } ) ;
2020-03-25 06:40:30 +01:00
auto func = PRAGMA_THREADS_FOR {
for ( auto i = start ; i < stop ; i + + ) {
NDArray subArrOut = ( * output ) ( i , { axis } ) ;
NDArray subArrIn = ( * input ) ( intArgs [ i + 1 ] , { axis } ) ;
subArrOut . assign ( subArrIn ) ;
}
} ;
samediff : : Threads : : parallel_tad ( func , 0 , numOfSubArrs ) ;
}
}
}
void gather ( NDArray * input , const NDArray * indices , NDArray * output , const std : : vector < int > & intArgs ) {
BUILD_SINGLE_SELECTOR ( input - > dataType ( ) , gather_ , ( input , indices , output , intArgs ) , LIBND4J_TYPES ) ;
}
}
}
}