/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author Adam Gibson
//

#ifndef LIBND4J_TAD_H
#define LIBND4J_TAD_H


#include <helpers/shape.h>
#include <system/pointercast.h>


namespace shape {
    /**
     * Dimension collapse is an algorithm
     * for collapsing singular dimensions.
     * This algorithm will adjust the dimensions
     * wrt the original.
     *
     * The algorithm has 3 components:
     * trailing ones
     * middle ones
     * beginning ones
     *
     * dimensions that are specified to reduce along
     * that are singular should be truncated
     *
     * dimensions that are specified that are singular
     * at the beginning should be removed with middle dimensions
     * decremented.
     *
     * For any time there is a no op, a collapse will
     * set the first dimension to be -1.
     *
     *
     */
    class TAD {
    public:
        Nd4jLong tadIndex = 0;
        int dimensionLength;
        int* dimension = nullptr;
        Nd4jLong const* shapeInfo = nullptr;
        Nd4jLong* tadOnlyShapeInfo = nullptr;
        Nd4jLong numTads = 0;
        int tadRank = 0;
        Nd4jLong* tadShape = nullptr;
        Nd4jLong* tadStride = nullptr;
        Nd4jLong* tadOffsets = nullptr;
        Nd4jLong tadOffsetForBlock = 0;
        int rank = 0;
        int numOnes = 0;
        //pointers to original
        int originalDimensionLength;
        int const* originalDimension = nullptr;
        Nd4jLong const* originalShapeInfo = nullptr;
        bool squeezed = false;
        bool newSqueezeDimensions = false;
        int numOnesInMiddle = 0;
        bool wholeThing = false;
        //need to track whether we create a new dimension array or not, we could have just moved the pointer forward
        //due to leading ones
        bool createdNewDimension = false;

        // special case for CUDA, we're passing in __shared__ memory pointers to be used instead of new/malloc
        void *ptrManager = nullptr;
        int *ptrOutput = nullptr;

        INLINEDEF bool dimensionsDescending(int rank, int const* dimensions, int length);

#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF TAD() {}


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void setExternalBuffers(void *ptrManager);



#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void setOutputBuffer(int *ptrOutput);

#ifdef __CUDACC__
        __host__ __device__
#endif
        /**
         * This method is for GPU mostly, it allows to initialize TAD instance with precalculated tadOnlyShapeInfo
         */
        INLINEDEF void initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, int *dimension, int dimensionLength);



#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void init(Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength);

#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void init(int index, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength);



        template <typename T>
#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void printTADsND(T *x);



#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong *out);

#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong* permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange);




#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void createTadOnlyShapeInfo();


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong lengthPerSlice(Nd4jLong const* shapeBuffer);


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong* tad2Sub(Nd4jLong index);



#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF ~TAD();


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF  int* permuteDims();


        /**
        * Compute the tad offset given a dimension.
        *
        * The general pattern for computing a tad offset is as follows:
        * Every $STRIDE that was removed (the first dimension)
        * do a jump by the major stride of the parent array
        * (stride[0] of the parent array)
        *
        * For example given a c ordered 2,2,3,2 with stride 12,6,2,1
        * A tad of dimension 1 will jump 12 every 6 tads.
        *
        * You then end up with offsets of:
        * 0
        * 1
        * 2
        * 3
        * 4
        * 5
        * 12
        * 13
        * 14
        * 15
        * 16
        * 17
        *
        * notice there are 12 tads here. This same incremental jump will happen
        * every time.
        * Note here that by default the
        * stride of element wise stride is used for the hops.
        *
        * Sometimes a jump doesn't happen. If there are less tads
        * than the stride of the dimension you removed, the
        * element wise stride will always be used.
        *
        * For example in a dimension of 0,1, you end up with offsets of:
        * 0,1,2,3,4,5
        *
        * Given that the inner most stride of the dimensions that was removed (1)
        * had a stride of 6, we never need to do a major stride jump.
        *
        */
#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong tadOffset(Nd4jLong index);


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong* tensorShape();

#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong* tad2Sub(Nd4jLong index, void *ptrManager);


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void createOffsets();


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong* shapeInfoOnlyShapeAndStride();



        /**
       * Length of a tad given
       * the shape information
       */
#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength);

/**
 * Computes the number
 * of tensors along
 * a given dimension
 */
#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF Nd4jLong tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength);


#ifdef __CUDACC__
        __host__ __device__
    INLINEDEF void createOffsetForBlock(int blockIdx) {
        this->tadOffsetForBlock = this->tadOffset(blockIdx);
    }
#endif


#ifdef __CUDACC__
        __host__ __device__
#endif
        INLINEDEF void collapse();
    };










    ////
/*
#ifdef __CUDACC__
    __host__ __device__
#endif
    INLINEDEF TAD::TAD(int tadIndex,Nd4jLong *shapeInfo,int *dimension,int dimensionLength) {
        this->tadIndex = tadIndex;
        this->init(shapeInfo, dimension, dimensionLength);
    }


#ifdef __CUDACC__
    __host__ __device__
#endif
    INLINEDEF TAD::TAD(Nd4jLong *shapeInfo,int *dimension,int dimensionLength) {
        this->init(shapeInfo, dimension, dimensionLength);
    }
        */

    INLINEDEF void TAD::setExternalBuffers(void *ptrManager) {
        this->ptrManager = ptrManager;
    }

    INLINEDEF void TAD::setOutputBuffer(int *ptrOutput) {
        this->ptrOutput = ptrOutput;
    }

    INLINEDEF void TAD::initWithExternalTAD(Nd4jLong *existingTAD, Nd4jLong *originalShape, int *dimension, int dimensionLength) {
        this->tadOnlyShapeInfo = existingTAD;
        this->rank = shape::rank(originalShape);

        this->originalShapeInfo = originalShape;
        this->originalDimension = dimension;
        this->originalDimensionLength = dimensionLength;

        this->shapeInfo = originalShape;
        this->dimension = dimension;
        this->dimensionLength = dimensionLength;

        this->tadShape = shape::shapeOf(existingTAD);
        this->tadStride = shape::stride(existingTAD);

        Nd4jLong ews = shape::elementWiseStride(originalShape);

        this->numTads = shape::length(originalShape) / shape::length(existingTAD); // this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength);//shape::length(originalShape) / shape::length(existingTAD);
        this->wholeThing = this->numTads == 1 || ((this->dimensionLength == this->rank || this->numTads == shape::length(this->shapeInfo)) && ews == 1);
    }

    INLINEDEF void TAD::init(int tadIndex, Nd4jLong const* shapeInfo,int const* dimension,int dimensionLength) {
        this->tadIndex = tadIndex;
        this->init(shapeInfo, dimension, dimensionLength);
    }

    INLINEDEF  void TAD::init(Nd4jLong const* shapeInfo, int const* dimension,int dimensionLength) {
        this->originalShapeInfo = shapeInfo;
        this->originalDimension = dimension;
        this->originalDimensionLength = dimensionLength;
        //start off as original references
        this->shapeInfo = shapeInfo;
        this->dimensionLength = dimensionLength;
        this->dimension = const_cast<int*>(dimension);
        this->rank = shape::rank(shapeInfo);
        this->numTads = dimensionLength == 0 ? 1 : this->tensorsAlongDimension(this->shapeInfo, this->dimension, this->dimensionLength);

        Nd4jLong ews = shape::elementWiseStride(shapeInfo);

        if (dimensionLength == 0) {
            wholeThing = true;
        } else if(!shape::isVector(shapeInfo)) {
            wholeThing = this->numTads == 1 // if number of TADs is 1, we just have input shape == TAD shape
                         || ((this->dimensionLength == this->rank // if number of dimensions is the same as input rank, that'll be wholeTad too, but only if EWS==1 (aka - not a View)
                         || (this->numTads == shape::length(shapeInfo) && shape::order(shapeInfo) == 'c')) // OR  number of tads equals to shapeInfo length AND input is in C order. if order is F - we'll have to calculate offsets
                         && ews == 1); // as mentioned above - last 2 rules apply only to non-views
        } else if(shape::isScalar(shapeInfo)) {
            wholeThing = true;
            //vector case
        } else {
            // if(dimensionLength == 1 && shape::shapeOf(shapeInfo)[dimension[0]] == 1) {
            //if(dimension == 0 && ) {
            if(dimensionLength != 0 && dimension != nullptr && shape::shapeOf(shapeInfo)[dimension[0]] == 1) {
                wholeThing = true;
            }
        }
    }

    template <typename T>
    INLINEDEF void TAD::printTADsND(T *x) {
        if(wholeThing) {
            for(int i = 0; i < shape::length(tadOnlyShapeInfo); i++) {
                printf(" %f ",x[i]);
            }
            printf("\n");
        }
        else {
            for (int i = 0; i <  numTads; i++) {
                auto  offset = tadOffsets[i];
                Nd4jLong shapeIter[MAX_RANK];
                Nd4jLong coord[MAX_RANK];
                int dim;
                int rankIter = shape::rank(tadOnlyShapeInfo);
                Nd4jLong xStridesIter[MAX_RANK];
                T *xPointer = x + offset;
                if (PrepareOneRawArrayIter<T>(rankIter,
                                              shape::shapeOf(tadOnlyShapeInfo),
                                              xPointer,
                                              shape::stride(tadOnlyShapeInfo),
                                              &rankIter,
                                              shapeIter,
                                              &xPointer,
                                              xStridesIter) >= 0) {
                    ND4J_RAW_ITER_START(dim, shape::rank(tadOnlyShapeInfo), coord, shapeIter); {
                        /* Process the innermost dimension */
                        printf(" %f ",xPointer[0]);
                    }
                    ND4J_RAW_ITER_ONE_NEXT(dim,
                                           rankIter,
                                           coord,
                                           shapeIter,
                                           xPointer,
                                           xStridesIter);
                    printf("\n");

                }
                else {
                    printf("Unable to prepare array\n");
                }
            }
        }
    }


    INLINEDEF void TAD::permuteShapeBufferInPlace(Nd4jLong const* shapeBuffer, int const* rearrange, Nd4jLong* out) {
        memcpy(out, shapeBuffer, sizeof(Nd4jLong) * shape::shapeInfoLength(this->rank));
        doPermuteShapeInfo(out, rearrange);
    }

    INLINEDEF Nd4jLong* TAD::permuteShapeBuffer(Nd4jLong const* shapeBuffer, int *rearrange) {
        int len = shape::shapeInfoLength(this->rank);
        Nd4jLong *copy = shape::copyOf(len,shapeBuffer);
        doPermuteShapeInfo(copy,rearrange);
        return copy;
    }

    INLINEDEF bool TAD::dimensionsDescending(int rank, int const* dimensions, int length) {
        int desired = rank - 1;
        for (int e = length - 1; e >= 0; e--) {
            if (dimensions[e] != desired--)
                return false;
        }
        return true;
    }

    INLINEDEF void TAD::createTadOnlyShapeInfo() {
        this->tadOnlyShapeInfo = this->shapeInfoOnlyShapeAndStride();
        sd::ArrayOptions::setDataType(this->tadOnlyShapeInfo, sd::ArrayOptions::dataType(this->originalShapeInfo));

        // possible optimization goes here
        if (shape::order(this->originalShapeInfo) == 'c'
                && shape::strideDescendingCAscendingF(this->originalShapeInfo)
                && dimensionsDescending(shape::rank(this->originalShapeInfo), this->originalDimension, this->originalDimensionLength)) {
            // for C order, if outer dimensions are used, continuous layout is preserved
            this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 2] = this->originalShapeInfo[shape::shapeInfoLength(this->originalShapeInfo) - 2];
        }

        // do not swap order if positive elementwise stride preserved
        if (shape::elementWiseStride(this->tadOnlyShapeInfo) >= 1) {
            this->tadOnlyShapeInfo[shape::shapeInfoLength(this->tadOnlyShapeInfo) - 1] = shape::order(this->originalShapeInfo);
        }

        if (this->tadShape != nullptr)
            delete[] this->tadShape;

        this->tadShape = shape::shapeOf(this->tadOnlyShapeInfo);
        this->tadStride = shape::stride(this->tadOnlyShapeInfo);
    }

    INLINEDEF Nd4jLong TAD::lengthPerSlice(Nd4jLong const* shapeBuffer) {
        int dimension = 0;
        Nd4jLong *remove = shape::removeIndex(shape::shapeOf(shapeBuffer),&dimension,shape::rank(shapeBuffer),1);
        Nd4jLong prod = shape::prodLong(remove, shape::rank(shapeBuffer) - 1);
        delete[] remove;
        return prod;
    }


    INLINEDEF Nd4jLong* TAD::tad2Sub(Nd4jLong index) {
        Nd4jLong *shape = shape::shapeOf(shapeInfo);
        int rank = shape::rank(shapeInfo);
        int leftOverIndexLen = rank - originalDimensionLength;

        Nd4jLong *ret = new Nd4jLong[rank];
        //shape of the tad
        Nd4jLong *tadShape = new Nd4jLong[leftOverIndexLen];
        Nd4jLong *leftOverIndexes = new Nd4jLong[leftOverIndexLen];
        Nd4jLong *sub = new Nd4jLong[rank];

        //indexes not specified in the tad indexes

        //every coordinate starts as zero
        memset(ret,0, shape::shapeInfoByteLength(rank));

        //find the length of the elements we
        //are iterating over
        Nd4jLong len = 1;
        //left over index cursor for initializing elements
        int leftOverIndex = 0;
        for(int i = 0; i < rank; i++) {
            //look for dimensions NOT found in dimension length (basically compute shape - dimension (set difference)
            bool found = false;
            for(int j = 0; j < originalDimensionLength; j++) {
                //skip over specified dimensions when computing left over length
                if(i == originalDimension[j]) {
                    found = true;
                    break;
                }

            }

            //add to the indexes that aren't specified as part of the tad dimension
            //indexes
            if(!found) {
                //accumulate the list of indexes left over used for initializing the return value
                leftOverIndexes[leftOverIndex] = i;
                //accumulate the tad shape
                tadShape[leftOverIndex] = shape[i];
                //accumulate the length (product) of the indexes that will be iterated over
                len *= shape[i];
                leftOverIndex++;

            }
        }


        //sub for indices
        /* int *sub = new int[leftOverIndexLen];
         shape::ind2subOrder(tadShape,index,len,sub);
        */
        shape::index2coords(index, leftOverIndexLen,tadShape, sub);


        for(int i = 0; i < leftOverIndexLen; i++) {
            ret[leftOverIndexes[i]] = sub[i];
        }

        if (ptrManager == nullptr) {
            delete[] tadShape;
            delete[] leftOverIndexes;
            delete[] sub;
        }

        return  ret;
    }


    INLINEDEF TAD::~TAD() {
        //we may have just moved the pointer forward, we may not need to delete the pointer here
        if(originalDimension != this->dimension && createdNewDimension) {
            delete[] this->dimension;
        }
        if(this->originalShapeInfo != this->shapeInfo) {
            delete[] this->shapeInfo;
        }
        if(this->tadOffsets != nullptr) {
            delete[] this->tadOffsets;
        }

        if(this->tadOnlyShapeInfo != nullptr && this->tadOnlyShapeInfo != shapeInfo) {
            delete[] this->tadOnlyShapeInfo;
        }
    }

    INLINEDEF int* TAD::permuteDims() {
        //permute dimensions for tad
        int dimIdx = 0;
        //loop backwards assuming dimension is sorted

        int *permuteDims = new int[shape::rank(shapeInfo)];

        for(int i = 0; i < shape::rank(shapeInfo); i++) {
            bool found = false;
            for(int j = 0; j < originalDimensionLength; j++) {
                if(i == originalDimension[j]) {
                    found = true;
                    break;
                }
            }

            //not found, append it to the end for permute
            if(!found)
                permuteDims[dimIdx++] = i;
        }



        for(int i = originalDimensionLength - 1; i >= 0; i--) {
            permuteDims[dimIdx++] = originalDimension[i];
        }

/*
            for (int i = 0; i < originalDimensionLength; i++) {
                permuteDims[i] = originalDimension[i];
            }
*/

        //permute dimensions for tad
        return permuteDims;
    }


    INLINEDEF Nd4jLong TAD::tadOffset(Nd4jLong index) {
        if(tadOnlyShapeInfo == nullptr) {
            this->createTadOnlyShapeInfo();
        }

        if(wholeThing)
            return index;

        if(dimensionLength > 1) {
            Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager);

            Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub);

            if(ret < 0) {
                if (ptrManager == nullptr)
                    delete[] tad2Sub;
                return -1;
            }
            if (ptrManager == nullptr)
                delete[] tad2Sub;

            return ret;

        }
        else {
            Nd4jLong *tad2Sub = this->tad2Sub(index, ptrManager);

            Nd4jLong ret = shape::getOffset(shapeInfo, tad2Sub);

            if (ptrManager == nullptr)
                delete[] tad2Sub;

            return ret;
        }
    }


    INLINEDEF Nd4jLong* TAD::tensorShape(){
        if(this->tadShape != nullptr)
            return this->tadShape;

        Nd4jLong *theShape = shape::shapeOf(shapeInfo);
        Nd4jLong *tensorShape = shape::keep(theShape, this->dimension, dimensionLength,shape::rank(shapeInfo));
        this->tadShape = tensorShape;
        this->tadRank = dimensionLength;
        return tensorShape;
    }

    INLINEDEF Nd4jLong* TAD::tad2Sub(Nd4jLong index, void *ptrManager) {
        auto shape = shape::shapeOf(shapeInfo);
        int rank = shape::rank(shapeInfo);
        int leftOverIndexLen = rank - originalDimensionLength;
        Nd4jLong *tadShape;
        Nd4jLong *leftOverIndexes;
        Nd4jLong *sub;
        Nd4jLong *ret;

        ret = new Nd4jLong[rank];
        //shape of the tad
        leftOverIndexes = new Nd4jLong[leftOverIndexLen];
        sub = new Nd4jLong[rank];
        tadShape = new Nd4jLong[leftOverIndexLen];

        //indexes not specified in the tad indexes

        //every coordinate starts as zero
        memset(ret,0,sizeof(Nd4jLong) * rank);


        //find the length of the elements we
        //are iterating over
        Nd4jLong len = 1;
        //left over index cursor for initializing elements
        int leftOverIndex = 0;
        for(int i = 0; i < rank; i++) {
            //look for dimensions NOT found in dimension length (basically compute shape - dimension (set difference)
            bool found = false;
            for(int j = 0; j < originalDimensionLength; j++) {
                //skip over specified dimensions when computing left over length
                if(i == originalDimension[j])  {
                    found = true;
                    break;
                }

            }

            //add to the indexes that aren't specified as part of the tad dimension
            //indexes
            if(!found) {
                //accumulate the list of indexes left over used for initializing the return value
                leftOverIndexes[leftOverIndex] = i;
                //accumulate the tad shape
                tadShape[leftOverIndex] = shape[i];
                //accumulate the length (product) of the indexes that will be iterated over
                leftOverIndex++;
                len *= shape[i];

            }
        }


        //sub for indices
        /* int *sub = new int[leftOverIndexLen];
         shape::ind2subOrder(tadShape,index,len,sub);
        */
        shape::index2coords(index, leftOverIndexLen,tadShape, sub);

        for(int i = 0; i < leftOverIndexLen; i++) {
            ret[leftOverIndexes[i]] = sub[i];
        }

        if (ptrManager == nullptr) {
            delete[] leftOverIndexes;
            delete[] tadShape;
            delete[] sub;
        }

        return  ret;
    }

    INLINEDEF void TAD::createOffsets() {
        this->tadOffsets = new Nd4jLong[this->numTads];
        uint nT = this->numTads;

        for(uint i = 0; i < nT; i++)
            this->tadOffsets[i] = this->tadOffset(i);
    }


    INLINEDEF Nd4jLong* TAD::shapeInfoOnlyShapeAndStride() {
        //if(wholeThing || (dimensionLength == 1 && dimension[0] == MAX_DIMENSION) || shape::isScalar(shapeInfo))
        //    return shape::createScalarShapeInfo();

        //ensure tad shapes get setup right for vectors
        if(dimensionLength > 1 && shape::isVector(shapeInfo))
            return shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo);

        // case when tad coincides with whole array
        if( this->numTads == 1 && ((shape::rank(originalShapeInfo) == originalDimensionLength) || originalDimensionLength == 0)) {
            // we might have special case here: skipped dimensions might be just full of ones
            Nd4jLong *ret = shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)), shapeInfo);
            if (shape::isDimPermuted<int>(dimension, (Nd4jLong) dimensionLength))    // check whether we need permutation
                doPermuteShapeInfo(ret, dimension);

            return ret;
        }

        Nd4jLong *theShape = shape::shapeOf(shapeInfo);
        int rank = shape::rank(shapeInfo);

        if(dimensionLength == 1) {
            if(dimension[0] == 0 && shape::isVector(shapeInfo) && theShape[1] == 1) {
                int permuted[2] = {1,0};
                Nd4jLong *permutedRet2 = shape::permuteShapeBuffer(shapeInfo, permuted);
                return permutedRet2;
            } else if(dimension[0] == 1 && shape::isVector(shapeInfo) && theShape[0] == 1) {
                Nd4jLong *ret = shape::copyOf(shape::shapeInfoLength(shape::rank(shapeInfo)),shapeInfo);
                return ret;
            }
            else if(shape::shapeOf(shapeInfo)[dimension[0]] == 1) {
                Nd4jLong *scalarInfo = shape::createScalarShapeInfo();
                scalarInfo[shape::shapeInfoLength(shape::rank(scalarInfo)) - 3] = this->tadIndex;
                return scalarInfo;
            }
        }

        Nd4jLong *tensorShape = this->tensorShape();
        int *reverseDimensions = shape::reverseCopy(dimension, dimensionLength);
        int *rankRange = shape::range<int>(0, rank);
        int *remove  = shape::removeIndex<int>(rankRange, dimension, (Nd4jLong) rank, (Nd4jLong) dimensionLength);
        //concat is wrong here with the length
        int *newPermuteDims = shape::concat(remove,rank - dimensionLength,reverseDimensions,dimensionLength);

        Nd4jLong* permuted = shape::permuteShapeBuffer(shapeInfo,newPermuteDims);


        Nd4jLong sliceIndex = shape::sliceOffsetForTensor(shape::rank(permuted),
                                                     this->tadIndex,
                                                     shape::shapeOf(shapeInfo),
                                                     tensorShape,
                                                     dimensionLength,
                                                     dimension,
                                                     dimensionLength);



        Nd4jLong *ret2 = shape::sliceOfShapeBuffer(sliceIndex, permuted);
        Nd4jLong tensorLength = shape::prodLong(tensorShape,tadRank);

        Nd4jLong compLength = shape::isVector(ret2) ? shape::length(ret2) : shape::prodLong(tensorShape,tadRank);
        // int temp;
        // const bool isLikeVector = shape::isLikeVector(ret2, temp);

        // if(dimensionLength == tadRank && compLength == shape::length(ret2) && !isLikeVector) {
        if(dimensionLength == tadRank && compLength == shape::length(ret2)) {
            if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) {
                //go to the bottom and return ret2 after proper freeing of pointers
                //basic idea; we *don't* permute row vectors
            }
            else if(dimensionLength > 1) {
                //permute *then* return ret2
                int *finalPermuteDims = new int[shape::rank(ret2)];
                int forward = 0;
                for(int i = shape::rank(ret2) - 1; i >= 0; i--) {
                    finalPermuteDims[forward++] = i;
                }
                shape::permuteShapeBufferInPlace(ret2,finalPermuteDims,ret2);
                delete[] finalPermuteDims;

            }

        }
        else {
            Nd4jLong length = tensorLength;
            Nd4jLong lengthPerSlice = this->lengthPerSlice(ret2);
            if(lengthPerSlice < 1) {
                return ret2;
            }

            Nd4jLong offset = tadIndex * tensorLength /lengthPerSlice;
            if(sliceIndex == 0 && length == lengthPerSlice) {
                Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset, ret2);
                delete[] ret2;
                ret2 = newRet2;
                int *finalPermuteDims = new int[shape::rank(ret2)];
                int forward = 0;
                for(int i = shape::rank(ret2) - 1; i >= 0; i--) {
                    finalPermuteDims[forward++] = i;
                }
                // bool isRowVector2 = shape::isRowVector(ret2) && !isLikeVector;
                bool isRowVector2 = shape::isRowVector(ret2);
                if(isRowVector2 == false) {
                    shape::permuteShapeBufferInPlace(ret2, finalPermuteDims, ret2);
                }

                delete[] finalPermuteDims;

            }
            else if(length == lengthPerSlice) {
                offset -= shape::slices(ret2) * (offset / shape::slices(ret2));
                Nd4jLong *newRet2 = shape::sliceOfShapeBuffer(offset,ret2);
                delete[] ret2;
                ret2 = newRet2;
                if(dimensionLength == 1 && shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) {
                    //go to the bottom and return ret2 after proper freeing of pointers
                    //basic idea; we *don't* permute row vectors
                }
                else {
                    int *finalPermuteDims = new int[shape::rank(ret2)];
                    int forward = 0;
                    for(int i = shape::rank(ret2) - 1; i >= 0; i--) {
                        finalPermuteDims[forward++] = i;
                    }
                    Nd4jLong *newRet = shape::permuteShapeBuffer(ret2, finalPermuteDims);
                    delete[] ret2;
                    delete[] finalPermuteDims;
                    ret2 = newRet;

                }

            }
            else {
                //execute final part, note that this is mainly so delete[] gets called
                //at the bottom of the method
                while(shape::length(ret2) > length) {
                    auto lengthPerSlice2 = this->lengthPerSlice(ret2);
                    sliceIndex =    sliceOffsetForTensor(sliceIndex,shape::length(ret2),lengthPerSlice2);
                    sliceIndex -= shape::slices(ret2) * (sliceIndex / shape::slices(ret2));
                    auto newRet2 = shape::sliceOfShapeBuffer(sliceIndex,ret2);
                    delete[] ret2;
                    ret2 = newRet2;
                }

                //don't permute on a row vector
                if(dimensionLength == 1 &&  shape::isVector(ret2) && shape::shapeOf(ret2)[0] == 1) {
                    //go to the bottom and return ret2 after proper freeing of pointers
                    //basic idea; we *don't* permute row vectors
                }
                else if(dimensionLength > 1){
                    //permute *then* return ret
                    int *finalPermuteDims = new int[shape::rank(ret2)];
                    int forward = 0;
                    for(int i = shape::rank(ret2) - 1; i >= 0; i--) {
                        finalPermuteDims[forward++] = i;
                    }
                    auto newPermute = shape::permuteShapeBuffer(ret2,finalPermuteDims);
                    delete[] ret2;
                    delete[] finalPermuteDims;
                    ret2 = newPermute;
                }

            }
        }


        delete[] permuted;
        delete[] newPermuteDims;
        delete[] rankRange;
        delete[] remove;
        delete[] reverseDimensions;
        return ret2;
    }


    INLINEDEF Nd4jLong TAD::tadLength(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) {
        if(dimensionLength == 1) {
            return shape::shapeOf(shapeInfo)[dimension[0]];
        }
        else {
            Nd4jLong ret = 1;
            for(int i = 0; i < shape::rank(shapeInfo); i++) {
                for(int j = 0; j < dimensionLength; j++) {
                    if(i == dimension[j])
                        ret *= shape::shapeOf(shapeInfo)[dimension[j]];
                }
            }
            return ret;
        }
    }


    INLINEDEF Nd4jLong TAD::tensorsAlongDimension(Nd4jLong const* shapeInfo, int const* dimension, int dimensionLength) {
        return shape::length(shapeInfo) / this->tadLength(shapeInfo,dimension,dimensionLength);
    }


    INLINEDEF void TAD::collapse() {
        auto shape = shape::shapeOf(shapeInfo);
        //handle negative dimensions/backwards indexing
        for(int i = 0; i < dimensionLength; i++) {
            if((dimension)[i] < 0)
                (dimension)[i] += shape::rank(this->shapeInfo);
        }

        this->dimension =  new int[dimensionLength];
        memcpy(this->dimension,this->originalDimension, sizeof(int) * dimensionLength);

        //we can drop trailing dimensions where it's all singular for example:
        // shape: 4,3,1,2
        //dimension: 0,2
        // the problem for 0,2 is equivalent to: 0
        //the rest of the algorithm handles cases suchas
        //shape: 4,1,1,2
        //dimension: 0,1
        //when this happens there are other dimensions (eg: at the end) that matter
        int trailingOneDimensions = 0;
        //trailing ones
        for(int i = dimensionLength - 1; i >= 0; i--) {
            if(shape[dimension[i]] != 1) {
                break;
            }
            else if(shape[dimension[i]] == 1)
                trailingOneDimensions++;
        }

        dimensionLength -= trailingOneDimensions;

        int leadingOneDimensions = 0;
        //trailing ones
        for(int i = 0; i < dimensionLength; i++) {
            if(shape[dimension[i]] != 1) {
                break;
            }
            else if(shape[dimension[i]] == 1)
                leadingOneDimensions++;
        }

        //bump the dimension pointer forward for however many leadingones there are
        dimension += leadingOneDimensions;
        //decrease the dimension length by the amount of leading ones
        dimensionLength -= leadingOneDimensions;


        bool preConverged = true;
        for(int i = 0; i < dimensionLength; i++) {
            if(shape[dimension[i]] == 1) {
                preConverged = false;
                break;
            }
        }

        //we took away all the singular dimensions, we can just return
        if(preConverged)
            return;

        //no more singular dimensions specified
        bool done = false;
        int onesDecrement = 0;
        bool changed = false;
        while(!done) {
            //terminate early: only singular dimensions specified for reduce
            if((dimensionLength) < 1) {
                done = true;
                //signal as a no op
                dimension[0] = -1;
                break;
            }
            //captures intermediary result from the for loop
            traceNew(3);

            int intermediaryResult[MAX_RANK];
            for(int i = 0; i < dimensionLength; i++) {
                intermediaryResult[i] = (dimension)[i];
            }

            bool oneEncountered = false;
            bool nonOneEncountered = false;
            bool hitBeginning = false;
            //assume intermediate collapsing of dimensions
            bool collapseMiddleDimensions = true;
            //note here that dimension length MAY end up being zero
            for(int i = (dimensionLength) - 1; i >= 0; i--) {
                if(shape[(dimension)[i]] == 1) {
                    oneEncountered = true;
                    //trailing ones
                    if(!nonOneEncountered) {
                        //just drop trailing ones
                        dimensionLength--;
                        nonOneEncountered = false;
                        collapseMiddleDimensions = false;
                        //intermediary result just needs to have the results copied from dimension since we're just removing the tail
                        memcpy(intermediaryResult,dimension, sizeof(int) * dimensionLength);
                        changed = true;
                        //break the for loop and force it to go back around starting from the new index
                        break;
                    }
                    else {
                        //already decremented all dimensions
                        //this was a result of hitting beginning ones
                        //we will only need to loop once
                        if(i == 0) {
                            hitBeginning = true;
                        }
                        //will need to shift dimensions that aren't trailing ones
                        //back by onesDecrement
                        //mark the intermediary result as -1 for non inclusion
                        intermediaryResult[i] = -1;
                        onesDecrement++;
                    }
                }
                else {
                    intermediaryResult[i] = (dimension)[i];
                    nonOneEncountered = true;
                }
            }

            if(collapseMiddleDimensions && oneEncountered) {
                //collapse dimensions
                int newIntermediary[MAX_RANK];
                int idx = 0;
                for(int i = 0; i < dimensionLength; i++) {
                    //of note: dimension will decrease by the number of ones encountered
                    if(intermediaryResult[i] >= 0) {
                        //dimension 0 doesn't need to be decremented
                        if(intermediaryResult[i] > 0)
                            newIntermediary[idx++] = intermediaryResult[i] - onesDecrement;
                        else
                            newIntermediary[idx++] = intermediaryResult[i];
                    }
                }


                //decrement by the number of dimensions where ones appeared
                (dimensionLength) -= onesDecrement;
                //update to current result
                memcpy(dimension,newIntermediary, sizeof(int) * (dimensionLength));
                changed = true;

            }
                //converged: no need to change result
            else {
                //update to current result
                memcpy(dimension,intermediaryResult, sizeof(int) * dimensionLength);
            }

            //converge when there are no singular dimensions specified in the reduce
            done = (!oneEncountered && nonOneEncountered) || hitBeginning;
            //delete[] intermediaryResult;
        }

        //nothing changed but need to collapse dimension
        if(!changed && this->numOnes > 0) {
            for(int i = 0; i < dimensionLength ;i++) {
                dimension[i] -= numOnes;
            }
        }


    }
}

#endif //LIBND4J_TAD_H