/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author raver119@gmail.com
//

#ifndef LIBND4J_WORKSPACETESTS_H
#define LIBND4J_WORKSPACETESTS_H

#include "testlayers.h"
#include <array/NDArray.h>
#include <memory/Workspace.h>
#include <memory/MemoryRegistrator.h>
#include <helpers/MmulHelper.h>

using namespace sd;
using namespace sd::memory;

class WorkspaceTests : public testing::Test {

};


TEST_F(WorkspaceTests, BasicInitialization1) {
    Workspace workspace(1024);

    ASSERT_EQ(1024, workspace.getCurrentSize());
    ASSERT_EQ(0, workspace.getCurrentOffset());
}

TEST_F(WorkspaceTests, BasicInitialization2) {
    Workspace workspace(65536);

    ASSERT_EQ(0, workspace.getCurrentOffset());
    LaunchContext ctx;
    ctx.setWorkspace(&workspace);
    auto array = NDArrayFactory::create<float>('c', {5, 5}, &ctx);

    array.p(0, 1.0f);
    array.p(5, 1.0f);

    auto v = array.reduceNumber(reduce::Sum);
    auto f = v.e<float>(0);

    ASSERT_NEAR(2.0f, f, 1e-5);

    ASSERT_TRUE(workspace.getCurrentOffset() > 0);
}


TEST_F(WorkspaceTests, BasicInitialization3) {
    Workspace workspace;

    ASSERT_EQ(0, workspace.getCurrentOffset());
    LaunchContext ctx;
    ctx.setWorkspace(&workspace);

    auto array = NDArrayFactory::create<float>('c', {5, 5}, &ctx);

    array.p(0, 1.0f);
    array.p(5, 1.0f);

    auto v = array.reduceNumber(reduce::Sum);
    auto f = v.e<float>(0);

    ASSERT_NEAR(2.0f, array.reduceNumber(reduce::Sum).e<float>(0), 1e-5);

    ASSERT_TRUE(workspace.getCurrentOffset() == 0);
}


TEST_F(WorkspaceTests, ResetTest1) {
    Workspace workspace(65536);
    LaunchContext ctx;
    ctx.setWorkspace(&workspace);

    auto array = NDArrayFactory::create<float>('c', {5, 5}, &ctx);
    array.p(0, 1.0f);
    array.p(5, 1.0f);

    workspace.scopeOut();
    for (int e = 0; e < 5; e++) {
        workspace.scopeIn();

        auto array2 = NDArrayFactory::create<float>('c', {5, 5}, &ctx);
        array2.p(0, 1.0f);
        array2.p(5, 1.0f);

        ASSERT_NEAR(2.0f, array2.reduceNumber(reduce::Sum).e<float>(0), 1e-5);

        workspace.scopeOut();
    }

    ASSERT_EQ(65536, workspace.getCurrentSize());
    ASSERT_EQ(0, workspace.getCurrentOffset());
    ASSERT_EQ(0, workspace.getSpilledSize());
}


TEST_F(WorkspaceTests, StretchTest1) {
    if (!Environment::getInstance().isCPU())
        return;

    Workspace workspace(128);
    void* ptr = workspace.allocateBytes(8);
    workspace.scopeOut();
    ASSERT_EQ(0, workspace.getSpilledSize());
    ASSERT_EQ(0, workspace.getSpilledSecondarySize());
    ASSERT_EQ(0, workspace.getCurrentOffset());
    ASSERT_EQ(0, workspace.getCurrentSecondaryOffset());


    workspace.scopeIn();
    for (int e = 0; e < 10; e++) {

        workspace.allocateBytes(128);

    }
    ASSERT_EQ(128 * 9, workspace.getSpilledSize());
    workspace.scopeOut();
    workspace.scopeIn();

    ASSERT_EQ(0, workspace.getCurrentOffset());

    // we should have absolutely different pointer here, due to reallocation
    void* ptr2 = workspace.allocateBytes(8);

    //ASSERT_FALSE(ptr == ptr2);


    ASSERT_EQ(1280, workspace.getCurrentSize());
    ASSERT_EQ(0, workspace.getSpilledSize());
}

TEST_F(WorkspaceTests, NewInWorkspaceTest1) {
    if (!Environment::getInstance().isCPU())
        return;

    Workspace ws(65536);

    ASSERT_EQ(65536, ws.getCurrentSize());
    ASSERT_EQ(0, ws.getCurrentOffset());

    ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached());

    MemoryRegistrator::getInstance().attachWorkspace(&ws);

    ASSERT_TRUE(MemoryRegistrator::getInstance().hasWorkspaceAttached());

    auto ast = NDArrayFactory::create_<float>('c', {5, 5});

    ASSERT_TRUE(ws.getCurrentOffset() > 0);

    delete ast;

    MemoryRegistrator::getInstance().forgetWorkspace();

    ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached());
    ASSERT_TRUE(MemoryRegistrator::getInstance().getWorkspace() == nullptr);
}


TEST_F(WorkspaceTests, NewInWorkspaceTest2) {
    Workspace ws(65536);
    LaunchContext ctx;
    ctx.setWorkspace(&ws);

    ASSERT_EQ(65536, ws.getCurrentSize());
    ASSERT_EQ(0, ws.getCurrentOffset());

    MemoryRegistrator::getInstance().attachWorkspace(&ws);

    auto ast = NDArrayFactory::create_<float>('c', {5, 5}, &ctx);

    ASSERT_TRUE(ws.getCurrentOffset() > 0);

    delete ast;

    MemoryRegistrator::getInstance().forgetWorkspace();
}

TEST_F(WorkspaceTests, CloneTest1) {
    if (!Environment::getInstance().isCPU())
        return;

    Workspace ws(65536);

    ws.allocateBytes(65536 * 2);

    ASSERT_EQ(65536 * 2, ws.getSpilledSize());

    auto clone = ws.clone();

    ASSERT_EQ(65536 * 2, clone->getCurrentSize());
    ASSERT_EQ(0, clone->getCurrentOffset());
    ASSERT_EQ(0, clone->getSpilledSize());

    delete clone;
}

TEST_F(WorkspaceTests, Test_Arrays_1) {
    Workspace ws(65536);
    LaunchContext ctx;
    ctx.setWorkspace(&ws);

    auto x = NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, &ctx);

    // x.printIndexedBuffer("x0");

    auto y = NDArrayFactory::create<float>('c', {3, 3}, {-1, -2, -3, -4, -5, -6, -7, -8, -9}, &ctx);

    // x.printIndexedBuffer("x2");

    auto z = NDArrayFactory::create<float>('c', {3, 3}, {0, 0, 0, 0, 0, 0, 0, 0, 0}, &ctx);

    MmulHelper::mmul(&x, &y, &z);

    y.assign(&x);


    // x.printIndexedBuffer("x3");
    // y.printIndexedBuffer("y");
    // z.printIndexedBuffer("z");
}

#ifdef GRAPH_FILES_OK
TEST_F(WorkspaceTests, Test_Graph_1) {
    auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb");
    auto workspace = graph->getVariableSpace()->workspace();

    auto status = GraphExecutioner::execute(graph);
    ASSERT_EQ(Status::OK(), status);

    delete graph;
}
#endif

TEST_F(WorkspaceTests, Test_Externalized_1) {
    if (!Environment::getInstance().isCPU())
        return;

    char buffer[10000];
    ExternalWorkspace pojo((Nd4jPointer) buffer, 10000, nullptr, 0);

    ASSERT_EQ(10000, pojo.sizeHost());
    ASSERT_EQ(0, pojo.sizeDevice());

    Workspace ws(&pojo);
    ASSERT_EQ(10000, ws.getCurrentSize());
    ASSERT_EQ(10000, ws.getAllocatedSize());
    LaunchContext ctx;
    ctx.setWorkspace(&ws);

    auto x = NDArrayFactory::create<float>('c', {10, 10}, &ctx);

    // only buffer size goes into account
    ASSERT_EQ(400, ws.getUsedSize());
    ASSERT_EQ(400, ws.getCurrentOffset());

    x.assign(2.0);

    float m = x.meanNumber().e<float>(0);
    ASSERT_NEAR(2.0f, m, 1e-5);
}

// TODO: uncomment this test once long shapes are introduced
/*
TEST_F(WorkspaceTests, Test_Big_Allocation_1) {
    Workspace ws(65536);
    NDArray<float> x('c', {256, 64, 384, 384}, &ws);
}
*/


#endif //LIBND4J_WORKSPACETESTS_H