2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright ( c ) 2015 - 2018 Skymind , Inc .
*
* This program and the accompanying materials are made available under the
* terms of the Apache License , Version 2.0 which is available at
* https : //www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing , software
* distributed under the License is distributed on an " AS IS " BASIS , WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied . See the
* License for the specific language governing permissions and limitations
* under the License .
*
* SPDX - License - Identifier : Apache - 2.0
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
//
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
# include <ops/declarable/CustomOperations.h>
# include <ops/declarable/helpers/transforms.h>
# include <array>
2020-03-02 10:49:41 +01:00
namespace sd {
2019-11-03 11:37:19 +01:00
namespace ops {
2019-06-06 14:21:15 +02:00
//////////////////////////////////////////////////////////////////////////
2019-11-03 11:37:19 +01:00
CUSTOM_OP_IMPL ( concat , - 1 , 1 , false , 0 , 0 ) {
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( block . width ( ) > 0 , 0 , " CONCAT op: No input arrays were provided " ) ;
2019-11-03 11:37:19 +01:00
const bool isAxisInLastArr = block . getBArguments ( ) - > size ( ) = = 0 ? false : B_ARG ( 0 ) ;
const int numOfInArrs = isAxisInLastArr ? block . width ( ) - 1 : block . width ( ) ;
2019-06-06 14:21:15 +02:00
// first of all take into account possible presence of empty arrays
// also if scalar is present -> copy its value to vector with length=1
2020-03-03 05:32:37 +01:00
std : : vector < const NDArray * > nonEmptyArrs ;
2019-06-06 14:21:15 +02:00
std : : vector < int > arrsToDelete ;
int index = 0 ;
bool allOfSameType = true ;
2020-02-20 19:19:01 +01:00
auto rankOfFirstArr = block . width ( ) > 0 ? INPUT_VARIABLE ( 0 ) - > rankOf ( ) : 0 ;
auto typeOfFirstArr = block . width ( ) > 0 ? INPUT_VARIABLE ( 0 ) - > dataType ( ) : block . dataType ( ) ;
2019-11-03 11:37:19 +01:00
for ( int i = 0 ; i < numOfInArrs ; + + i ) {
2019-06-15 13:34:34 +02:00
auto input = INPUT_VARIABLE ( i ) ;
auto currentRank = input - > rankOf ( ) ;
2019-08-23 18:20:50 +02:00
// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy
2019-06-15 13:34:34 +02:00
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
2020-02-20 19:19:01 +01:00
// REQUIRE_TRUE(rankOfFirstArr == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, rankOfFirstArr);
2019-06-15 13:34:34 +02:00
if ( ! input - > isEmpty ( ) ) {
2020-02-20 19:19:01 +01:00
allOfSameType & = ( typeOfFirstArr = = input - > dataType ( ) ) ;
2019-11-03 11:37:19 +01:00
2019-06-15 13:34:34 +02:00
if ( input - > rankOf ( ) = = 0 ) {
auto vec = new NDArray ( ' c ' , { 1 } , input - > dataType ( ) , block . launchContext ( ) ) ;
vec - > assign ( input ) ;
2019-06-06 14:21:15 +02:00
nonEmptyArrs . push_back ( vec ) ;
arrsToDelete . push_back ( index ) ;
}
else {
2019-06-15 13:34:34 +02:00
nonEmptyArrs . push_back ( input ) ;
2019-06-06 14:21:15 +02:00
}
+ + index ;
}
}
2019-06-15 13:34:34 +02:00
2019-11-03 11:37:19 +01:00
const int numOfNonEmptyArrs = nonEmptyArrs . size ( ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
if ( numOfNonEmptyArrs = = 0 ) {
2019-06-06 14:21:15 +02:00
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
REQUIRE_TRUE ( OUTPUT_VARIABLE ( 0 ) - > isEmpty ( ) , 0 , " CONCAT op: If all input variables are empty, output must be empty " ) ;
return Status : : OK ( ) ;
}
const int rank = nonEmptyArrs [ 0 ] - > rankOf ( ) ; // look up to first non-empty array
2019-11-04 14:05:04 +01:00
int axis = isAxisInLastArr ? INPUT_VARIABLE ( block . width ( ) - 1 ) - > e < int > ( 0 ) : INT_ARG ( 0 ) ;
if ( axis < 0 ) {
axis + = rank ;
}
2019-06-06 14:21:15 +02:00
// ******** input validation ******** //
REQUIRE_TRUE ( allOfSameType , 0 , " CONCAT op: all of input arrays must have same type ! " ) ;
REQUIRE_TRUE ( 0 < = axis & & ( axis < rank | | ( axis = = 0 & & rank = = 0 ) ) , 0 , " CONCAT op: input axis must be in range [0, %i], but got %i instead! " , rank - 1 , axis ) ;
2019-11-03 11:37:19 +01:00
for ( int i = 1 ; i < numOfNonEmptyArrs ; + + i )
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( nonEmptyArrs [ i ] - > rankOf ( ) = = rank , 0 , " CONCAT op: all input arrays must have the same rank ! " ) ;
2019-11-03 11:37:19 +01:00
for ( int i = 1 ; i < numOfNonEmptyArrs ; + + i ) {
2019-06-06 14:21:15 +02:00
for ( int dim = 0 ; dim < rank ; + + dim )
2019-06-15 13:34:34 +02:00
if ( dim ! = axis )
2019-06-06 14:21:15 +02:00
REQUIRE_TRUE ( nonEmptyArrs [ i ] - > sizeAt ( dim ) = = nonEmptyArrs [ 0 ] - > sizeAt ( dim ) , 0 , " CONCAT op: all input arrays must have the same dimensions (except those on input axis) ! " ) ;
}
// ******** end of input validation ******** //
auto output = OUTPUT_VARIABLE ( 0 ) ;
2019-11-03 11:37:19 +01:00
if ( numOfNonEmptyArrs = = 1 )
2019-06-06 14:21:15 +02:00
output - > assign ( nonEmptyArrs [ 0 ] ) ;
2019-06-15 13:34:34 +02:00
else
2019-06-06 14:21:15 +02:00
helpers : : concat ( block . launchContext ( ) , nonEmptyArrs , * output , axis ) ;
// delete dynamically allocated vectors with length=1
for ( int index : arrsToDelete )
delete nonEmptyArrs [ index ] ;
return Status : : OK ( ) ;
}
DECLARE_SYN ( ParallelConcat , concat ) ;
DECLARE_SYN ( concat_v2 , concat ) ;
DECLARE_SYN ( concatv2 , concat ) ;
DECLARE_TYPES ( concat ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY ) ;
2019-11-03 11:37:19 +01:00
// ->setSameMode(true);
2019-06-06 14:21:15 +02:00
}
2019-11-03 11:37:19 +01:00
//////////////////////////////////////////////////////////////////////////
2019-06-06 14:21:15 +02:00
DECLARE_SHAPE_FN ( concat ) {
REQUIRE_TRUE ( block . width ( ) > 0 , 0 , " CONCAT op: No input arrays were provided " ) ;
2019-06-15 13:34:34 +02:00
2019-11-03 11:37:19 +01:00
const bool isAxisInLastArr = block . getBArguments ( ) - > size ( ) = = 0 ? false : B_ARG ( 0 ) ;
const int numOfInArrs = isAxisInLastArr ? block . width ( ) - 1 : block . width ( ) ;
2019-06-06 14:21:15 +02:00
// first of all take into account possible presence of empty arrays
2019-06-15 13:34:34 +02:00
// also if scalar is present -> use the shape of vector with length=1 instead
2020-05-09 07:06:14 +02:00
ShapeList arrShapes ;
2019-06-06 14:21:15 +02:00
std : : vector < int > shapesToDelete ;
int index = 0 ;
2019-11-03 11:37:19 +01:00
for ( int i = 0 ; i < numOfInArrs ; + + i ) {
2019-06-15 13:34:34 +02:00
if ( inputShape - > at ( i ) [ 0 ] = = 0 ) {
2019-08-23 12:16:50 +02:00
if ( shape : : isEmpty ( inputShape - > at ( i ) ) )
arrShapes . push_back ( ConstantShapeHelper : : getInstance ( ) - > vectorShapeInfo ( 0 , INPUT_VARIABLE ( 0 ) - > dataType ( ) ) ) ;
else
arrShapes . push_back ( ConstantShapeHelper : : getInstance ( ) - > vectorShapeInfo ( 1 , INPUT_VARIABLE ( 0 ) - > dataType ( ) ) ) ;
2019-06-06 14:21:15 +02:00
}
2019-06-15 13:34:34 +02:00
else {
arrShapes . push_back ( inputShape - > at ( i ) ) ;
}
+ + index ;
2019-06-06 14:21:15 +02:00
}
2019-11-03 11:37:19 +01:00
const int numOfNonEmptyArrs = arrShapes . size ( ) ;
2019-06-06 14:21:15 +02:00
2020-05-09 07:06:14 +02:00
const int rank = shape : : rank ( arrShapes . at ( 0 ) ) ;
2019-06-06 14:21:15 +02:00
2019-11-04 14:05:04 +01:00
int axis = isAxisInLastArr ? INPUT_VARIABLE ( block . width ( ) - 1 ) - > e < int > ( 0 ) : INT_ARG ( 0 ) ;
if ( axis < 0 ) {
axis + = rank ;
}
2019-06-06 14:21:15 +02:00
// ******** input validation ******** //
REQUIRE_TRUE ( 0 < = axis & & axis < rank , 0 , " CONCAT op: input axis must be in range [0, %i], but got %i instead! " , rank - 1 , axis ) ;
2019-11-03 11:37:19 +01:00
for ( int i = 1 ; i < numOfNonEmptyArrs ; + + i )
2020-05-09 07:06:14 +02:00
REQUIRE_TRUE ( shape : : rank ( arrShapes . at ( i ) ) = = rank , 0 , " CONCAT op: all input arrays must have the same rank ! " ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
for ( int i = 1 ; i < numOfNonEmptyArrs ; + + i ) {
2019-06-06 14:21:15 +02:00
for ( int dim = 0 ; dim < rank ; + + dim )
2019-06-15 13:34:34 +02:00
if ( dim ! = axis )
2020-05-09 07:06:14 +02:00
REQUIRE_TRUE ( arrShapes . at ( i ) [ dim + 1 ] = = arrShapes . at ( 0 ) [ dim + 1 ] , 0 , " CONCAT op: all input arrays must have the same dimensions (except those on input axis) ! " ) ;
2019-06-06 14:21:15 +02:00
}
// ******** end of input validation ******** //
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
Nd4jLong * outShapeInfo ( nullptr ) ;
2020-05-09 07:06:14 +02:00
COPY_SHAPE ( arrShapes . at ( 0 ) , outShapeInfo ) ;
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
// case when we have only one input array
2019-11-03 11:37:19 +01:00
if ( numOfNonEmptyArrs = = 1 ) {
2020-05-09 07:06:14 +02:00
ShapeUtils : : updateStridesAndType ( outShapeInfo , arrShapes . at ( 0 ) , shape : : order ( arrShapes . at ( 0 ) ) ) ;
2019-06-06 14:21:15 +02:00
return SHAPELIST ( CONSTANT ( outShapeInfo ) ) ;
}
2019-11-03 11:37:19 +01:00
for ( int i = 1 ; i < numOfNonEmptyArrs ; + + i )
2020-05-09 07:06:14 +02:00
outShapeInfo [ axis + 1 ] + = arrShapes . at ( i ) [ axis + 1 ] ;
2019-06-06 14:21:15 +02:00
2020-05-09 07:06:14 +02:00
ShapeUtils : : updateStridesAndType ( outShapeInfo , arrShapes . at ( 0 ) , shape : : order ( arrShapes . at ( 0 ) ) ) ;
2019-06-06 14:21:15 +02:00
// delete dynamically allocated vectors shapes with length=1
2020-05-09 07:06:14 +02:00
// for(int index : shapesToDelete)
// RELEASE(arrShapes[index], block.getWorkspace());
2019-06-06 14:21:15 +02:00
auto result = ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( ShapeDescriptor ( outShapeInfo ) ) ;
RELEASE ( outShapeInfo , block . getWorkspace ( ) ) ;
return SHAPELIST ( result ) ;
}
// //////////////////////////////////////////////////////////////////////////
// CUSTOM_OP_IMPL(concat, -1, 1, false, 0, -2){
// // do something here{
// NDArray<T> *last = INPUT_VARIABLE((int) block.width() - 1);
// int _dimension = 0;
// if (block.numI() > 0)
// _dimension = INT_ARG(0);
// else {
// _dimension = (int) last->e(0);
// }
// // we want to ensure that all
// NDArray<T> *first = nullptr;
// auto output = OUTPUT_VARIABLE(0);
// int elements = 0;
// for (int e = 0; e < block.width(); e++) {
// auto arr = INPUT_VARIABLE(e);
// if (!arr->isEmpty())
// elements++;
// // we must find first non-empty element here
// if (!arr->isEmpty() && first == nullptr)
// first = arr;
// }
// REQUIRE_TRUE(first != nullptr, 0, "Concat: at least 1 non-empty input required!");
// // it's possible to get into situation when your input has only 1 input. That's just assign
// if (elements == 1) {
// output->assign(first);
// return Status::OK();
// }
// bool oldScalars = first->rankOf() == 2 && first->isScalar();
// auto buffers = new Nd4jPointer[elements];
// auto shapes = new Nd4jPointer[elements];
2020-05-09 07:06:14 +02:00
// buffers[0] = (Nd4jPointer) first->buffer();
// shapes[0] = (Nd4jPointer) first->shapeInfo();
2019-06-06 14:21:15 +02:00
// if (_dimension < 0)
// _dimension += first->rankOf();
2020-03-02 10:49:41 +01:00
// if (sd::Environment::getInstance()->isDebugAndVerbose()) {
2019-06-06 14:21:15 +02:00
// printf("Shape %i: ", 0);
// shape::printShapeInfoLinear((Nd4jLong *) shapes[0]);
// }
// int er = 0;
// for (int e = 0; e < block.width(); e++) {
// Variable<T> *var = block.variable(e);
// auto array = var->getNDArray();
// if (array->isEmpty())
// continue;
2020-05-09 07:06:14 +02:00
// buffers[er] = reinterpret_cast<Nd4jPointer>(array->buffer());
// shapes[er++] = reinterpret_cast<Nd4jPointer>(array->shapeInfo());
2019-06-06 14:21:15 +02:00
// oldScalars &= array->rankOf() == 2 && array->isScalar();
2020-03-02 10:49:41 +01:00
// if (sd::Environment::getInstance()->isDebugAndVerbose()) {
2019-06-06 14:21:15 +02:00
// printf("Shape %i: ", e);
// shape::printShapeInfoLinear(array->shapeInfo());
// }
// }
2020-03-02 10:49:41 +01:00
// if (sd::Environment::getInstance()->isDebugAndVerbose())
2019-06-06 14:21:15 +02:00
// fflush(stdout);
// if (oldScalars) {
// nd4j_debug("OLD_SCALARS!\n","");
// _dimension = 1;
// }
2020-05-09 07:06:14 +02:00
// sd::SpecialMethods<T>::concatCpuGeneric(_dimension, elements, buffers, shapes, output->buffer(), output->shapeInfo());
2019-06-06 14:21:15 +02:00
// STORE_RESULT(*output);
2020-03-02 10:49:41 +01:00
// if (sd::Environment::getInstance()->isDebugAndVerbose())
2019-06-06 14:21:15 +02:00
// output->printShapeInfo("Concat result shape");
// delete[] buffers;
// delete[] shapes;
// return ND4J_STATUS_OK;
// }
// DECLARE_SYN(ParallelConcat, concat);
// DECLARE_SYN(concat_v2, concat);
// DECLARE_SYN(concatv2, concat);
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
// DECLARE_SHAPE_FN(concat) {
// auto inp = inputShape->at(0);
// int _dimension = INT_ARG(0);
// NDArray<T>* first = nullptr;
// auto last = inputShape->at(inputShape->size() - 1);
// Nd4jLong elements = 0;
// Nd4jLong *newShape;
// for (int e = 0; e < inputShape->size(); e++) {
// auto s = INPUT_VARIABLE(e);
// if (!s->isEmpty()) {
// elements++;
// if (first == nullptr)
// first = s;
// }
// }
// { // special cases for 0D concat
// bool allScalars = true;
// bool hasScalars = false;
// for (int e = 0; e < block.width(); e++) {
// auto c = INPUT_VARIABLE(e);
// if (c->isEmpty())
// continue;
// allScalars &= c->rankOf() == 0;
// hasScalars |= c->rankOf() == 0;
// }
// // all scalars
// if (allScalars) {
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
// shape::shapeBuffer(1, &elements, newShape);
// return SHAPELIST(newShape);
// }
// // any scalar
// if (hasScalars) {
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
// Nd4jLong length = shape::length(inp);
// for (int i = 1; i < block.width(); i++) {
// auto c = INPUT_VARIABLE(i);
// if (c->isEmpty())
// continue;
// length += c->lengthOf();
// }
// shape::shapeBuffer(1, &length, newShape);
// return SHAPELIST(newShape);
// }
// }
2019-06-15 13:34:34 +02:00
2019-06-06 14:21:15 +02:00
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong);
// if (_dimension < 0)
// _dimension += first->rankOf();
// std::memcpy(newShape, first->shapeInfo(), shape::shapeInfoByteLength(first->shapeInfo()));
// for (int i = 0; i < inputShape->size(); i++) {
// auto s = INPUT_VARIABLE(i);
// // FIXME: s == first is bad, but fast. alternatively we can subtract first size out of result
// if (s->isEmpty() || s == first)
// continue;
// newShape[_dimension + 1] += shape::shapeOf(inputShape->at(i))[_dimension];
// }
// shape::updateStrides(newShape, first->ordering());
// return SHAPELIST(newShape);
// }
2019-11-03 11:37:19 +01:00
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL ( concat_bp , - 1 , - 1 , false , 0 , 0 ) {
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
const bool isAxisInLastArr = block . getBArguments ( ) - > size ( ) = = 0 ? false : B_ARG ( 0 ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
const int numOfInArrs = isAxisInLastArr ? block . width ( ) - 1 : block . width ( ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
auto epsilonNext = INPUT_VARIABLE ( numOfInArrs - 1 ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
auto first = INPUT_VARIABLE ( 0 ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
const int axis = isAxisInLastArr ? INPUT_VARIABLE ( block . width ( ) - 1 ) - > e < int > ( 0 ) : ( INT_ARG ( 0 ) > = 0 ? INT_ARG ( 0 ) : INT_ARG ( 0 ) + INPUT_VARIABLE ( 0 ) - > rankOf ( ) ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
int startPos = 0 ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
for ( int e = 0 ; e < numOfInArrs - 1 ; e + + ) {
auto originalChunk = INPUT_VARIABLE ( e ) ;
auto epsilonChunk = OUTPUT_VARIABLE ( e ) ;
std : : vector < Nd4jLong > indices ( 2 * epsilonNext - > rankOf ( ) ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
int width = originalChunk - > sizeAt ( axis ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
for ( int e = 0 ; e < epsilonNext - > rankOf ( ) ; e + + ) {
if ( e = = axis )
indices [ 2 * e + 1 ] = ( indices [ 2 * e ] = startPos ) + width ;
else
indices [ 2 * e + 1 ] = indices [ 2 * e ] = 0 ;
2019-06-06 14:21:15 +02:00
}
2019-11-03 11:37:19 +01:00
auto subarray = ( * epsilonNext ) ( indices , true ) ;
epsilonChunk - > assign ( subarray ) ;
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
startPos + = width ;
}
2019-06-06 14:21:15 +02:00
2019-11-03 11:37:19 +01:00
return ND4J_STATUS_OK ;
}
DECLARE_TYPES ( concat_bp ) {
getOpDescriptor ( )
2020-03-02 10:49:41 +01:00
- > setAllowedInputTypes ( sd : : DataType : : ANY )
2019-11-03 11:37:19 +01:00
- > setAllowedOutputTypes ( { ALL_FLOATS } ) ;
}
DECLARE_SHAPE_FN ( concat_bp ) {
const bool isAxisInLastArr = block . getBArguments ( ) - > size ( ) = = 0 ? false : B_ARG ( 0 ) ;
const int numOfInArrs = isAxisInLastArr ? block . width ( ) - 1 : block . width ( ) ;
auto shapeList = SHAPELIST ( ) ;
for ( int e = 0 ; e < numOfInArrs - 1 ; e + + ) {
auto inShape = inputShape - > at ( e ) ;
shapeList - > push_back ( ConstantShapeHelper : : getInstance ( ) - > createShapeInfo ( ShapeDescriptor ( ArrayOptions : : dataType ( inShape ) , shape : : order ( inShape ) , shape : : shapeOf ( inShape ) , shape : : rank ( inShape ) ) ) ) ;
2019-06-06 14:21:15 +02:00
}
2019-11-03 11:37:19 +01:00
return shapeList ;
}
}
2019-06-06 14:21:15 +02:00
}