/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author raver119@gmail.com
//

#ifndef LIBND4J_HELPER_PTRMAP_H
#define LIBND4J_HELPER_PTRMAP_H

#ifdef __CUDACC__
#define ptr_def __host__ __device__ inline
#else
#define ptr_def inline
#endif

namespace sd {

    /**
     * This class is a simple wrapper to represent batch arguments as single surface of parameters.
     * So we pass batch parameters as single surface, and then we use this helper to extract arguments for each aggregates.
     *
     * Surface map format is simple:
     * [0] we put numbers for num*Arguments
     * [1] then we put indexing arguments, since their size is constant
     * [2] here we put block of JVM IntArrays by value, batchLimit * maxIntArrays * maxArraySize;
     * [3] then we put real arguments
     * [4] then we put arguments pointers
     * [5] then we put shape pointers
     *
     */
    template <typename T>
    class PointersHelper {
    private:
        int aggregates;
        void *ptrGeneral;

        // we enforce maximal batch size limit, to simplify
#ifdef __CUDACC__
        const int batchLimit = 8192;
#else
        const int batchLimit = 512;
#endif

        // we have 5 diff kinds of arguments: arguments, shapeArguments, intArrayArguments, indexArguments, realArguments
        const int argTypes = 5;

        int maxIntArrays;
        int maxArraySize;

        // right now we hardcode maximas, but we'll probably change that later
        int maxIndexArguments;
        int maxRealArguments;

        // since that's pointers (which is 64-bit on 64bit systems), we limit number of maximum arguments to 1/2 of maxIndex arguments
        int maxArguments;
        int maxShapeArguments;

        int sizeT;
        int sizePtr;
    public:

        /**
         * We accept single memory chunk and number of jobs stored.
         *
         * @param ptrToParams pointer to "surface"
         * @param numAggregates actual number of aggregates being passed in
         * @return
         */
#ifdef __CUDACC__
        __host__ __device__
#endif
        PointersHelper(void *ptrToParams, int numAggregates, int maxArgs, int maxShapes, int maxIntArrays, int maxIntArraySize, int maxIdx, int maxReals) {
            aggregates = numAggregates;
            ptrGeneral = ptrToParams;

            // ptrSize for hypothetical 32-bit compatibility
            sizePtr = sizeof(ptrToParams);

            // unfortunately we have to know sizeOf(T)
            sizeT = sizeof(T);

            this->maxIntArrays = maxIntArrays;
            this->maxArraySize = maxIntArraySize;
            this->maxIndexArguments = maxIdx;
            this->maxArguments = maxArgs;
            this->maxShapeArguments = maxShapes;
            this->maxRealArguments = maxReals;
        }

        /**
         * This method returns point
         *
         * @param aggregateIdx
         * @return
         */

        ptr_def T **getArguments(int aggregateIdx) {
            T **aPtr = (T **) getRealArguments(batchLimit);

            return aPtr + (aggregateIdx * maxArguments);
        }

        /**
         * This method returns number of array arguments for specified aggregate
         *
         * @param aggregateIdx
         * @return
         */
        ptr_def int getNumArguments(int aggregateIdx) {
            int *tPtr = (int *) ptrGeneral;
            return tPtr[aggregateIdx * argTypes];
        }

        /**
         * This method returns set of pointers to shape aruments for specified aggregates
         *
         * @param aggregateIdx
         * @return
         */
        ptr_def Nd4jLong **getShapeArguments(int aggregateIdx) {
            Nd4jLong **sPtr = (Nd4jLong **)getArguments(batchLimit);

            return sPtr + (aggregateIdx * maxShapeArguments);
        }

        /**
         * This methor returns number of shape arguments for specified aggregate
         *
         * @param aggregateIdx
         * @return
         */
        ptr_def int getNumShapeArguments(int aggregateIdx) {
            int *tPtr = (int *) ptrGeneral;
            return tPtr[aggregateIdx * argTypes + 1];
        }

        /**
         * This method returns pointer to array of int/index arguments for specified aggregate
         *
         * @param aggregateIdx
         * @return
         */
        ptr_def int *getIndexArguments(int aggregateIdx) {
            // we skip first numeric num*arguments
            int *ptr = ((int *) ptrGeneral) + (batchLimit * argTypes);

            // and return index for requested aggregate
            return ptr + (aggregateIdx * maxIndexArguments) ;
        }

        /**
         * This method returns number of int/index arguments for specified aggregate
         *
         * @param aggregateIdx
         * @return
         */
        ptr_def int getNumIndexArguments(int aggregateIdx) {
            int *tPtr = (int *) ptrGeneral;
            return tPtr[aggregateIdx * argTypes + 2];
        }

        /**
         * This method returns pointer to array of jvm IntArray arguments
         */
        ptr_def int *getIntArrayArguments(int aggregateIdx, int argumentIdx) {
            int *ptr = (int * )getIndexArguments(batchLimit);

            return ptr + (aggregateIdx * maxIntArrays * maxArraySize) + (argumentIdx * maxArraySize);
        }

        /**
         * This method returns number of jvm IntArray arguments
         */
        ptr_def int getNumIntArrayArguments(int aggregateIdx) {
            int *tPtr = (int *) ptrGeneral;
            return tPtr[aggregateIdx * argTypes + 4];
        }

        /**
         * This method returns real arguments for specific aggregate
         *
         * @param aggregateIdx
         * @return
         */
        ptr_def T *getRealArguments(int aggregateIdx) {
            // we get pointer for last batchElement + 1, so that'll be pointer for 0 realArgument
            T *ptr = (T * ) getIntArrayArguments(batchLimit, 0);

            return ptr + (aggregateIdx * maxRealArguments);
        }

        /**
         * This methor returns number of real arguments for specified aggregate
         *
         * @param aggregateIdx
         * @return
         */
        ptr_def int getNumRealArguments(int aggregateIdx) {
            int *tPtr = (int *) ptrGeneral;
            return tPtr[aggregateIdx * argTypes + 3];
        }
    };
}

#endif //LIBND4J_HELPER_PTRMAP_H