* initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored buffer() and shapeInfo() methods usage with NDArray class. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt Graph class methods to use const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt choose op to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt where op shape method to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt lstsq op to use constant empty shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt matrix_diag_part op shape routine to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt determinant ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt mean_pairwssqerr_loss ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for loss ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt log_loss op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt dilation2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted deconv2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted dynamicRNN op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for lstm layer ops. Signed-off-by: shugeo <sgazeos@gmail.com> * few updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * first cuda tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Adopt constant shapes for sconv2d ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes for gru ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes with shape methods for segment ops and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with unsorted_segment_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with gamma op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods of reduce_stddev ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for reduce_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape method for squeeze op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt strided_slice shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored concat op shape method to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape method for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted split op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted tile ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Added const cast for mkldnn routines handles. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetic changes to proper usage of constant pointers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored depthToSpace helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored histogram helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored im2col helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gather and gatherND helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage on percentile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed gather shape with helpers and range buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with space to depth helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage and constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with LUP decomposition> Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored onehot_ helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pad and prefix to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactoed softmax helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed space to batch helpers to use buffers properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed stack and split helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with sparse to dense helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with mindistance_ helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with tile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with legacy pairwise bool ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple of methods to adopt constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed broadcasting with constant shape." Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const usage with inplace reverse and constant shapes with legacy reduction. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored sort to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected sort for constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with special methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored Context to conform with constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * CUDA broadcasting headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * pairwise/indexreduce/random headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored native ops to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * legacy reduce3/scalar headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected pullRow signature and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected routines to proper use of constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with NDArray tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed native ops tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed special concat routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with test. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with a test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored TAD.h and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored calcStrides* routines to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed miscelaneous errors with constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected definitions for declared functions. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed const shapes with shape routines. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed shape method for broadcastable case. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * xw_plus_b BP shape fn restored Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed signatures with broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Repaired backprops shape methods for a set of operations. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored broadcast bool for cuda. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods for 3 args with const qualifier. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed a couple of kernel signatures for broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels signatures for const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise methods to persistent buffers and shapes usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with scalar kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored indexreduce kernels signatures to use const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise bool kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored random special ops to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored native ops to conform with const shapes and buffers under cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetical changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes and buffers error. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected start pos routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored helpers to use proper methods instead. Signed-off-by: shugeo <sgazeos@gmail.com> * bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected const shape cases with sort and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes for sort. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored kernel declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernel declarations to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed segment helpers kernels declarations and so on to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with segment and solve helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernel declaration with adjustWeight helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed cuda implementations for constant shape helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted const shape usage with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted top_k kernels to use const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernels declarations to adopt const shapes with helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored NDArray definitions to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes with image suppression helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Slight improvement with buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with definitions. Signed-off-by: shugeo <sgazeos@gmail.com> * minor updates on cpu side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored const shape usage with ConstantDescritor and native ops with cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tear and tile kernels to adopt with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * softmax_loop fix Signed-off-by: raver119 <raver119@gmail.com> * update missing signature Signed-off-by: raver119@gmail.com <raver119@gmail.com> * softmax again Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more missing consts Signed-off-by: raver119 <raver119@gmail.com> * new methods updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
439 lines
16 KiB
C++
439 lines
16 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
|
//
|
|
|
|
#include<ops/declarable/CustomOperations.h>
|
|
#include<ops/declarable/helpers/transforms.h>
|
|
#include<array>
|
|
|
|
namespace sd {
|
|
namespace ops {
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
|
|
|
|
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
|
|
|
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
|
|
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
|
|
|
// first of all take into account possible presence of empty arrays
|
|
// also if scalar is present -> copy its value to vector with length=1
|
|
std::vector<const NDArray*> nonEmptyArrs;
|
|
std::vector<int> arrsToDelete;
|
|
int index = 0;
|
|
bool allOfSameType = true;
|
|
auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
|
|
auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();
|
|
|
|
for(int i = 0; i < numOfInArrs; ++i) {
|
|
auto input = INPUT_VARIABLE(i);
|
|
auto currentRank = input->rankOf();
|
|
|
|
// TODO: follow two lines are in accordance to current tf.concat spec. Commented for compatibility with legacy
|
|
// REQUIRE_TRUE(currentRank > 0, 0, "Rank of input variable %i must be greater 0, but is %lld instead.", i, currentRank);
|
|
// REQUIRE_TRUE(rankOfFirstArr == currentRank, 0, "Number of dimensions in concat should be equals, but for %i input variable %lld != %lld appears.", i, currentRank, rankOfFirstArr);
|
|
if(!input->isEmpty()) {
|
|
|
|
allOfSameType &= (typeOfFirstArr == input->dataType());
|
|
|
|
if(input->rankOf() == 0) {
|
|
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
|
|
vec->assign(input);
|
|
nonEmptyArrs.push_back(vec);
|
|
arrsToDelete.push_back(index);
|
|
}
|
|
else{
|
|
nonEmptyArrs.push_back(input);
|
|
}
|
|
++index;
|
|
}
|
|
}
|
|
|
|
const int numOfNonEmptyArrs = nonEmptyArrs.size();
|
|
|
|
if(numOfNonEmptyArrs == 0){
|
|
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
|
|
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT op: If all input variables are empty, output must be empty");
|
|
return Status::OK();
|
|
}
|
|
|
|
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
|
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
|
if(axis < 0){
|
|
axis += rank;
|
|
}
|
|
|
|
// ******** input validation ******** //
|
|
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
|
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
|
|
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
|
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
|
|
|
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
|
|
for(int dim = 0; dim < rank; ++dim)
|
|
if(dim != axis)
|
|
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
|
}
|
|
// ******** end of input validation ******** //
|
|
|
|
auto output = OUTPUT_VARIABLE(0);
|
|
|
|
if(numOfNonEmptyArrs == 1)
|
|
output->assign(nonEmptyArrs[0]);
|
|
else
|
|
helpers::concat(block.launchContext(), nonEmptyArrs, *output, axis);
|
|
|
|
// delete dynamically allocated vectors with length=1
|
|
for(int index : arrsToDelete)
|
|
delete nonEmptyArrs[index];
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
DECLARE_SYN(ParallelConcat, concat);
|
|
DECLARE_SYN(concat_v2, concat);
|
|
DECLARE_SYN(concatv2, concat);
|
|
|
|
DECLARE_TYPES(concat) {
|
|
getOpDescriptor()
|
|
->setAllowedInputTypes(sd::DataType::ANY);
|
|
// ->setSameMode(true);
|
|
}
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
DECLARE_SHAPE_FN(concat) {
|
|
|
|
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT op: No input arrays were provided");
|
|
|
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
|
|
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
|
|
|
// first of all take into account possible presence of empty arrays
|
|
// also if scalar is present -> use the shape of vector with length=1 instead
|
|
ShapeList arrShapes;
|
|
std::vector<int> shapesToDelete;
|
|
int index = 0;
|
|
for(int i = 0; i < numOfInArrs; ++i) {
|
|
|
|
if(inputShape->at(i)[0] == 0) {
|
|
if (shape::isEmpty(inputShape->at(i)))
|
|
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(0, INPUT_VARIABLE(0)->dataType()));
|
|
else
|
|
arrShapes.push_back(ConstantShapeHelper::getInstance()->vectorShapeInfo(1, INPUT_VARIABLE(0)->dataType()));
|
|
}
|
|
else{
|
|
arrShapes.push_back(inputShape->at(i));
|
|
}
|
|
++index;
|
|
}
|
|
|
|
const int numOfNonEmptyArrs = arrShapes.size();
|
|
|
|
const int rank = shape::rank(arrShapes.at(0));
|
|
|
|
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
|
if(axis < 0){
|
|
axis += rank;
|
|
}
|
|
|
|
// ******** input validation ******** //
|
|
REQUIRE_TRUE(0 <= axis && axis < rank, 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
|
|
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
|
REQUIRE_TRUE(shape::rank(arrShapes.at(i)) == rank, 0, "CONCAT op: all input arrays must have the same rank !");
|
|
|
|
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
|
|
for(int dim = 0; dim < rank; ++dim)
|
|
if(dim != axis)
|
|
REQUIRE_TRUE(arrShapes.at(i)[dim+1] == arrShapes.at(0)[dim+1], 0, "CONCAT op: all input arrays must have the same dimensions (except those on input axis) !");
|
|
}
|
|
// ******** end of input validation ******** //
|
|
|
|
|
|
Nd4jLong* outShapeInfo(nullptr);
|
|
COPY_SHAPE(arrShapes.at(0), outShapeInfo);
|
|
|
|
// case when we have only one input array
|
|
if(numOfNonEmptyArrs == 1) {
|
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0)));
|
|
return SHAPELIST(CONSTANT(outShapeInfo));
|
|
}
|
|
|
|
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
|
outShapeInfo[axis + 1] += arrShapes.at(i)[axis + 1];
|
|
|
|
ShapeUtils::updateStridesAndType(outShapeInfo, arrShapes.at(0), shape::order(arrShapes.at(0)));
|
|
|
|
// delete dynamically allocated vectors shapes with length=1
|
|
// for(int index : shapesToDelete)
|
|
// RELEASE(arrShapes[index], block.getWorkspace());
|
|
|
|
auto result = ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(outShapeInfo));
|
|
RELEASE(outShapeInfo, block.getWorkspace());
|
|
return SHAPELIST(result);
|
|
}
|
|
|
|
|
|
// //////////////////////////////////////////////////////////////////////////
|
|
// CUSTOM_OP_IMPL(concat, -1, 1, false, 0, -2){
|
|
// // do something here{
|
|
// NDArray<T> *last = INPUT_VARIABLE((int) block.width() - 1);
|
|
|
|
// int _dimension = 0;
|
|
// if (block.numI() > 0)
|
|
// _dimension = INT_ARG(0);
|
|
// else {
|
|
// _dimension = (int) last->e(0);
|
|
// }
|
|
|
|
// // we want to ensure that all
|
|
// NDArray<T> *first = nullptr;
|
|
// auto output = OUTPUT_VARIABLE(0);
|
|
|
|
// int elements = 0;
|
|
|
|
// for (int e = 0; e < block.width(); e++) {
|
|
// auto arr = INPUT_VARIABLE(e);
|
|
// if (!arr->isEmpty())
|
|
// elements++;
|
|
|
|
// // we must find first non-empty element here
|
|
// if (!arr->isEmpty() && first == nullptr)
|
|
// first = arr;
|
|
// }
|
|
|
|
// REQUIRE_TRUE(first != nullptr, 0, "Concat: at least 1 non-empty input required!");
|
|
|
|
// // it's possible to get into situation when your input has only 1 input. That's just assign
|
|
// if (elements == 1) {
|
|
// output->assign(first);
|
|
// return Status::OK();
|
|
// }
|
|
|
|
// bool oldScalars = first->rankOf() == 2 && first->isScalar();
|
|
|
|
// auto buffers = new Nd4jPointer[elements];
|
|
// auto shapes = new Nd4jPointer[elements];
|
|
|
|
// buffers[0] = (Nd4jPointer) first->buffer();
|
|
// shapes[0] = (Nd4jPointer) first->shapeInfo();
|
|
|
|
// if (_dimension < 0)
|
|
// _dimension += first->rankOf();
|
|
|
|
// if (sd::Environment::getInstance()->isDebugAndVerbose()) {
|
|
// printf("Shape %i: ", 0);
|
|
// shape::printShapeInfoLinear((Nd4jLong *) shapes[0]);
|
|
// }
|
|
|
|
// int er = 0;
|
|
// for (int e = 0; e < block.width(); e++) {
|
|
// Variable<T> *var = block.variable(e);
|
|
// auto array = var->getNDArray();
|
|
|
|
// if (array->isEmpty())
|
|
// continue;
|
|
|
|
// buffers[er] = reinterpret_cast<Nd4jPointer>(array->buffer());
|
|
// shapes[er++] = reinterpret_cast<Nd4jPointer>(array->shapeInfo());
|
|
|
|
// oldScalars &= array->rankOf() == 2 && array->isScalar();
|
|
|
|
// if (sd::Environment::getInstance()->isDebugAndVerbose()) {
|
|
// printf("Shape %i: ", e);
|
|
// shape::printShapeInfoLinear(array->shapeInfo());
|
|
// }
|
|
// }
|
|
// if (sd::Environment::getInstance()->isDebugAndVerbose())
|
|
// fflush(stdout);
|
|
|
|
// if (oldScalars) {
|
|
// nd4j_debug("OLD_SCALARS!\n","");
|
|
// _dimension = 1;
|
|
// }
|
|
|
|
// sd::SpecialMethods<T>::concatCpuGeneric(_dimension, elements, buffers, shapes, output->buffer(), output->shapeInfo());
|
|
|
|
// STORE_RESULT(*output);
|
|
|
|
// if (sd::Environment::getInstance()->isDebugAndVerbose())
|
|
// output->printShapeInfo("Concat result shape");
|
|
|
|
// delete[] buffers;
|
|
// delete[] shapes;
|
|
|
|
// return ND4J_STATUS_OK;
|
|
// }
|
|
|
|
// DECLARE_SYN(ParallelConcat, concat);
|
|
// DECLARE_SYN(concat_v2, concat);
|
|
// DECLARE_SYN(concatv2, concat);
|
|
|
|
// DECLARE_SHAPE_FN(concat) {
|
|
// auto inp = inputShape->at(0);
|
|
// int _dimension = INT_ARG(0);
|
|
|
|
// NDArray<T>* first = nullptr;
|
|
// auto last = inputShape->at(inputShape->size() - 1);
|
|
|
|
// Nd4jLong elements = 0;
|
|
// Nd4jLong *newShape;
|
|
|
|
// for (int e = 0; e < inputShape->size(); e++) {
|
|
// auto s = INPUT_VARIABLE(e);
|
|
|
|
// if (!s->isEmpty()) {
|
|
// elements++;
|
|
|
|
// if (first == nullptr)
|
|
// first = s;
|
|
// }
|
|
// }
|
|
|
|
|
|
// { // special cases for 0D concat
|
|
// bool allScalars = true;
|
|
// bool hasScalars = false;
|
|
// for (int e = 0; e < block.width(); e++) {
|
|
// auto c = INPUT_VARIABLE(e);
|
|
|
|
// if (c->isEmpty())
|
|
// continue;
|
|
|
|
// allScalars &= c->rankOf() == 0;
|
|
// hasScalars |= c->rankOf() == 0;
|
|
// }
|
|
|
|
// // all scalars
|
|
// if (allScalars) {
|
|
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
|
|
|
|
// shape::shapeBuffer(1, &elements, newShape);
|
|
// return SHAPELIST(newShape);
|
|
// }
|
|
|
|
// // any scalar
|
|
// if (hasScalars) {
|
|
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong);
|
|
// Nd4jLong length = shape::length(inp);
|
|
// for (int i = 1; i < block.width(); i++) {
|
|
// auto c = INPUT_VARIABLE(i);
|
|
// if (c->isEmpty())
|
|
// continue;
|
|
|
|
// length += c->lengthOf();
|
|
// }
|
|
|
|
// shape::shapeBuffer(1, &length, newShape);
|
|
// return SHAPELIST(newShape);
|
|
// }
|
|
// }
|
|
|
|
|
|
// ALLOCATE(newShape, block.getWorkspace(), shape::shapeInfoLength(first->shapeInfo()), Nd4jLong);
|
|
|
|
// if (_dimension < 0)
|
|
// _dimension += first->rankOf();
|
|
|
|
// std::memcpy(newShape, first->shapeInfo(), shape::shapeInfoByteLength(first->shapeInfo()));
|
|
// for (int i = 0; i < inputShape->size(); i++) {
|
|
// auto s = INPUT_VARIABLE(i);
|
|
|
|
// // FIXME: s == first is bad, but fast. alternatively we can subtract first size out of result
|
|
// if (s->isEmpty() || s == first)
|
|
// continue;
|
|
|
|
// newShape[_dimension + 1] += shape::shapeOf(inputShape->at(i))[_dimension];
|
|
// }
|
|
|
|
// shape::updateStrides(newShape, first->ordering());
|
|
|
|
// return SHAPELIST(newShape);
|
|
// }
|
|
|
|
//////////////////////////////////////////////////////////////////////////
|
|
CUSTOM_OP_IMPL(concat_bp, -1, -1, false, 0, 0) {
|
|
|
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
|
|
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
|
|
|
auto epsilonNext = INPUT_VARIABLE(numOfInArrs - 1);
|
|
|
|
auto first = INPUT_VARIABLE(0);
|
|
|
|
const int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + INPUT_VARIABLE(0)->rankOf());
|
|
|
|
int startPos = 0;
|
|
|
|
for (int e = 0; e < numOfInArrs - 1; e++) {
|
|
auto originalChunk = INPUT_VARIABLE(e);
|
|
auto epsilonChunk = OUTPUT_VARIABLE(e);
|
|
std::vector<Nd4jLong> indices(2 * epsilonNext->rankOf());
|
|
|
|
int width = originalChunk->sizeAt(axis);
|
|
|
|
for (int e = 0; e < epsilonNext->rankOf(); e++) {
|
|
if (e == axis)
|
|
indices[2*e + 1] = (indices[2*e] = startPos) + width;
|
|
else
|
|
indices[2*e + 1] = indices[2*e] = 0;
|
|
}
|
|
|
|
auto subarray = (*epsilonNext)(indices, true);
|
|
epsilonChunk->assign(subarray);
|
|
|
|
startPos += width;
|
|
}
|
|
|
|
return ND4J_STATUS_OK;
|
|
}
|
|
|
|
DECLARE_TYPES(concat_bp) {
|
|
getOpDescriptor()
|
|
->setAllowedInputTypes(sd::DataType::ANY)
|
|
->setAllowedOutputTypes({ALL_FLOATS});
|
|
}
|
|
|
|
DECLARE_SHAPE_FN(concat_bp) {
|
|
|
|
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
|
|
|
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
|
|
|
auto shapeList = SHAPELIST();
|
|
|
|
for (int e = 0; e < numOfInArrs - 1; e++) {
|
|
auto inShape = inputShape->at(e);
|
|
shapeList->push_back(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(ArrayOptions::dataType(inShape), shape::order(inShape), shape::shapeOf(inShape), shape::rank(inShape))));
|
|
}
|
|
|
|
return shapeList;
|
|
}
|
|
|
|
|
|
}
|
|
}
|