/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
//  @author raver119@gmail.com
//

#include <ops/declarable/helpers/dropout.h>
#include <legacy/NativeOps.h>
#include <vector>
#include <memory>
#include <exceptions/cuda_exception.h>

namespace sd {
namespace ops {
namespace helpers {

    template <typename T>
    static __global__ void dropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong const* outputShape, double probVal, int inLen, sd::graph::RandomGenerator* nodeRng) {
        auto tid = blockIdx.x * blockDim.x + threadIdx.x;
        auto step = blockDim.x * gridDim.x;
        T const* input = reinterpret_cast<T const*>(inputBuf);
        T* output = reinterpret_cast<T*>(outputBuf);

        // trivial idea: loop through all elements, get independent probability for each element to be nullified
        for (Nd4jLong e = 0; e < inLen; ++e) {
            T val = nodeRng->relativeT(e, T(0.f), T(1.f));

            // if probability is ok - we're saving scaled value
            if (double(val) < probVal)
                output[shape::getIndexOffset(e, outputShape)] = T(input[shape::getIndexOffset(e, inputShape)] / probVal);
        }
    }

    template <typename T>
    static void dropoutSimple(sd::LaunchContext* context, NDArray const* input, NDArray* output, double probValue, int seed) {
        sd::graph::RandomGenerator nodeRng(3019L, seed);
        int inLen = input->lengthOf();
        sd::graph::RandomGenerator* dRandom;
        auto stream = context->getCudaStream();
        NDArray::prepareSpecialUse({output}, {input});

        auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator));
        if (err) {
            throw cuda_exception::build("helpers::dropoutSimple: Cannot allocate device memory for random generator.", err);
        }
        err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice);
        if (err) {
            throw cuda_exception::build("helpers::dropoutSimple: Cannot set up device memory for random generator.", err);
        }

        dropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, inLen, dRandom);
        err = cudaFree(dRandom);
        if (err) {
            throw cuda_exception::build("helpers::dropoutSimple: Cannot deallocate device memory for random generator.", err);
        }
        NDArray::registerSpecialUse({output}, {input});
    }

    template <typename T>
    int _dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {

        if (reduceShape == nullptr){
            dropoutSimple<T>(context.launchContext(), input, output, probValue, seed);
        }
        else {
            REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");

            std::vector<Nd4jLong> dims(reduceShape->lengthOf());
            reduceShape->syncToHost(); // to ensure that follows are actual
            bool fit = true;

            for( int i = 0; i < dims.size(); i++ ) {
                if (fit) {
                    dims[i] = reduceShape->e<Nd4jLong>(i);
                    for (int e = 0; e < input->rankOf(); ++e)
                        if (fit)
                            if (input->sizeAt(e) % dims[i]) {
                                fit = false;
                            }
                }
            }

            // check dims to fit input
            REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
            std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
            chunk->assign(1.f);

            dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed);
            // broadcast chunk to full matrix
            std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
            dropOutMultiplier->assign(1.f);

            *dropOutMultiplier += *chunk;

            // FIXME: we could do this in one step, aren't we?
            output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
        }

        return Status::OK();
    }

    int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
        auto xType = input->dataType();
        NDArray::prepareSpecialUse({output}, {input});

        BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES);

        NDArray::registerSpecialUse({output}, {input});
    }

/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
    template <typename T>
    static __global__ void dropoutBPKernel(void* outputBuf, Nd4jLong const* outputShape, void* gradOutBuf, Nd4jLong const* gradOutShape, double probValue) {
        __shared__ T* output;
        __shared__ T* input;
        __shared__ int len;

        if (threadIdx.x == 0) {
            len = shape::length(outputShape);
            output = reinterpret_cast<T*>(outputBuf);
            input = reinterpret_cast<T*>(gradOutBuf);
        }
        __syncthreads();

        auto tid = blockIdx.x * blockDim.x + threadIdx.x;
        auto step = blockDim.x * gridDim.x;

        for (int e = tid; e < len; e += step) {
            const auto zOffset = shape::getIndexOffset(e, outputShape);

            // if probability was non-zero on FF step, we'll scale grads back
            if (output[zOffset] != T(0.))
                output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape)] / probValue);

        }
    }
    template <typename T>
    static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
        // we're making additional FF run to see how probabilities played out with given seeds
        int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue);
        auto stream = context.launchContext()->getCudaStream();

        NDArray::prepareSpecialUse({output}, {input, gradOut});

        if (ND4J_STATUS_OK == res)
            dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue);

        NDArray::registerSpecialUse({output}, {input, gradOut});

        return res;
    }

    template <typename T>
    static __global__ void alphaDropoutSimpleKernel(void const* inputBuf, Nd4jLong const* inputShape, void* outputBuf, Nd4jLong const* outputShape, double probValue, double alpha, double alpha1, double beta, int inLen, sd::graph::RandomGenerator* nodeRng) {
        auto tid = blockIdx.x * blockDim.x + threadIdx.x;
        auto step = blockDim.x * gridDim.x;
        T const* input = reinterpret_cast<T const*>(inputBuf);
        T* output = reinterpret_cast<T*>(outputBuf);

        for (auto e = tid; e < inLen; e += step) {
            T val = nodeRng->relativeT(e, T(0.f), T(1.f));
            T xVal = input[shape::getIndexOffset(e, inputShape)];
            output[shape::getIndexOffset(e, outputShape)] = (val >= T(probValue) ? T(alpha * beta + alpha1) : T(alpha * (double)xVal + alpha1));
        }
    }
    template <typename T>
    static void alphaDropoutSimple(sd::LaunchContext* context, NDArray const* input, NDArray* output, int seed, double probValue, double alpha, double alpha1, double beta) {
        sd::graph::RandomGenerator nodeRng(3019L, seed), *dRandom;
        auto stream = context->getCudaStream();
        auto err = cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator));
        NDArray::prepareSpecialUse({output}, {input});
        if (err) {
            throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot allocate device memory for random generator.", err);
        }
        err = cudaMemcpy(dRandom, &nodeRng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice);
        if (err) {
            throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot set up device memory for random generator.", err);
        }

        alphaDropoutSimpleKernel<T><<<128, 256, 1024, *stream>>>(input->specialBuffer(), input->specialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), probValue, alpha, alpha1, beta, output->lengthOf(), dRandom);

        err = cudaFree(dRandom);
        if (err) {
            throw cuda_exception::build("helpers::alphaDropoutSimple: Cannot deallocate device memory for random generator.", err);
        }
        NDArray::registerSpecialUse({output}, {input});
    }

    template <typename T>
    static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output,
                            NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {

        if (reduceShape == nullptr){
            alphaDropoutSimple<T>(context.launchContext(), input, output, seed, probValue, alpha, alpha1, beta);
        }
        else {
            REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input");

            std::vector<Nd4jLong> dims(reduceShape->lengthOf());
            reduceShape->syncToHost(); // to ensure that follows are actual
            bool fit = true;

            for( int i = 0; i < dims.size(); i++ ) {
                if (fit) {
                    dims[i] = reduceShape->e<Nd4jLong>(i);
                    for (int e = 0; e < input->rankOf(); ++e)
                        if (fit)
                            if (input->sizeAt(e) % dims[i]) {
                                fit = false;
                            }
                }
            }

            // check dims to fit input
            REQUIRE_TRUE(fit, 0, "alpha_dropout: Noise shape should fit to input rank.");
            std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
            chunk->assign(1.f);

            alphaDropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), seed, probValue, alpha, alpha1, beta);

            // broadcast chunk to full matrix
            std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
            dropOutMultiplier->assign(1.f);

            *dropOutMultiplier += *chunk;

            output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
        }


        return Status::OK();
    }

    template <typename T>
    int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output,
                              NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {

        int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta);
        if (res == ND4J_STATUS_OK) {
            // FIXME: can we make it single-loop?
            (*output) *= alpha;
            (*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr);
        }
        return res;
    }

    int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
        BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES);
    }

    int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
        BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES);
    }

    int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) {
        BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES);
    }

}
}
}