/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// Created by raver119 on 01.11.2017.
//

#include "testlayers.h"
#include <helpers/ShapeUtils.h>
#include <array/NDArray.h>


using namespace sd;
using namespace sd::graph;

class ShapeUtilsTests : public testing::Test {
public:

};

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalDimsToExclude_1) {
    std::vector<int> res = ShapeUtils::evalDimsToExclude(3, {0});

    ASSERT_EQ(2, res.size());
    ASSERT_EQ(1, res.at(0));
    ASSERT_EQ(2, res.at(1));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalDimsToExclude_2) {
    std::vector<int> res = ShapeUtils::evalDimsToExclude(4, {2, 3});

    ASSERT_EQ(2, res.size());
    ASSERT_EQ(0, res.at(0));
    ASSERT_EQ(1, res.at(1));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_1)
{

    Nd4jLong xShapeInfo[]   = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99};
    Nd4jLong yShapeInfo[]   = {2,    1, 2,    2, 1, 8192, 1, 99};
    Nd4jLong expShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99};

    NDArray x(xShapeInfo);
    NDArray y(yShapeInfo);

    Nd4jLong *newShapeInfo = nullptr;
    ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);

    ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_2)
{

    Nd4jLong xShapeInfo[]   = {4, 8, 1, 6, 1, 6,   6,  1, 1, 8192, 1, 99};
    Nd4jLong yShapeInfo[]   = {3,    7, 1, 5,      5,  5, 1, 8192, 1, 99};
    Nd4jLong expShapeInfo[] = {4, 8, 7, 6, 5, 210, 30, 5, 1, 8192, 1, 99};

    NDArray x(xShapeInfo);
    NDArray y(yShapeInfo);

    Nd4jLong *newShapeInfo = nullptr;
    ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);

    ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_3)
{

    Nd4jLong xShapeInfo[]   = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99};
    Nd4jLong yShapeInfo[]   = {3, 15, 1, 5,  5, 5, 1, 8192, 1, 99};
    Nd4jLong expShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99};

    NDArray x(xShapeInfo);
    NDArray y(yShapeInfo);

    Nd4jLong *newShapeInfo = nullptr;
    ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);

    ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_4)
{

    Nd4jLong xShapeInfo[]   = {3, 8, 1, 3,  3, 3, 1, 8192, 1, 99};
    Nd4jLong yShapeInfo[]   = {2,    4, 3,     3, 1, 8192, 1, 99};
    Nd4jLong expShapeInfo[] = {3, 8, 4, 3, 12, 3, 1, 8192, 1, 99};

    NDArray x(xShapeInfo);
    NDArray y(yShapeInfo);

    Nd4jLong *newShapeInfo = nullptr;
    ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr);
    //for(int i=0; i<2*newShapeInfo[0]+4; ++i)
    //        std::cout<<newShapeInfo[i]<<" ";
    //  std::cout<<std::endl;

    ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test1)
{

    auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
    auto expected = NDArrayFactory::create<float>('c', {2,4,5});
    std::vector<int> dimensions = {1};

    auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo());

    ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test2)
{

    auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
    auto expected = NDArrayFactory::create<float>('c', {2,1,4,5});
    std::vector<int> dimensions = {1};

    auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo(), true);

    ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test3)
{

    auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
    auto expected = NDArrayFactory::create<float>('c', {1,1,1,5});
    std::vector<int> dimensions = {0,1,2};

    auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo(), true);

    ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));

}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test4)
{

    auto x = NDArrayFactory::create<float>('c',{2,3,4,5});
    auto expected = NDArrayFactory::create<float>('c', {1,1,1,1});
    std::vector<int> dimensions = {0,1,2,3};

    auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.getShapeInfo(), true);

    ASSERT_TRUE(shape::shapeEquals(expected.getShapeInfo(), newShapeInfo));
}

TEST_F(ShapeUtilsTests, Test_Strings_1) {
    auto x = NDArrayFactory::create<float>('c', {2, 3, 4, 5});
    std::string exp("[2, 3, 4, 5]");

    auto s = ShapeUtils::shapeAsString(&x);

    ASSERT_EQ(exp, s);
}

TEST_F(ShapeUtilsTests, Test_Backward_Axis_1) {
    auto x = NDArrayFactory::create<float>('c', {2, 4, 3});
    auto y = NDArrayFactory::create<float>('c', {4, 3});
    std::vector<int> exp({0});

    auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo());

    ASSERT_EQ(exp, z);
}

TEST_F(ShapeUtilsTests, Test_Backward_Axis_2) {
    auto x = NDArrayFactory::create<float>('c', {2, 4, 4, 3});
    auto y = NDArrayFactory::create<float>('c', {4, 1, 3});
    std::vector<int> exp({0, 2});

    auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo());

    ASSERT_EQ(exp, z);
}


TEST_F(ShapeUtilsTests, Test_Backward_Axis_3) {
    auto x = NDArrayFactory::create<float>('c', {2, 4, 4, 3});
    auto y = NDArrayFactory::create<float>('c', {2, 1, 1, 3});
    std::vector<int> exp({1, 2});

    auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo());

    ASSERT_EQ(exp, z);
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalPermutFromTo_test1) {

    int a=1, b=2, c=3, d=4;
    std::vector<int> expected = {2, 3, 0, 1};

    std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,b});

    ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result)));

}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalPermutFromTo_test2) {

    int a=1, b=2, c=3, d=4;
    std::vector<int> expected = {0, 1, 3, 2};

    std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c});

    ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result)));

}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalPermutFromTo_test3) {

    int a=2, b=2, c=3, d=2;
    std::vector<int> expected = {0, 1, 3, 2};

    std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c});

    ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result)));

}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalPermutFromTo_test4) {

    int a=2, b=3, c=4, d=5;

    std::vector<int> result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d});

    ASSERT_TRUE(result.empty());

}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalPermutFromTo_test5) {

    int a=1, b=2, c=3, d=4;

    // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,8}), const char*);
    ASSERT_TRUE(1);
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, evalPermutFromTo_test6) {

    int a=1, b=2, c=3, d=4;

    // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d,d}), const char*);
    ASSERT_TRUE(1);
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, isPermutNecessary_test1) {

    ASSERT_TRUE(ShapeUtils::isPermutNecessary({1,0,2,3}));
}

//////////////////////////////////////////////////////////////////
TEST_F(ShapeUtilsTests, isPermutNecessary_test2) {

    ASSERT_TRUE(!ShapeUtils::isPermutNecessary({0,1,2,3}));
}