/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author Yurii Shyrma (iuriish@yahoo.com)
//

#include <algorithm>
#include <helpers/ShapeUtils.h>
#include <climits>
#include <numeric>
#include <algorithm>
#include <set>
#include <flatbuffers/util.h>


namespace sd {

//////////////////////////////////////////////////////////////////////////
// evaluate shape for array resulting from tensorDot operation, also evaluate shapes and dimensions permutations for transposition of two input arrays
std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const Nd4jLong* aShapeInfo, const Nd4jLong* bShapeInfo, std::vector<int> axesA, std::vector<int> axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt) {

    int axeAsize = (int) axesA.size();
    int axeBsize = (int) axesB.size();
    int aRank = aShapeInfo[0];
    int bRank = bShapeInfo[0];

    if(axeAsize != axeBsize)
        throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the numbers of a axes and b axes to make dot product along must have identical values !");
    if(axeAsize > aRank || axeBsize > bRank)
        throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the length of vector of a or b axes is larger than array rank !");

    // axes validation
    for (int i = 0; i < axeBsize; i++) {
        if (axesA[i] < 0)
            axesA[i] += aRank;
        if (axesB[i] < 0)
            axesB[i] += bRank;
        if (aShapeInfo[axesA[i] + 1] != bShapeInfo[axesB[i] + 1])
            throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the dimensions at given axes for both input arrays must be the same !");
    }

    // check whether axesA and axesB contain only unique numbers
    std::set<Nd4jLong> uniqueElems(axesA.begin(), axesA.end());
    if((int)uniqueElems.size() != axeAsize)
        throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the vector of a axes contains duplicates !");
    uniqueElems.clear();
    uniqueElems = std::set<Nd4jLong>(axesB.begin(), axesB.end());
    if((int)uniqueElems.size() != axeBsize)
        throw std::runtime_error("ShapeUtils::evalShapeForTensorDot method: the vector of b axes contains duplicates !");

    std::vector<int> list_A, list_B;
    for (int i = 0; i < aRank; i++)
        if (std::find(axesA.begin(), axesA.end(), i) == axesA.end())
            list_A.emplace_back(i);
    for (int i = 0; i < bRank; i++)
        if (std::find(axesB.begin(), axesB.end(), i) == axesB.end())
            list_B.emplace_back(i);

    permutAt = list_A;
    permutAt.insert(permutAt.end(), axesA.begin(), axesA.end());
    permutBt = axesB;
    permutBt.insert(permutBt.end(), list_B.begin(), list_B.end());

    // if permut contains something like {0,1,2,..rank-1}, then there is no need to make permutation and we return empty vector in this case
    uint i1, i2;
    for(i1 = 0; i1 < aRank; ++i1)
        if(permutAt[i1] != i1)
            break;
    if(i1 == aRank)
        permutAt = {};
    for(i2 = 0; i2 < bRank; ++i2)
        if(permutBt[i2] != i2)
            break;
    if(i2 == bRank)
        permutBt = {};

    Nd4jLong n2 = 1;
    for (int i = 0; i < axeAsize; i++)
        n2 *= aShapeInfo[axesA[i] + 1];
    shapeAt = {shape::length(aShapeInfo) / n2, n2};

    std::vector<Nd4jLong> oldShapeA;
    oldShapeA.resize(list_A.size());
    for (int i = 0; i < oldShapeA.size(); ++i)
        oldShapeA[i] = aShapeInfo[list_A[i] + 1];


    Nd4jLong n3 = 1;
    for (int i = 0; i < axeBsize; i++)
        n3 *= bShapeInfo[axesB[i] + 1];
    shapeBt = {n3, shape::length(bShapeInfo) / n3};

    std::vector<Nd4jLong> oldShapeB;
    oldShapeB.resize(list_B.size());
    for (int i = 0; i < oldShapeB.size(); i++)
        oldShapeB[i] = bShapeInfo[list_B[i] + 1];

    std::vector<Nd4jLong> aPlusB(oldShapeA);
    aPlusB.insert(aPlusB.end(), oldShapeB.begin(), oldShapeB.end());

    return aPlusB;
}

//////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::evalShapeForTensorDot(const NDArray* a,   const NDArray* b,  const std::vector<int>& axesA, const std::vector<int>& axesB, std::vector<int>& permutAt, std::vector<int>& permutBt, std::vector<Nd4jLong>& shapeAt, std::vector<Nd4jLong>& shapeBt) {

    return evalShapeForTensorDot(a->shapeInfo(), b->shapeInfo(), axesA, axesB, permutAt, permutBt, shapeAt, shapeBt);
}


//////////////////////////////////////////////////////////////////////////
// evaluate output shape for reduce operation when input shape is empty
    const Nd4jLong* ShapeUtils::evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, sd::memory::Workspace* workspace) {

    if (dimsToExclude.size() == 0) {   // return copy of input shape
        Nd4jLong* outShapeInfo = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dataType, true, workspace);
        ShapeDescriptor descriptor(outShapeInfo, dataType);
        RELEASE(outShapeInfo, workspace);
        return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
    }

    const int rank = shape::rank(shapeInfo);
    Nd4jLong* outShapeInfo = nullptr;

    if (dimsToExclude.size() == rank) {  // return scalar or shape filled with unities

        if(!keepDims)
            outShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace);
        else
            outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, std::vector<Nd4jLong>(rank, 1), workspace);
    }
    else {

        shape::checkDimensions(rank, dimsToExclude);

        std::vector<Nd4jLong> outShape;

        if(keepDims) {
            outShape.assign(shapeInfo + 1, shapeInfo + 1 + rank);
            for(const auto& dim : dimsToExclude)
                outShape[dim] = 1;
        }
        else {
            for (uint i = 0, j = 0; i < rank; ++i) {
                if(j < dimsToExclude.size() && i == dimsToExclude[j])
                    ++j;
                else
                    outShape.emplace_back(shapeInfo[i + 1]);
            }
        }

        outShapeInfo = ShapeBuilders::createShapeInfo(dataType, order, outShape, workspace);
    }

    ShapeDescriptor descriptor(outShapeInfo, dataType);
    RELEASE(outShapeInfo, workspace);
    return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
}

    const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
        return evalReduceShapeInfo(order, dimsToExclude, arr, arr.dataType(), keepDims, supportOldShapes, workspace);
    }

    const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong* shapeInfo, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
        return evalReduceShapeInfo(order, dimsToExclude, shapeInfo, ArrayOptions::dataType(shapeInfo), keepDims, supportOldShapes, workspace);
    }

//////////////////////////////////////////////////////////////////////////
    const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const NDArray& arr, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {
        return evalReduceShapeInfo(order, dimsToExclude, arr.shapeInfo(), dataType, keepDims, supportOldShapes, workspace);
    }

//////////////////////////////////////////////////////////////////////////
// evaluate shape resulting from reduce operation
    const Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& dimsToExclude, const Nd4jLong *shapeInfo, const sd::DataType dataType, const bool keepDims, const bool supportOldShapes, sd::memory::Workspace* workspace) {

    if(ArrayOptions::arrayType(shapeInfo) == ArrayType::EMPTY)
        return ShapeUtils::evalReduceShapeInfoEmpty(order, dimsToExclude, shapeInfo, dataType, keepDims, workspace);

    Nd4jLong* newShapeInfo = nullptr;

    int rank = shape::rank(const_cast<Nd4jLong*>(shapeInfo));

    if (dimsToExclude.size() == 0) {                                               // return scalar or array with len=1 in this case

        if(keepDims && rank > 1) {
            ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
            newShapeInfo[0] = rank;
            for(int i = 0; i < rank; ++i)
                newShapeInfo[i+1] = 1;
            ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order);
            ArrayOptions::setDataType(newShapeInfo, dataType);

            ShapeDescriptor descriptor(newShapeInfo, dataType);
            RELEASE(newShapeInfo, workspace);
            return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
        }
        else if(supportOldShapes) {
            ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
            shape::shapeOldScalar(dataType, newShapeInfo, 'c');
            ShapeDescriptor descriptor(newShapeInfo, dataType);
            RELEASE(newShapeInfo, workspace);
            return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
        }
        else {
            newShapeInfo = ShapeBuilders::createScalarShapeInfo(dataType, workspace);
            ShapeDescriptor descriptor(newShapeInfo, dataType);
            RELEASE(newShapeInfo, workspace);
            return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
        }
    }

    shape::checkDimensions(rank, dimsToExclude);

    int dimSize = dimsToExclude.size();

    if(keepDims) {

        ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rank), Nd4jLong);
        newShapeInfo[0] = rank;
        for(int i = 0; i < rank; ++i)
            if (std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i))                       // dimsToExclude is already sorted after shape::checkDimensions() has been applied
                newShapeInfo[i+1] = 1;
            else
                newShapeInfo[i+1] = shapeInfo[i+1];

        ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order);
        ShapeDescriptor descriptor(newShapeInfo, dataType);
        RELEASE(newShapeInfo, workspace);
        return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
    }

	int newRank = rank - dimSize;
	if (newRank==0 || (dimSize==1 && dimsToExclude[0]==INT_MAX)) { 			// check whether given dimension is meant for the whole dimension

        if(supportOldShapes) {
            ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
            shape::shapeOldScalar(ArrayOptions::dataType(shapeInfo), newShapeInfo, 'c');
            ShapeDescriptor descriptor(newShapeInfo, dataType);
            RELEASE(newShapeInfo, workspace);
            return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
        }
        else {
            newShapeInfo = ShapeBuilders::createScalarShapeInfo(ArrayOptions::dataType(shapeInfo), workspace);
            ShapeDescriptor descriptor(newShapeInfo, dataType);
            RELEASE(newShapeInfo, workspace);
            return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
        }
	}

    ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(newRank), Nd4jLong);
    newShapeInfo[0] = newRank;                      // set rank
    int j=1;
    for(int i = 0; i < rank; ++i)
        if (!std::binary_search(dimsToExclude.begin(), dimsToExclude.end(), i))                       // dimsToExclude is already sorted after shape::checkDimensions() has been applied
            newShapeInfo[j++] = shapeInfo[i+1];

	//ensure whether vector has proper shape for old shape type
	if (newRank == 1 && supportOldShapes) {
        int oldValue = newShapeInfo[1];
        RELEASE(newShapeInfo, workspace);
        ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);		// set newRank = 2
        newShapeInfo[0] = 2;
        if (dimsToExclude[0] == 0) {
            newShapeInfo[1] = 1;
            newShapeInfo[2] = oldValue;
        }
        else {
            newShapeInfo[1] = oldValue;
            newShapeInfo[2] = 1;
        }
    }

	ShapeUtils::updateStridesAndType(newShapeInfo, shapeInfo, order);

	ShapeDescriptor descriptor(newShapeInfo, dataType);
	RELEASE(newShapeInfo, workspace);
	return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
}

//////////////////////////////////////////////////////////////////////////
// evaluate shape for array which is result of repeat operation applied to arr
std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr) {

    if (axis < 0)
        axis += arr.rankOf();

    if(repeats.size() != 1 && repeats.size() != arr.sizeAt(axis))
        throw std::invalid_argument("ShapeUtils::evalRepeatShape: size of repeats vector must be 1 or equal to dimension at given axis !");

    std::vector<Nd4jLong> outShape = arr.getShapeAsVector();

    if(repeats.size() == 1)
        outShape[axis] *= repeats[0];
    else
        outShape[axis] = std::accumulate(repeats.begin(), repeats.end(), 0);

    return outShape;
}

//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of permuted array
    const Nd4jLong* ShapeUtils::evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {

        if (!arr.nonNull())
            throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: array is nullptr!");

        if (rank != arr.rankOf())
            throw std::runtime_error("ShapeUtils::evalPermShapeInfo static method: wrong arguments: rank is not suitable!");

        auto shapeInfoLength = shape::shapeInfoLength(rank);

        // allocate memory for new array - shapeInfo
        Nd4jLong *shapeInfoNew = nullptr;
        ALLOCATE(shapeInfoNew, workspace, shapeInfoLength, Nd4jLong);

        // copy arr _shapeInfo into new array
        memcpy(shapeInfoNew, arr.shapeInfo(), shape::shapeInfoByteLength(rank));

        // perform buffer permutation
        shape::doPermuteShapeInfo(shapeInfoNew, dimensions, arr.lengthOf());

        if(setContigStrides)
            shape::updateStrides(shapeInfoNew, arr.ordering());

        ShapeDescriptor descriptor(shapeInfoNew);

        RELEASE(shapeInfoNew, workspace);

        return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
    }

    //////////////////////////////////////////////////////////////////////////
    // evaluate shapeInfo of permuted array
    const Nd4jLong* ShapeUtils::evalPermShapeInfo(const Nd4jLong *dimensions, const int rank, const NDArray& arr, sd::memory::Workspace* workspace) {

        std::vector<int> dims(dimensions, dimensions + rank);
        return evalPermShapeInfo(dims.data(), rank, arr, workspace);
    }

//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo of transposed array
    const Nd4jLong* ShapeUtils::evalTranspShapeInfo(const NDArray& arr, sd::memory::Workspace* workspace, const bool setContigStrides) {

        int rank = arr.rankOf();
        std::vector<int> dimensions(rank);
        for (int i = 0; i < rank; ++i)
            dimensions[i] = rank - 1 - i;

        return evalPermShapeInfo(dimensions.data(), dimensions.size(), arr, workspace, setContigStrides);
    }

//////////////////////////////////////////////////////////////////////////
    bool ShapeUtils::copyVectorPart(std::vector<int>& target, std::vector<int>& source, int rank, int offset) {
        if (source.size() < offset + rank)
            return false;

        for (int e = offset; e < offset + rank; e++)
            target.push_back(source[e]);

        return true;
    }


//////////////////////////////////////////////////////////////////////////
// return new (shorter) sorted dimensions array without dimensions that are present in input vector
std::vector<int> ShapeUtils::evalDimsToExclude(const int rank, const int dimsLen, const int* dimensions) {

    std::vector<int> newDimensions;
    if(dimsLen == 0) {                          // if input vector is empty then return whole shape range
        newDimensions.resize(rank);
        std::iota(newDimensions.begin(), newDimensions.end(), 0);   // fill with 0, 1, ... rank-1
    }
    else {
        bool isAbsent;
        for(int i=0; i<rank; ++i) {
            isAbsent = true;
            for(int j = 0; j < dimsLen; ++j) {
                int dim = dimensions[j] >= 0 ? dimensions[j] : dimensions[j] + rank;
                if(i == dim) {
                    isAbsent = false;
                    break;
                }
            }
            if(isAbsent)
                newDimensions.emplace_back(i);
        }
    }

    return newDimensions;
}

//////////////////////////////////////////////////////////////////////////
std::vector<int> ShapeUtils::evalDimsToExclude(const int rank, const std::vector<int>& dimensions) {

    return ShapeUtils::evalDimsToExclude(rank, dimensions.size(), dimensions.data());
}

//////////////////////////////////////////////////////////////////////////
// check whether 2 arrays have mutually broadcastable shapes
// shape comparison starts from the end
bool ShapeUtils::areShapesBroadcastable(const NDArray &arr1, const NDArray &arr2) {
    return areShapesBroadcastable(arr1.shapeInfo(), arr2.shapeInfo());
}

bool ShapeUtils::areShapesBroadcastable(const Nd4jLong *shapeInfo1, const Nd4jLong *shapeInfo2) {
    int minRank = shape::rank(shapeInfo1) < shape::rank(shapeInfo2) ? shape::rank(shapeInfo1) : shape::rank(shapeInfo2);

    for (int i = -1; i >= -minRank; --i)
        if (shape::sizeAt(shapeInfo1, i) != shape::sizeAt(shapeInfo2, i) && shape::sizeAt(shapeInfo1, i) != 1 && shape::sizeAt(shapeInfo2, i) != 1)
            return false;

    return true;
}

    bool ShapeUtils::areShapesBroadcastable(const std::vector<Nd4jLong>& shape1, const std::vector<Nd4jLong>& shape2) {

        const auto rank1 = shape1.size();
        const auto rank2 = shape2.size();
        const int minRank = rank1 < rank2 ? rank1 : rank2;

        for (int i = 1; i <= minRank; ++i)
            if (shape1[rank1-i] != shape2[rank2-i] && shape1[rank1-i] != 1 && shape2[rank2-i] != 1)
                return false;

        return true;
    }

    //////////////////////////////////////////////////////////////////////////
    // check the possibility of broadcast operation, if true then return shapeInfo of resulting array
    // if evalMinMax == false the array with larger rank has to be passed as first argument
    bool ShapeUtils::evalBroadcastShapeInfo(const NDArray &max, const NDArray &min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {
        return evalBroadcastShapeInfo(max.shapeInfo(), min.shapeInfo(), evalMinMax, resultShapeInfo, workspace);
    }

    bool ShapeUtils::evalBroadcastShapeInfo(const Nd4jLong *max, const Nd4jLong *min, const bool evalMinMax, const Nd4jLong*& resultShapeInfo, sd::memory::Workspace* workspace) {

        // check whether broadcast operation is possible for input arrays
        if(!areShapesBroadcastable(max, min))
            return false;

        auto maxShapeInfo = max; //max.shapeInfo();
        auto minShapeInfo = min; //min.shapeInfo();

        if(evalMinMax && (shape::rank(max) < shape::rank(min))) {
            maxShapeInfo = min;
            minShapeInfo = max;
        }

        const auto maxRank = shape::rank(maxShapeInfo);
        const auto minRank = shape::rank(minShapeInfo);

        // evaluate shapeInfo for resulting array
        if(resultShapeInfo != nullptr)
            throw std::runtime_error("std::runtime_error(ShapeUtils::evalBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");

        Nd4jLong *tmpShapeInfo = nullptr;
        ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);

        // FIXME: get rid of memcpy here
        memcpy(tmpShapeInfo, maxShapeInfo, shape::shapeInfoByteLength(maxRank));
        for (int i = 0; i < minRank; ++i)
            if((maxShapeInfo[maxRank-i] != 0 && maxShapeInfo[maxRank-i] < minShapeInfo[minRank-i]) || minShapeInfo[minRank-i] == 0)
                tmpShapeInfo[maxRank - i] = minShapeInfo[minRank-i];

        ShapeUtils::updateStridesAndType(tmpShapeInfo, DataTypeUtils::pickPairwiseResultType(maxShapeInfo, minShapeInfo), shape::order(maxShapeInfo));

        if (shape::isEmpty(max) || shape::isEmpty(min)) {
            ArrayOptions::setPropertyBit(tmpShapeInfo, ARRAY_EMPTY);
            memset(shape::stride(tmpShapeInfo), 0, shape::rank(tmpShapeInfo) * sizeof(Nd4jLong));
        }

        ShapeDescriptor descriptor(tmpShapeInfo);
        RELEASE(tmpShapeInfo, workspace);
        resultShapeInfo = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();

        return true;
    }

    //////////////////////////////////////////////////////////////////////////
    // check the possibility of broadcast operation for set of arrays, if true then return resulting broadcasted shapeInfo
    bool ShapeUtils::evalCommonBroadcastShapeInfo(const std::vector<const NDArray*>& arrays, Nd4jLong*& resultShapeInfo, memory::Workspace* workspace) {

        if(resultShapeInfo != nullptr)
            throw std::runtime_error("ShapeUtils::evalCommonBroadcastShapeInfo method: the input pointer on shapeInfo must be empty (=nullptr) !");

        int size = arrays.size();
        int maxRank = arrays[size - 1]->rankOf();

        for(int i = 0; i < size - 1; ++i) {
            if(arrays[i]->rankOf() > maxRank)
                maxRank = arrays[i]->rankOf();
            for(int j = i + 1; j < size; ++j)
                if(!areShapesBroadcastable(*arrays[i], *arrays[j]))
                    return false;
        }

        Nd4jLong *tmpShapeInfo = nullptr;
        ALLOCATE(tmpShapeInfo, workspace, shape::shapeInfoLength(maxRank), Nd4jLong);
        memset(tmpShapeInfo, 0, shape::shapeInfoByteLength(maxRank));
        tmpShapeInfo[0] = maxRank;

        for(const auto& item : arrays ) {
            for(int i = -1; i >= -item->rankOf(); --i)
                if(tmpShapeInfo[i + 1 + maxRank] < item->sizeAt(i))
                    tmpShapeInfo[i + 1 + maxRank] = item->sizeAt(i);
        }

        shape::updateStrides(tmpShapeInfo, arrays[0]->ordering());
        ArrayOptions::setDataType(tmpShapeInfo, arrays[0]->dataType());

        ShapeDescriptor descriptor(tmpShapeInfo);
        RELEASE(tmpShapeInfo, workspace);
        resultShapeInfo = const_cast<Nd4jLong*>(ConstantShapeHelper::getInstance().createShapeInfo(descriptor));

        return true;
    }


    //////////////////////////////////////////////////////////////////////////
    // return sorted vector of dimensions common (same) for two arrays, dimensions values corresponds to array with bigger rank
    // for example if arr1{2,7}, arr2{2,5,4,7} then vector = {0,3}
    std::vector<int> ShapeUtils::getDimsWithSameShape(const NDArray& arr1, const NDArray& arr2) {

        const NDArray *min, *max;

        if(arr1.rankOf() >= arr2.rankOf()) {
            max = &arr1;
            min = &arr2;
        }
        else {
            max = &arr2;
            min = &arr1;
        }

        const int rankDiff = max->rankOf() - min->rankOf();

        std::vector<int> dims;

        for (int i = 0; i < min->rankOf(); ++i)
            if (min->sizeAt(i) == max->sizeAt(rankDiff + i))
                dims.emplace_back(rankDiff + i);

        return dims;
    }

    //////////////////////////////////////////////////////////////////////////
    // evaluate shapeInfo for resulting array from tile operation
    const Nd4jLong* ShapeUtils::evalTileShapeInfo(const NDArray& arr, const std::vector<Nd4jLong>& reps, sd::memory::Workspace* workspace) {
        // check whether reps contains at least one zero (then throw exception) or whether all elements in reps are unities (then simply reshape or do nothing)
        int repsSize = reps.size();
        Nd4jLong product = 1;
        for(const auto& item : reps)
            product *= item;
        if(product == 0)
            throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");

        int rankOld = arr.rankOf();
        int diff = rankOld - repsSize;

        // evaluate new shapeInfo
        Nd4jLong* newShapeInfo = nullptr;
        if(diff < 0) {
            ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(repsSize), Nd4jLong);
            newShapeInfo[0] = repsSize;                  // set new rank
            for(int i=1; i <= -diff; ++i)
                newShapeInfo[i] = 1;                // set unities to be new dimensions at left-hand side of newShapeInfo shape place
            memcpy(newShapeInfo + 1 - diff, arr.shapeInfo() + 1, rankOld*sizeof(Nd4jLong));       // copy old dimensions to the right-hand side of newShapeInfo shape place
            for(int i=1; i <= repsSize; ++i)
                newShapeInfo[i] *= reps[i - 1];     // set new shape by multiplying old dimensions by corresponding numbers from reps
        }
        else {
            ALLOCATE(newShapeInfo, workspace, shape::shapeInfoLength(rankOld), Nd4jLong);
            memcpy(newShapeInfo, arr.shapeInfo(), shape::shapeInfoByteLength(rankOld));      // copy all elements of _shapeInfo to newShapeInfo
            for(int i=1; i <= repsSize; ++i)
                newShapeInfo[rankOld + 1 - i] *= reps[repsSize - i];     // set new shape by multiplying old dimensions by corresponding numbers from reps
        }
        shape::updateStrides(newShapeInfo, arr.ordering());
        ArrayOptions::setDataType(newShapeInfo, arr.dataType());

        ShapeDescriptor descriptor(newShapeInfo);
        RELEASE(newShapeInfo, workspace);
        return ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor).primary();
    }

    std::vector<Nd4jLong> ShapeUtils::pullShapeFromShapeInfo(const Nd4jLong *shapeInfo) {
        std::vector<Nd4jLong> shape(shape::rank(shapeInfo));
        int shapeSize = shape.size();

        for (int e = 0; e < shapeSize; e++)
            shape[e] = shape::shapeOf(shapeInfo)[e];

        return shape;
    }

    std::string ShapeUtils::shapeAsString(const NDArray* array) {
        std::string result;

        result.append("[");
        for (int e = 0; e < array->rankOf(); e++) {
            result += flatbuffers::NumToString(array->sizeAt(e));
            if (e < array->rankOf() - 1)
                result.append(", ");
        }
        result.append("]");

        return result;
    }

    std::string ShapeUtils::strideAsString(const NDArray* array) {
        std::string result;

        auto shapeBuffer = array->shapeInfo();   //Nd4jLong*
        int rank = (int)*shapeBuffer;
        result.append("[");
        for (int e = 0; e < rank; e++) {
            if (e > 0)
                result.append(",");
            Nd4jLong stride = *(shapeBuffer + rank+1+e);
            result += flatbuffers::NumToString(stride);
        }
        result.append("]");

        return result;
    }

    std::string ShapeUtils::shapeAsString(const std::vector<Nd4jLong>& shape) {
        std::string result;

        result.append("[");
        for (int e = 0; e < shape.size(); e++) {
            result += flatbuffers::NumToString(shape.at(e));
            if (e < shape.size() - 1)
                result.append(", ");
        }
        result.append("]");

        return result;
    }

    std::string ShapeUtils::shapeAsString(const Nd4jLong* shapeInfo) {

        if(!shapeInfo)
            throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !");

        std::string result;

        result.append("[");
        for (int e = 0; e < shapeInfo[0]; e++) {
            result += flatbuffers::NumToString(shapeInfo[e+1]);
            if (e < shapeInfo[0] - 1)
                result.append(", ");
        }
        result.append("]");

        return result;
    }

    std::string ShapeUtils::shapeInfoAsString(const Nd4jLong* shapeInfo) {

        if(!shapeInfo)
            throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !");

        std::string result;

        int len = shape::shapeInfoLength(shapeInfo[0]);

        result.append("[");
        for (int e = 0; e < len; e++) {
            result += flatbuffers::NumToString(shapeInfo[e]);
            if (e < len - 1)
                result.append(", ");
        }
        result.append("]");

        return result;
    }


    std::string ShapeUtils::shapeAsString(const int rank, const Nd4jLong* shapeInfo) {
        if(!shapeInfo)
            throw std::runtime_error("ShapeUtils::shapeAsString method: input shapeInfo must not be nullptr !");

        std::string result;

        result.append("[");
        for (int e = 0; e < rank; e++) {
            result += flatbuffers::NumToString(shapeInfo[e]);
            if (e < rank - 1)
                result.append(", ");
        }
        result.append("]");

        return result;
    }

//////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::shapeAsVector(const Nd4jLong* shapeInfo) {

    if(!shapeInfo)
        throw std::runtime_error("ShapeUtils::shapeAsVector method: input shapeInfo must not be nullptr !");

    std::vector<Nd4jLong> vector(shapeInfo[0]);

    for (uint e = 0; e < shapeInfo[0]; e++)
        vector[e] = shapeInfo[e + 1];

    return vector;
}

//////////////////////////////////////////////////////////////////////////
// evaluate shapeInfo for diagonal array which is made using input arr elements as diagonal
    const Nd4jLong* ShapeUtils::evalDiagShapeInfo(const Nd4jLong* shapeInfoConst, sd::memory::Workspace* workspace){
        auto shapeInfo = const_cast<Nd4jLong*>(shapeInfoConst);

        const auto rank = shape::rank(shapeInfo);

        Nd4jLong* outputShapeInfo = nullptr;

        if(shape::isVector(shapeInfo) || shape::isScalar(shapeInfo)) {
            ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2), Nd4jLong);
            outputShapeInfo[0] = 2;
            outputShapeInfo[1] = outputShapeInfo[2] = shape::length(shapeInfo);
        }
        else {
            ALLOCATE(outputShapeInfo, workspace, shape::shapeInfoLength(2*rank), Nd4jLong);
            outputShapeInfo[0] = 2*rank;
            for(int i = 1; i <= rank; ++i)
                outputShapeInfo[i] = outputShapeInfo[i + rank] = shapeInfo[i];
        }

        ShapeUtils::updateStridesAndType(outputShapeInfo, shapeInfo, shape::order(shapeInfo));

        auto result = ConstantShapeHelper::getInstance().createShapeInfo(outputShapeInfo);
        RELEASE(outputShapeInfo, workspace);
        return result;
    }

std::vector<int> ShapeUtils::evalBroadcastBackwardAxis(const Nd4jLong *operandShapeInfo, const Nd4jLong *resultShapeInfo) {
    // rRank >= oRank always  !!
    const auto oRank = shape::rank(operandShapeInfo);
    const auto rRank = shape::rank(resultShapeInfo);
    const auto diff  = rRank - oRank;
    std::vector<int> axis;

    for(int i = 0; i < rRank; ++i)
        if(i < diff || shape::sizeAt(operandShapeInfo, i - diff) != shape::sizeAt(resultShapeInfo, i))
            axis.push_back(i);

    return axis;
}

////////////////////////////////////////////////////////////////////////////////
    const Nd4jLong* ShapeUtils::matrixProductShape(const Nd4jLong* theFirstShape, const Nd4jLong* theSecondShape, bool shouldTranspondFirst, bool shouldTranspondSecond, sd::DataType  dtype, sd::memory::Workspace* workspace) {
        auto inA = theFirstShape;
        auto inB = theSecondShape;
        Nd4jLong *shape;
        ALLOCATE(shape, workspace, shape::shapeInfoLength(2), Nd4jLong);

        Nd4jLong* tmpA = ShapeBuilders::copyShapeInfo(inA, true, workspace);
        Nd4jLong* tmpB = ShapeBuilders::copyShapeInfo(inB, true, workspace);

        if (shouldTranspondFirst)
            shape::transposeInplace(tmpA);

        if (shouldTranspondSecond)
            shape::transposeInplace(tmpB);


        if (shape::rank(tmpA) == 1 && shape::isMatrix(tmpB)) {
            // special case here
            shape[0] = 1;
            shape[1] = tmpB[2];
            Nd4jLong *newShape = ShapeBuilders::createShapeInfo(dtype, 'f', 2, shape, workspace);

            RELEASE(shape, workspace);
            RELEASE(tmpA, workspace);
            RELEASE(tmpB, workspace);

            return newShape;
        } else if (shape::isScalar(tmpA) && shape::isScalar(tmpB)) {
            // just scalar vs scalar
            shape[0] = 1;
            shape[1] = 1;
        }  else if (shape::isMatrix(tmpA) && shape::isVector(tmpB)) {
            // gemv case
            if (shape::rank(tmpB) == 2) {
                shape[0] = tmpA[1];
                shape[1] = tmpB[2];
            } else {
                // we have new 1D shape here
                auto newShape = ShapeBuilders::createVectorShapeInfo(dtype, tmpA[1], workspace);

                RELEASE(shape, workspace);
                RELEASE(tmpA, workspace);
                RELEASE(tmpB, workspace);

                return newShape;
            }
        } else if ((shape::isMatrix(tmpA) && shape::isMatrix(tmpB)) ||
                   (shape::isVector(tmpA) && shape::isMatrix(tmpB)) ||
                   (shape::isColumnVector(tmpA) && shape::isVector(tmpB))) {
            // gemm case
            shape[0] = tmpA[1];
            shape[1] = tmpB[2];
        } else if ((shape::isVector(tmpA) && shape::isScalar(tmpB)) ||
            (shape::isScalar(tmpA) && shape::isVector(tmpB))) {
            // element-wise
            shape[0] = 1;
            shape[1] = (int) sd::math::nd4j_max<Nd4jLong>(shape::length(tmpA), shape::length(tmpB));
        } else if (shape::isRowVector(tmpA) && shape::isRowVector(tmpB)) {
            // dot case
            shape[0] = 1;
            shape[1] = 1;
        } else if (shape::isRowVector(tmpA) && shape::isColumnVector(tmpB)) {
            // dot case
            shape[0] = 1;
            shape[1] = 1;
        }

        auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'f', 2, shape);

        RELEASE(shape, workspace);

        RELEASE(tmpA, workspace);
        RELEASE(tmpB, workspace);
        return newShape;
    }

////////////////////////////////////////////////////////////////////////////////
std::vector<int> ShapeUtils::evalPermutFromTo(const std::vector<Nd4jLong>& shapeFrom, const std::vector<Nd4jLong>& shapeTo) {
    auto rank = shapeFrom.size();
    if(rank != shapeTo.size())
        throw std::runtime_error("ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation !");

    if (std::equal(begin(shapeFrom), end(shapeFrom), begin(shapeTo)))       // if shapes are identical (permutation is unnecessary) then return empty vector
        return std::vector<int>();

    std::vector<int> permutation(rank, -2);                                 // vector to be returned
    std::vector<Nd4jLong> shapeTo2(shapeTo);                                     // make copy of const vector since we will change the content of shapeTo

    for(int i=0; i<rank; ++i)
        for(int j=0; j<rank; ++j)
            if(shapeFrom[i] == shapeTo2[j]) {
                permutation[j] = i;
                shapeTo2[j] = -2;                                           // mark coincidence as -2 in order to not account index of shapeTo twice
                break;
            }

    if(std::find(begin(permutation), end(permutation), -2) != end(permutation))      // if -2 is still present in vector then permutation is impossible
        throw std::runtime_error("ShapeUtils::evalPermutFromTo static method: the input shapes are not suitable for mutual permutation !");

    return permutation;
}


////////////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::composeShapeUsingDimsAndIdx(const std::vector<int>& dimsAndIdx) {
    auto size = dimsAndIdx.size();
    if(size % 2 != 0)
        throw std::runtime_error("ShapeUtils::composeShapeUsingDimsAndIdx static method: the size of input vector must be even !");

    size /= 2;

    std::vector<Nd4jLong> shape(size);
    int index;

    for(int i = 0; i < size; ++i) {
        index = dimsAndIdx[i + size];
        if(index > size-1)
            throw std::runtime_error("ShapeUtils::composeShapeUsingDimsAndIdx static method: input index is too large !");
        shape[index] = dimsAndIdx[i];
    }

    return shape;
}


////////////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::evalShapeForMatmul(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const bool transX, const bool transY) {

    const auto xRank = xShapeInfo[0];
    const auto yRank = yShapeInfo[0];

    const Nd4jLong x0Dim = transX ? xShapeInfo[xRank]   : xShapeInfo[xRank-1];
    const Nd4jLong y0Dim = transY ? yShapeInfo[yRank]   : yShapeInfo[yRank-1];
    const Nd4jLong x1Dim = transX ? xShapeInfo[xRank-1] : xShapeInfo[xRank];
    const Nd4jLong y1Dim = transY ? yShapeInfo[yRank-1] : yShapeInfo[yRank];


    if(xRank == 1 && yRank == 1) {   // dot case, output is scalar
        if(xShapeInfo[1] != yShapeInfo[1]) {
            nd4j_printf("ShapeUtils::evalShapeForMatmul method: since input arrays are vectors they must have the same length, but got x length = %i, y length = %i !", xShapeInfo[1], yShapeInfo[1]);
            throw std::invalid_argument("");
        }
        return std::vector<Nd4jLong>({});
    }


    if(xRank == 1 && yRank == 2) {  // vector x matrix, i.e. [4] x [4,5] = [5], output is vector
        if(xShapeInfo[1] != y0Dim) {
            nd4j_printf("ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str());
            throw std::invalid_argument("");
        }
        return std::vector<Nd4jLong>({y1Dim});
    }


    if(xRank == 2 && yRank == 1) {  // matrix x vector , i.e. [4,5] x [5] = [4], output is vector
        if(x1Dim != yShapeInfo[1]) {
            nd4j_printf("ShapeUtils::evalShapeForMatmul method: input arrays have inconsistent shapes for vector-matrix product: x %s, y %s !", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str());
            throw std::invalid_argument("");
        }
        return std::vector<Nd4jLong>({x0Dim});
    }


    // rest cases - usual 2Dx2D or batched mmul
    if(xRank != yRank) {
        nd4j_printf("ShapeUtils::evalShapeForMatmul static method: the ranks of arrays must be the same, but got xRank = %i and yRank = %i ! \n", xRank, yRank);
        throw std::invalid_argument("");
    }

    if(x1Dim != y0Dim) {
        nd4j_printf("ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xDim %i != yDim %i \n", x1Dim, y0Dim);
        throw std::invalid_argument("");
    }

    for(int i = 0; i < xRank - 2; ++i)
        if(xShapeInfo[i+1] != yShapeInfo[i+1]) {
            nd4j_printf("ShapeUtils::evalShapeForMatmul static method: input shapes are inconsistent: xShape = %s, yShape = %s ! \n", ShapeUtils::shapeAsString(xShapeInfo).c_str(), ShapeUtils::shapeAsString(yShapeInfo).c_str());
            throw std::invalid_argument("");
        }

    std::vector<Nd4jLong> cShape(xRank);

    // copy batch part of shape (if present)
    for(int i = 0; i < xRank - 2; ++i)
        cShape[i] = xShapeInfo[i+1];
    // copy rest part of shape (two dims: multiplication part)
    cShape[xRank-2] = x0Dim;
    cShape[xRank-1] = y1Dim;

    return cShape;
}

////////////////////////////////////////////////////////////////////////////////
Nd4jLong ShapeUtils::getNumOfSubArrs(const Nd4jLong* shapeInfo, const std::vector<int>& dimsToExclude) {

    Nd4jLong numOfSubArrs = 1;

    if(dimsToExclude.size() == shape::rank(shapeInfo) || dimsToExclude.size() == 0)     // means there is only one sub-array and it coincides with whole array
        return numOfSubArrs;

    for(const auto& dim : dimsToExclude)
        numOfSubArrs *= shapeInfo[dim + 1];

    return numOfSubArrs;
}

////////////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> ShapeUtils::evalDimsWithoutUnities(const Nd4jLong* shapeInfo) {

    std::vector<Nd4jLong> result;
    for(int i = 1; i <= shapeInfo[0]; ++i)
        if(shapeInfo[i] != 1)
            result.push_back(shapeInfo[i]);

    return result;
}

////////////////////////////////////////////////////////////////////////////////
void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const Nd4jLong* source, const char order) {

    shape::updateStrides(dest, order);
    dest[2 * dest[0] +1] = 0; //zero extra
    ArrayOptions::copyDataType(dest, source);
}

////////////////////////////////////////////////////////////////////////////////
void ShapeUtils::updateStridesAndType(Nd4jLong* dest, const DataType dtype, const char order) {

    shape::updateStrides(dest, order);
    ArrayOptions::setDataType(dest, dtype);
}

////////////////////////////////////////////////////////////////////////////////
std::vector<int> ShapeUtils::tadAxesForSimpleBroadcast(const NDArray& max, const NDArray& min) {

    const int maxRank = max.rankOf();
    const int minRank = min.rankOf();
    const int diff    = maxRank - minRank;

    Nd4jLong  numOfMinTads(1), numOfMaxTads(1);
    std::vector<int> maxTadDims;

    for(int i = 0; i < minRank; ++i) {
        if(min.sizeAt(i) == max.sizeAt(diff + i))
            maxTadDims.push_back(diff + i);
        else {
            numOfMinTads *= min.sizeAt(i);
            numOfMaxTads *= max.sizeAt(i);
        }
    }

    if(min.lengthOf() > max.lengthOf()) {   // in this case tad is max array
        for(int i = 0; i < diff; ++i)
            numOfMaxTads *= max.sizeAt(i);

        return numOfMaxTads == 1 ? maxTadDims : std::vector<int>();
    }

    return numOfMinTads == 1 ? maxTadDims : std::vector<int>();
}

void ShapeUtils::copyCertainStridesFromShapeInfo(const Nd4jLong* inShapeInfo, const int nRank, const int dimsSize, const int* dims, Nd4jLong* outStrides) {

    int yRank = shape::rank(inShapeInfo);
    auto  yOrigStride = shape::stride(inShapeInfo);

    if (yRank == nRank) {
        for (int i = 0; i < yRank; ++i) {
            // x[2,3,4] * y[2,1,4] = z[2,3,4]
            outStrides[i] = (1 == shape::sizeAt(inShapeInfo, i)) ? 0 : yOrigStride[i];
        }
    }
    else {

        auto dimEx = sd::ShapeUtils::evalDimsToExclude(nRank, dimsSize, dims);

        for (int i = 0, it = 0; i < nRank; ++i) {
            auto nCount = std::count(dimEx.cbegin(), dimEx.cend(), i);
            outStrides[i] = (0 == nCount) ? yOrigStride[it++] : 0;
            if (it == yRank)
                break;
        }
    }
}

bool ShapeUtils::areShapesEqual(const Nd4jLong* shapeInfo, const std::vector<Nd4jLong>& shapeOnly) {

    if(shape::rank(shapeInfo) != shapeOnly.size())
        return false;

    for(uint i = 0; i < shape::rank(shapeInfo); ++i)
        if(shape::shapeOf(shapeInfo)[i] != shapeOnly[i])
            return false;

    return true;
}

////////////////////////////////////////////////////////////////////////////////
std::vector<int> ShapeUtils::evalDimsForReduceOp(const int rank, const std::vector<int>& dimsToExclude) {

    std::vector<int> output = ShapeUtils::evalDimsToExclude(rank, dimsToExclude);

    for(uint j = 0; j < dimsToExclude.size(); ++j)
        output.emplace_back(dimsToExclude[j]);

    return output;
}

////////////////////////////////////////////////////////////////////////////////
/*
bool ShapeUtils::isSubArrayCase(const NDArray& arr1, const NDArray& arr2, std::vector<int>& sameDims) {

    if(!sameDims.empty())
        sameDims.clear();

    const NDArray* max = &arr1;
    const NDArray* min = &arr2;

    if(arr1.lengthOf() < arr2.lengthOf()) {
        max = &arr2;
        min = &arr1;
    }

    int numUnitiesInMin = 0;

    for (int iMax = -1, iMin = -1; iMax >= -max->rankOf() && iMin >= -min->rankOf(); ) {

        if(max->sizeAt(iMax) == 1) {      // ignore unities in shape
            --iMax;
            continue;
        }

        if(min->sizeAt(iMin) == 1) {     // ignore unities in shape
            ++numUnitiesInMin;
            --iMin;
            continue;
        }

        if(max->sizeAt(iMax) == min->sizeAt(iMin)) {
            sameDims.insert(sameDims.begin(), iMax + max->rankOf());
            --iMin;
        }

        --iMax;
    }

    return sameDims.size() + numUnitiesInMin == min->rankOf();
}
*/

}