2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author Shyrma Yurii (iuriish@yahoo.com), created on 16.11.2017
//
2020-03-02 10:49:41 +01:00
# include <system/op_boilerplate.h>
2019-06-06 14:21:15 +02:00
# if NOT_EXCLUDED(OP_gather)
# include <ops/declarable/CustomOperations.h>
2019-11-26 18:29:09 +01:00
# include <ops/declarable/helpers/gather.h>
# include <ops/declarable/helpers/scatter.h>
2019-06-06 14:21:15 +02:00
2020-03-02 10:49:41 +01:00
namespace sd {
2019-06-06 14:21:15 +02:00
namespace ops {
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( gather , 1 , 1 , false , 0 , - 2 ) {
auto input = INPUT_VARIABLE ( 0 ) ;
auto indices = block . width ( ) > 1 ? INPUT_VARIABLE ( 1 ) : nullptr ;
2019-11-26 18:29:09 +01:00
auto output = OUTPUT_VARIABLE ( 0 ) ;
const bool checkIndices = block . getBArguments ( ) - > empty ( ) ? false : B_ARG ( 0 ) ;
2019-06-06 14:21:15 +02:00
//Edge case: empty indices -> empty output
if ( indices ! = nullptr & & indices - > isEmpty ( ) ) {
REQUIRE_TRUE ( output - > isEmpty ( ) , 0 , " Gather op: If indices are empty, output must also be empty " ) ;
return Status : : OK ( ) ; //No op
}
const int numOfIntArgs = block . numI ( ) ;
std : : vector < int > intArgs ;
if ( block . width ( ) > 2 ) {
intArgs = INPUT_VARIABLE ( 2 ) - > template asVectorT < int > ( ) ;
2019-11-26 18:29:09 +01:00
}
2019-06-06 14:21:15 +02:00
else {
if ( numOfIntArgs = = 0 )
intArgs . emplace_back ( 0 ) ;
else
for ( int i = 0 ; i < numOfIntArgs ; + + i )
intArgs . emplace_back ( block . getIArguments ( ) - > at ( i ) ) ;
}
const int inputRank = input - > rankOf ( ) ;
if ( intArgs [ 0 ] < 0 )
intArgs [ 0 ] + = inputRank ;
// input validation
REQUIRE_TRUE ( intArgs [ 0 ] < inputRank , 0 , " GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly! " , intArgs [ 0 ] , inputRank ) ;
REQUIRE_TRUE ( indices ! = nullptr | | numOfIntArgs > 1 , 0 , " GATHER op: indices should be provided either as additional input array or as IntArguments ! " ) ;
2019-11-26 18:29:09 +01:00
if ( checkIndices ) {
NDArray * pIndices = indices ;
if ( indices = = nullptr )
pIndices = new NDArray ( input - > ordering ( ) , { static_cast < int > ( intArgs . size ( ) ) - 1 } , std : : vector < double > ( intArgs . begin ( ) + 1 , intArgs . end ( ) ) , DataType : : INT64 , block . launchContext ( ) ) ;
const Nd4jLong numOfBadIndx = helpers : : checkIndices ( block . launchContext ( ) , * pIndices , * input , intArgs [ 0 ] ) ;
REQUIRE_TRUE ( numOfBadIndx = = 0 , 0 , " GATHER OP: please check elements of indices-array, total number of wrong elements is %lld! " , numOfBadIndx ) ;
if ( indices = = nullptr )
delete pIndices ;
}
2019-06-06 14:21:15 +02:00
helpers : : gather ( block . launchContext ( ) , input , indices , output , intArgs ) ;
return Status : : OK ( ) ;
}
DECLARE_TYPES ( gather ) {
getOpDescriptor ( ) - > setAllowedInputTypes ( 0 , { ALL_INTS , ALL_FLOATS } ) ;
getOpDescriptor ( ) - > setAllowedInputTypes ( 1 , { ALL_INTS } ) ;
getOpDescriptor ( ) - > setAllowedOutputTypes ( 0 , { ALL_INTS , ALL_FLOATS } ) ;
}
DECLARE_SHAPE_FN ( gather ) {
2019-11-26 18:29:09 +01:00
// check shape of paddings
2019-06-06 14:21:15 +02:00
auto inputShapeInfo = inputShape - > at ( 0 ) ;
Nd4jLong * outputShapeInfo = nullptr ;
int axis = 0 ;
if ( block . width ( ) > 2 ) {
axis = INPUT_VARIABLE ( 2 ) - > e < int > ( 0 ) ;
} else
axis = block . numI ( ) > 0 ? block . getIArguments ( ) - > at ( 0 ) : 0 ;
int inputRank = shape : : rank ( inputShapeInfo ) ;
if ( axis < 0 )
axis + = inputRank ;
REQUIRE_TRUE ( axis < inputRank , 0 , " GATHER op: input axis must be smaller than input array rank, but got %i and %i correspondingly! " , axis , inputRank ) ;
bool isEmpty = false ;
2019-11-26 18:29:09 +01:00
2019-06-06 14:21:15 +02:00
if ( block . width ( ) > 1 ) {
auto indicesShapeInfo = inputShape - > at ( 1 ) ;
2019-11-26 18:29:09 +01:00
2019-06-06 14:21:15 +02:00
int indicesRank = shape : : rank ( indicesShapeInfo ) ;
2019-11-26 18:29:09 +01:00
2019-06-06 14:21:15 +02:00
int outputRank = inputRank + indicesRank - 1 ;
2019-11-26 18:29:09 +01:00
2019-06-06 14:21:15 +02:00
ALLOCATE ( outputShapeInfo , block . getWorkspace ( ) , shape : : shapeInfoLength ( outputRank ) , Nd4jLong ) ;
// fill output shapeInfo
outputShapeInfo [ 0 ] = outputRank ;
2019-11-26 18:29:09 +01:00
int shapeIdx = 1 ;
for ( int i = 0 ; i < axis ; + + i )
2019-06-06 14:21:15 +02:00
outputShapeInfo [ shapeIdx + + ] = inputShapeInfo [ i + 1 ] ;
for ( int i = 0 ; i < indicesRank ; + + i )
outputShapeInfo [ shapeIdx + + ] = indicesShapeInfo [ i + 1 ] ;
for ( int i = axis + 1 ; i < inputRank ; + + i )
outputShapeInfo [ shapeIdx + + ] = inputShapeInfo [ i + 1 ] ;
2019-11-26 18:29:09 +01:00
}
2019-06-06 14:21:15 +02:00
else if ( block . numI ( ) > 1 ) {
int indicesRank = block . numI ( ) = = 2 ? 0 : 1 ;
int outputRank = inputRank + indicesRank - 1 ;
ALLOCATE ( outputShapeInfo , block . getWorkspace ( ) , shape : : shapeInfoLength ( outputRank ) , Nd4jLong ) ;
// building shape manually
outputShapeInfo [ 0 ] = outputRank ;
2019-11-26 18:29:09 +01:00
int shapeIdx = 1 ;
2019-06-06 14:21:15 +02:00
for ( int i = 0 ; i < axis ; + + i )
outputShapeInfo [ shapeIdx + + ] = inputShapeInfo [ i + 1 ] ;
if ( block . numI ( ) > 2 )
outputShapeInfo [ shapeIdx + + ] = block . numI ( ) - 1 ;
for ( int i = axis + 1 ; i < inputRank ; + + i )
outputShapeInfo [ shapeIdx + + ] = inputShapeInfo [ i + 1 ] ;
}
else
REQUIRE_TRUE ( false , 0 , " GATHER op: indices should be provided either as additional input array or as IntArguments ! " ) ;
ShapeUtils : : updateStridesAndType ( outputShapeInfo , inputShapeInfo , shape : : order ( inputShapeInfo ) ) ;
if ( isEmpty ) {
ArrayOptions : : setPropertyBit ( outputShapeInfo , ARRAY_EMPTY ) ;
}
auto result = ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( ShapeDescriptor ( outputShapeInfo ) ) ;
RELEASE ( outputShapeInfo , block . getWorkspace ( ) ) ;
return SHAPELIST ( result ) ;
}
}
}
# endif