// @author Paul Dubs


#include <helpers/AttentionHelper.h>

#include "../AttentionHelper.h"
#include <ops/declarable/CustomOperations.h>

namespace sd {

    sd::NDArray AttentionHelper::multiHeadProject(const sd::NDArray *input, const sd::NDArray *projectionMatrix, sd::LaunchContext * context) {
        auto miniBatchSize = input->sizeAt(0);
        auto seqLength = input->sizeAt(2);
        auto numHeads = projectionMatrix->sizeAt(0);
        auto projectedSize = projectionMatrix->sizeAt(1);

        auto inputPerm = input->permute({1, 0, 2});     //[batch, nIn, timeSteps] -> [nIn, batch, timeSteps]
        auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});   //[nIn, batch*timeSteps]
        auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});    //[nHeads, hS, nIn] -> [nHeads*hS, nIn]

        NDArray projected('c', {numHeads * projectionMatrix->sizeAt(1), (miniBatchSize * seqLength)}, input->dataType(), context);  //[nHeads*hS, batch*timeSteps]
        sd::ops::matmul mmul;
        mmul.execute({&projectionPrep, &inputPrep}, {&projected});

        projected.reshapei({numHeads, projectedSize, miniBatchSize, seqLength});
        projected.permutei({2, 0, 1, 3});   //[minibatch, numHeads, projectedSize, seqLength]

        return projected;

    void AttentionHelper::multiHeadProjectBp(const sd::NDArray *input, const sd::NDArray *projectionMatrix,
                                        const sd::NDArray *eps, sd::NDArray *dLdInput,
                                        sd::NDArray *dLdProjectionMatrix, sd::LaunchContext * context) {
        auto miniBatchSize = input->sizeAt(0);
        auto seqLength = input->sizeAt(2);
        auto numHeads = projectionMatrix->sizeAt(0);
        auto projectedSize = projectionMatrix->sizeAt(1);

        auto epsPerm = eps->permute({1, 2, 0, 3});
        auto epsReshaped = epsPerm.reshape('c', {numHeads * projectedSize, miniBatchSize * seqLength});

        auto inputPerm = input->permute({1, 0, 2});
        auto inputPrep = inputPerm.reshape('c', {input->sizeAt(1), (miniBatchSize * seqLength)});
        auto projectionPrep = projectionMatrix->reshape('c', {numHeads * projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});

        sd::ops::matmul_bp mmulBp;
        NDArray dLdProjectionPrep(projectionPrep.shapeInfo(), false, context);
        NDArray dLdInputPrep(inputPrep.shapeInfo(), false, context);
        mmulBp.execute({&projectionPrep, &inputPrep, &epsReshaped}, std::vector<NDArray*>{&dLdProjectionPrep, &dLdInputPrep}, {}, {}, {});

        dLdProjectionPrep.reshapei({numHeads, projectionMatrix->sizeAt(1), projectionMatrix->sizeAt(2)});

        dLdInputPrep.reshapei({input->sizeAt(1), miniBatchSize, seqLength});
        dLdInputPrep.permutei({1, 0, 2});
