/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// Created by raver119 on 09.02.18.
//


#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <helpers/helper_hash.h>
#include <array/NDArray.h>
#include <array/NDArrayList.h>


using namespace sd;
using namespace sd::graph;

class DeclarableOpsTests6 : public testing::Test {
public:

    DeclarableOpsTests6() {
        printf("\n");
        fflush(stdout);
    }
};


TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
    auto matrix = NDArrayFactory::create<double>('c', {5, 2});
    auto b = NDArrayFactory::create<double>('c', {1}, {0.});
    auto e = NDArrayFactory::create<double>('c', {1}, {1});
    auto s = NDArrayFactory::create<double>('c', {1}, {1});

    auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});

    matrix.linspace(1);

    sd::ops::strided_slice op;
    auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
    auto matrix = NDArrayFactory::create<double>('c', {5, 2});
    auto b = NDArrayFactory::create<double>('c', {1}, {0.0f});
    auto e = NDArrayFactory::create<double>('c', {1}, {1.0f});
    auto s = NDArrayFactory::create<double>('c', {1}, {1.0f});

    auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});

    matrix.linspace(1);

    sd::ops::strided_slice op;
    auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_EQ(exp, *z);


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
    auto matrix = NDArrayFactory::create<double>(10);
    auto b = NDArrayFactory::create<double>(0);
    auto e = NDArrayFactory::create<double>(0);
    auto s = NDArrayFactory::create<double>(1.0);

    //auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});

    //matrix.linspace(1);

    sd::ops::strided_slice op;
    auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    //z->printShapeInfo("SS OS shape");
    ASSERT_TRUE(z->isEmpty());
    //ASSERT_EQ(exp, *z);


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
    auto matrix = NDArrayFactory::create<double>('c', {1}, {10});
    auto b = NDArrayFactory::create<double>('c', {1}, {0.});
    auto e = NDArrayFactory::create<double>('c', {1}, {0.});
    auto s = NDArrayFactory::create<double>('c', {1}, {1.0});

    auto exp = NDArrayFactory::create<double>(10);

    //matrix.linspace(1);

    sd::ops::strided_slice op;
    auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(z->equalsTo(exp));
    //ASSERT_EQ(exp, *z);


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
    int z = 0;
    auto matrix = NDArrayFactory::create<double>('c', {1}, {10});
    auto b = NDArrayFactory::create_<int>('c', {1}, {1});
    auto e = NDArrayFactory::create_<int>('c', {1}, {z});
    auto s = NDArrayFactory::create_<int>('c', {1}, {1});
    sd::ops::ones_as opOnes;
    //auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
    auto onesRes = opOnes.evaluate({&matrix});
    //matrix.linspace(1);
    ASSERT_EQ(onesRes.status(), Status::OK());

    auto ones = onesRes.at(0);
    *ones *= 10;
    auto onesD = new NDArray(ones->dup());

    auto variableSpace = new VariableSpace();
    variableSpace->putVariable(-1, onesD);
    variableSpace->putVariable(-2, b);
    variableSpace->putVariable(-3, e);
    variableSpace->putVariable(-4, s);
    auto block = new Context(1, variableSpace, false);  // not-in-place
    block->fillInputs({-1});
    block->fillInputs({-2});
    block->fillInputs({-3});
    block->fillInputs({-4});
    block->getIArguments()->push_back(0);
    block->getIArguments()->push_back(0);
    block->getIArguments()->push_back(1);
    block->getIArguments()->push_back(0);
    block->getIArguments()->push_back(0);
    auto inputShapes = new ShapeList({ones->shapeInfo(), b->shapeInfo(), e->shapeInfo(), s->shapeInfo()});
    sd::ops::strided_slice op;
    auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0});
    ASSERT_EQ(result->size(), 1);
    ASSERT_TRUE(shape::isEmpty(result->at(0)));
    //ASSERT_EQ(exp, *z);
    delete block;
    delete result;
    delete variableSpace;
    delete inputShapes;
}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
    auto matrix = NDArrayFactory::create<double>('c', {3, 2, 2});
    auto b = NDArrayFactory::create<int>('c', {1}, {2});
    auto e = NDArrayFactory::create<int>('c', {1}, {3});
    auto s = NDArrayFactory::create<int>('c', {1}, {1});

    auto exp = NDArrayFactory::create<double>('c', {2,2}, {0.0f, 0.0f, 0., 0.});

    //matrix.linspace(1);

    sd::ops::strided_slice op;
    auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
    auto matrix = NDArrayFactory::create<double>('c', {3, 2, 2});
    auto b = NDArrayFactory::create<int>('c', {1}, {2});
    auto e = NDArrayFactory::create<int>('c', {1}, {3});
    auto s = NDArrayFactory::create<int>('c', {1}, {1});

    auto exp = NDArrayFactory::create<double>('c', {1,2,2}, {0.0f, 0.0f, 0., 0.});

    //matrix.linspace(1);

    sd::ops::strided_slice op;
    auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
    int zero = 0;
    auto matrix = NDArrayFactory::create<double>('c', {5, 4});
    auto b = NDArrayFactory::create<int>('c', {1}, {zero});
    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
    auto s = NDArrayFactory::create<int>('c', {1}, {1});

    //auto exp = NDArrayFactory::create<double>('c', {1,2,2}, {0.0f, 0.0f, 0., 0.});

    //matrix.linspace(1);

    sd::ops::strided_slice op;
    auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    //ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
    int zero = 0;
    auto matrix = NDArrayFactory::create<double>('c', {5, 4});
//    auto b = NDArrayFactory::create<int>('c', {1}, {zero});
//    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
//    auto s = NDArrayFactory::create<int>('c', {1}, {1});

    auto grad = NDArrayFactory::create<double>('c', {5});

    matrix.linspace(1);
    grad.linspace(1);

    sd::ops::strided_slice_bp op;
    auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    //ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
    int zero = 0;
    auto matrix = NDArrayFactory::create<double>('c', {1, 2});
//    auto b = NDArrayFactory::create<int>('c', {1}, {zero});
//    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
//    auto s = NDArrayFactory::create<int>('c', {1}, {1});

    auto grad = NDArrayFactory::create<double>('c', {1}, {1.});

    matrix.linspace(1);
    //grad.linspace(1);

    sd::ops::strided_slice_bp op;
    auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    //ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
    int zero = 0;
    auto matrix = NDArrayFactory::create<float>('c', {4, 8192});
//    auto b = NDArrayFactory::create<int>('c', {1}, {zero});
//    auto e = NDArrayFactory::create<int>('c', {1}, {zero});
//    auto s = NDArrayFactory::create<int>('c', {1}, {1});

    auto grad = NDArrayFactory::create<double>('c', {4, 256});

    matrix.linspace(1);
    grad.linspace(1);

    sd::ops::strided_slice_bp op;
    auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    //ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
    auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
    auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});

    sd::ops::test_scalar op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_Order_1) {
    auto x = NDArrayFactory::create<double>('f', {2, 3});
    auto exp = NDArrayFactory::create<double>('c', {2, 3});
    x.linspace(1);
    exp.linspace(1);

    sd::ops::order op;
    auto result = op.evaluate({&x}, {}, {0});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    ASSERT_TRUE(exp.equalsTo(z));
    ASSERT_NE(x.ordering(), z->ordering());


}

TEST_F(DeclarableOpsTests6, cumSum_1) {
    auto x = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 2.f, 3.f, 4.f});
    auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 3.f, 6.f, 10.f});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 0});
    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, cumSum_2) {
    auto x= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
    auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 0, 1});
    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    // z->printIndexedBuffer("CumSum1");
    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, cumSum_3) {
    auto x= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f});
    auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 0, 0});
    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, cumSum_4) {
    auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    auto exp = NDArrayFactory::create<double>('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    // z->printBuffer();

    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, cumSum_5) {
    auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    auto exp = NDArrayFactory::create<double>('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 1, 1}, {});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, cumSum_6) {
    auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    auto exp = NDArrayFactory::create<double>('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {1, 1, 0}, {});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, cumSum_7) {
    auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {1, 1, 1}, {});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, cumSum_8) {
    auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
    auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
    auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});

    sd::ops::cumsum op;
    auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_9) {

    auto inputC = NDArrayFactory::create<double>('c', {3, 5},   {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
    auto axis = NDArrayFactory::create<Nd4jLong>(1);

    auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1.,  3.,  6., 10., 15., 6., 13., 21., 30., 40., 11., 23., 36., 50., 65.});
    auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {0., 1., 3.,  6., 10., 0.,  6., 13., 21., 30., 0., 11., 23., 36., 50.});

    auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {15, 14, 12, 9, 5,40, 34, 27, 19, 10,65, 54, 42, 29, 15});    //+++
    auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {14, 12, 9, 5, 0,34, 27, 19, 10, 0,54, 42, 29, 15, 0});

    int exclusive, reverse;

    //************************************//
    exclusive = 0; reverse = 0;

    sd::ops::cumsum op;
    auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {});
    ASSERT_EQ(Status::OK(), result.status());
    auto z = result.at(0);
    ASSERT_TRUE(expFF.equalsTo(z));


    //************************************//
    exclusive = 1; reverse = 0;

    result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
    ASSERT_EQ(Status::OK(), result.status());
    z = result.at(0);
    ASSERT_TRUE(expTF.equalsTo(z));


    //************************************//
    exclusive = 0; reverse = 1;

    result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
    ASSERT_EQ(Status::OK(), result.status());
    z = result.at(0);
    ASSERT_TRUE(expFT.equalsTo(z));


    //************************************//
    exclusive = 1; reverse = 1;

    result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse});
    ASSERT_EQ(Status::OK(), result.status());
    z = result.at(0);
    ASSERT_TRUE(expTT.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_10) {
    auto x = NDArrayFactory::create<double>('c', {4, 16, 16, 1});
    auto y = NDArrayFactory::create<int>(-3);

    sd::ops::cumsum op;
    auto result = op.evaluate({&x, &y}, {}, {1, 1});
    ASSERT_EQ(Status::OK(), result.status());


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_11) {

    NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE);
    auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {12., 15., 18.,11., 13., 15.,7.,  8.,  9., 39., 42., 45.,29., 31., 33.,16., 17., 18., 66., 69., 72.,47., 49., 51.,25., 26., 27.});

    x.linspace(1);

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 1, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_12) {

    NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE);
    auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {1.,  2.,  3.,5.,  7.,  9.,12., 15., 18., 10., 11., 12.,23., 25., 27.,39., 42., 45., 19., 20., 21.,41., 43., 45., 66., 69., 72.});

    x.linspace(1);

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_13) {

    NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE);
    auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {11., 13., 15.,7.,  8.,  9.,0.,  0.,  0., 29., 31., 33.,16., 17., 18.,0.,  0.,  0., 47., 49., 51.,25., 26., 27.,0.,  0.,  0.});

    x.linspace(1);

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {1, 1, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_14) {

    NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE);
    auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {29., 31., 33.,35., 37., 39.,41., 43., 45., 19., 20., 21.,22., 23., 24.,25., 26., 27.,  0.,  0.,  0.,0.,  0.,  0.,0.,  0.,  0.});

    x.linspace(1);

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {1, 1, 0});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_15) {

    NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE);
    auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {6.,  5.,  3.,15., 11.,  6.,24., 17.,  9., 33., 23., 12.,42., 29., 15.,51., 35., 18., 60., 41., 21.,69., 47., 24.,78., 53., 27.});

    x.linspace(1);

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 1, 2});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_16) {

    NDArray x('f', {3, 4}, sd::DataType::FLOAT32);

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    // z->printShapeInfo();
    // x.printShapeInfo();

    ASSERT_TRUE(z->ews() == 1);
    ASSERT_TRUE(x.ews() == 1);


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_17) {

    NDArray x('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray x0 = x(0, {0});
    NDArray x1 = x(1, {0});
    x0.linspace(1);
    x1.linspace(1);

    NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray exp0 = exp(0, {0});
    NDArray exp1 = exp(1, {0});

    exp0.p(0, 1.);
    exp1.p(0, 1.);

    for (int i = 1; i < 1500; ++i) {
        const auto prev = exp0.e<float>(i-1);
        exp0.p(i, prev + i + 1);
        exp1.p(i, prev + i + 1);
    }

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_18) {

    NDArray x('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray x0 = x(0, {0});
    NDArray x1 = x(1, {0});
    x0.linspace(1);
    x1.linspace(1);

    NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray exp0 = exp(0, {0});
    NDArray exp1 = exp(1, {0});

    exp0.p(0, 0.);
    exp1.p(0, 0.);

    for (int i = 1; i < 1500; ++i) {
        const auto prev = exp0.e<float>(i-1);
        exp0.p(i, prev + i);
        exp1.p(i, prev + i);
    }

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {1, 0, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_19) {

    NDArray x('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray x0 = x(0, {0});
    NDArray x1 = x(1, {0});
    x0.linspace(1);
    x1.linspace(1);

    NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray exp0 = exp(0, {0});
    NDArray exp1 = exp(1, {0});

    exp0.p(1499, 1500.f);
    exp1.p(1499, 1500.f);

    for (int i = 1498; i >= 0; --i) {
        const auto prev = exp0.e<float>(i + 1);
        exp0.p(i, prev + i + 1);
        exp1.p(i, prev + i + 1);
    }

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {0, 1, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);
    // exp0.printBuffer();

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_20) {

    NDArray x('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray x0 = x(0, {0});
    NDArray x1 = x(1, {0});
    x0.linspace(1);
    x1.linspace(1);

    NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32);
    NDArray exp0 = exp(0, {0});
    NDArray exp1 = exp(1, {0});

    exp0.p(1499, 0.);
    exp1.p(1499, 0.);

    for (int i = 1498; i >= 0; --i) {
        const auto prev = exp0.e<float>(i + 1);
        exp0.p(i, prev + i + 2);
        exp1.p(i, prev + i + 2);
    }

    sd::ops::cumsum op;
    auto result = op.evaluate({&x}, {}, {1, 1, 1});
    ASSERT_EQ(Status::OK(), result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
    auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
    auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
    auto exp = NDArrayFactory::create<int>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
    sd::ops::mergemaxindex op;

    auto res = op.evaluate({&x, &y, &z}, {}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(res.at(0)->equalsTo(exp));

}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f,  2.f,  3.f,   4.f,  5.f,  60.f,  7.f,  8.f});
    auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f,  30.f,  4.f,  50.f, 6.f,  70.f, 8.f});
    auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f,  20.f, 3.f,   40.f, 5.f,  6.f, 7.f,  80.f});
    auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2,     1,    2,    1,    0,   1,    2});
    sd::ops::mergemaxindex op;

    auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64});

    ASSERT_EQ(ND4J_STATUS_OK, ress.status());
    ASSERT_TRUE(ress.at(0)->equalsTo(exp));

}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_3) {

    auto x1 = NDArrayFactory::create<double>('c', {3}, {1.f, 0.f, 0.f});
    auto x2 = NDArrayFactory::create<double>('c', {3}, {0.f, 1.f, 0.f});
    auto x3 = NDArrayFactory::create<double>('c', {3}, {0.f, 0.f, 1.f});
    NDArray z('c', {3}, sd::DataType::INT32);
    NDArray expZ('c', {3}, {0, 1, 2}, sd::DataType::INT32);

    sd::ops::mergemaxindex op;
    auto result = op.execute({&x1, &x2, &x3}, {&z}, {}, {}, {});

    ASSERT_EQ(Status::OK(), result);
    ASSERT_TRUE(z.equalsTo(expZ));
}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestDropout_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
    auto shape = NDArrayFactory::create<Nd4jLong>({2, 2});
    sd::ops::dropout op;

    auto res = op.evaluate({&x, &shape}, {0.2f}, {113});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    //res.at(0)->printIndexedBuffer("Result is ");
    //x.printIndexedBuffer("Input is");


}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMod_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
    auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
    auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0});
    sd::ops::mod op;

    auto res = op.evaluate({&x, &y});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
//    res.at(0)->printIndexedBuffer("MOD Result is ");
//    x.printIndexedBuffer("Input is");
    ASSERT_TRUE(res.at(0)->equalsTo(exp));

}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMod_BP_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
    auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
    auto eps = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
    auto exp = NDArrayFactory::create<double>('c', {2, 2, 2});
    sd::ops::mod_bp op;

    auto res = op.evaluate({&x, &y, &eps});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
//    res.at(0)->printIndexedBuffer("MOD_BP Result is ");

    //    x.printIndexedBuffer("Input is");
    ASSERT_TRUE(res.at(0)->equalsTo(exp));

}

///////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestRank_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
    auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
    auto eps = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
    auto exp = NDArrayFactory::create<int>(3);
    sd::ops::rank op;

    auto res = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());

    ASSERT_TRUE(res.at(0)->equalsTo(exp));

}
TEST_F(DeclarableOpsTests6, TestDropout_2) {
//    auto x0 = NDArrayFactory::create<double>('c', {10, 10});
//    auto x1 = NDArrayFactory::create<double>('c', {10, 10});
    auto x = NDArrayFactory::create<double>('c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});

    sd::ops::dropout op;

    auto res = op.evaluate({&x}, {0.4f}, {113});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());


}

TEST_F(DeclarableOpsTests6, TestDropout_3) {
//    auto x0 = NDArrayFactory::create<double>('c', {10, 10});
//    auto x1 = NDArrayFactory::create<double>('c', {10, 10});
    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
    auto shape = NDArrayFactory::create<int>({1, 2});

    sd::ops::dropout op;

    auto res = op.evaluate({&x, &shape}, {0.4f}, {113});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0.,   0.3,  5.5,1.5, 0.,   1.3,  6.5,8.6, 0.,    0.,  0.4,2.5, 1.,   0.3,  4.5,
                                                                1.5, 1.,   1.3,  1.5, 3.5, 0.,   1.3,  2.5, 2.6, 2.,    3.,  1.4, 4.5, 1.,   0.3,  0.5});
    auto expI = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2, 4}, {0,  1,  2,  3,4,  5,  6,  7,8,  9, 10, 11,12, 13, 14, 15,
                                                                0,  1,  2,  3,4,  5,  6,  7,8,  9, 10, 11,12, 13, 14, 15});

    sd::ops::max_pool_with_argmax op;

    auto res = op.evaluate({&x}, {}, {1,1,1,1,1,1,1,1,1});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(expI.isSameShape(res.at(0)));
    ASSERT_TRUE(expI.isSameShape(res.at(1)));
    ASSERT_TRUE(x.equalsTo(res.at(0)));
    ASSERT_TRUE(expI.equalsTo(res.at(1)));
    //x.printIndexedBuffer("Input is");

    ASSERT_TRUE(expI.equalsTo(res.at(1)));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
//    auto x0 = NDArrayFactory::create<double>('c', {10, 10});
//    auto x1 = NDArrayFactory::create<double>('c', {10, 10});
    auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0.,  0.3, 5.5,1.5, 0.,  1.3, 6.5,8.6, 0.,   0., 0.4,2.5, 1.,  0.3, 4.5,1.5, 1.,
                                                                1.3, 1.5,3.5, 0.,  1.3, 2.5,2.6, 2.,   3., 1.4,4.5, 1.,  0.3, 0.5});
// ------------------------------------
    double count = 8.0;
    auto sumExp = NDArrayFactory::create<double>({30.2, 5., 7.8, 22.8});
    auto sqrExp = NDArrayFactory::create<double>({154.22,   7.,    14.34, 103.62});

    auto axis = NDArrayFactory::create<Nd4jLong>({0, 1, 2});

    sd::ops::sufficient_statistics op;

    auto res = op.evaluate({&x, &axis});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_EQ(res.at(0)->e<double>(0), count);
    ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
    ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
//    auto x0 = NDArrayFactory::create<double>('c', {10, 10});
//    auto x1 = NDArrayFactory::create<double>('c', {10, 10});
    auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0.,  0.3, 5.5,1.5, 0.,  1.3, 6.5,8.6, 0.,   0., 0.4,2.5, 1.,  0.3, 4.5,
                                                                1.5, 1.,  1.3, 1.5,3.5, 0.,  1.3, 2.5,2.6, 2.,   3., 1.4,4.5, 1.,  0.3, 0.5});
// ------------------------------------
    double count = 4.0;
    auto sumExp = NDArrayFactory::create<double>('c', {2, 4}, {
        18.2,        3.,         4.6,        8.8,
        12.,         2.,         3.2,        14.}
    );

    auto sqrExp = NDArrayFactory::create<double>('c', {2, 4}, {
        113.22, 5., 10.78, 34.62,
           41., 2.,  3.56, 69.}
    );

    auto axis = NDArrayFactory::create<int>({0, 1});

    sd::ops::sufficient_statistics op;

    auto res = op.evaluate({&x, &axis});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_EQ(res.at(0)->e<double>(0), count);
    ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
    ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_1) {

    auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
        1, 2, 0, 1, 2, 2, 1, 2}
    );
// ------------------------------------

    NDArray exp('c', {3}, {1, 3, 4}, sd::DataType::INT32);

    sd::ops::bincount op;

    auto res = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));

}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_2) {

    auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
        1, 2, 0, 1, 2, 2, 1, 2}
    );

    auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
        2, 1, 3, 1, 5, 1, 1, 6}
    );

// ------------------------------------

    auto exp = NDArrayFactory::create<double>({3., 4., 13.});

    sd::ops::bincount op;

    auto res = op.evaluate({&x, &weights});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));
}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_3) {

    auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
        1, 2, 0, 1, 2, 2, 1, 2}
    );

    auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
        2, 1, 3, 1, 5, 1, 1, 6}
    );

// ------------------------------------

    auto exp = NDArrayFactory::create<double>({3., 4.});

    sd::ops::bincount op;

    auto res = op.evaluate({&x, &weights}, {}, {0, 2});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));

}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_4) {

    auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
        1, 2, 0, 1, 2, 2, 1, 2}
    );

    auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
        2, 1, 3, 1, 5, 1, 1, 6}
    );

// ------------------------------------

    auto exp = NDArrayFactory::create<double>({3., 4.,  13., 0.0});

    sd::ops::bincount op;

    auto res = op.evaluate({&x, &weights}, {}, {4, 4});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));
}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_5) {

    auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
            1, 2, 0, 1, 2, 2, 1, 2}
    );

    auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
            2, 1, 3, 1, 5, 1, 1, 6}
    );
    auto minV = NDArrayFactory::create(4);
    auto maxV = NDArrayFactory::create(4);
// ------------------------------------

    auto exp = NDArrayFactory::create<double>({3., 4., 13., 0.0});

    sd::ops::bincount op;

    auto res = op.evaluate({&x, &weights, &minV, &maxV});
    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    // res->at(0)->printBuffer("BC out");
    ASSERT_TRUE(exp.equalsTo(res.at(0)));

}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {

    auto x = NDArrayFactory::create<int>( {2, 2, 2} );

    auto y = NDArrayFactory::create<int>({ 2, 1, 2});

    auto exp = NDArrayFactory::create<int>({2, 2, 2});

    sd::ops::broadcast_dynamic_shape op;

    auto res = op.evaluate({&x, &y});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));

}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {

    auto x = NDArrayFactory::create<Nd4jLong>( {2, 2} );

    auto y = NDArrayFactory::create<Nd4jLong>({2, 1, 2});

    auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});

    sd::ops::broadcast_dynamic_shape op;

    auto res = op.evaluate({&x, &y});
    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));

}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {

    auto x = NDArrayFactory::create<int>( {2, 2, 2} );

    auto y = NDArrayFactory::create<int>({2, 1});

    auto exp = NDArrayFactory::create<int>({2, 2, 2});

    sd::ops::broadcast_dynamic_shape op;

    auto res = op.evaluate({&x, &y}, {}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));
}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {

    auto x = NDArrayFactory::create<Nd4jLong>( {2, 1} );

    auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, {4});

    auto exp = NDArrayFactory::create<Nd4jLong>({2, 4});

    sd::ops::broadcast_dynamic_shape op;

    auto res = op.evaluate({&x, &y});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    //res->at(0)->printBuffer("Shape SGO 4");
    ASSERT_TRUE(exp.equalsTo(res.at(0)));
}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {

    auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 4});

    auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});

    auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});

    sd::ops::broadcast_dynamic_shape op;
    auto res = op.evaluate({&x, &y});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));

}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {

    auto x = NDArrayFactory::create<Nd4jLong>({1, 1, 3});

    auto y = NDArrayFactory::create<Nd4jLong>({2, 4, 1});

    auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});

    sd::ops::broadcast_dynamic_shape op;
    auto res = op.evaluate({&x, &y});

    ASSERT_EQ(ND4J_STATUS_OK, res.status());
    ASSERT_TRUE(exp.equalsTo(res.at(0)));
}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) {

    auto x = NDArrayFactory::create<int>('c', {1}, {1});

    auto y = NDArrayFactory::create<int>('c', {1}, {4});

    auto z = NDArrayFactory::create<int>('c', {1});

    auto exp = NDArrayFactory::create<int>('c', {1}, {4});

    sd::ops::broadcast_dynamic_shape op;
    auto status = op.execute({&x, &y}, {&z}, {}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, status);
    ASSERT_TRUE(exp.equalsTo(z));
}

/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) {

    auto x = NDArrayFactory::create<int>('c', {2}, {2,2});

    auto y = NDArrayFactory::create<int>('c', {1}, {1});

    auto z = NDArrayFactory::create<Nd4jLong>('c', {2});

    auto exp = NDArrayFactory::create<Nd4jLong>('c', {2}, {2,2});

    sd::ops::broadcast_dynamic_shape op;
    auto status = op.execute({&x, &y}, {&z}, {}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, status);
    // ASSERT_TRUE(exp.equalsTo(z));

}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
                                      -3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
                                      -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
    );

    auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
            -0.2771281,  0.,          0.,
            0.36950415,  0.,          0.,
            -0.2771281,  0.,          0.,
            0.36950415,  0.,          0.,
            -0.2771281,  0.,          0.,
            0.36950415,  0.,          0.}
    );
//    8.660254
//    auto expNorm(8.660254);

    sd::ops::clip_by_global_norm op;
    auto result = op.evaluate({&x}, {0.8}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    auto norm = result.at(1);
    //z->printIndexedBuffer("Output");
    //exp.printIndexedBuffer("Expected");
    //norm->printIndexedBuffer("Norm");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));
//    ASSERT_TRUE(expNorm.equalsTo(norm));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {

    auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
                                      -3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
                                      -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
    );

    auto a = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
                                      -3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
                                      -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
    );

    auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
                                    -0.44090813,   0.,          0.,
                                      0.5878775,   0.,          0.,
                                    -0.44090813,   0.,          0.,
                                      0.5878775,   0.,          0.,
                                    -0.44090813,   0.,          0.,
                                      0.5878775,   0.,          0.}
//12.247449

    );

    sd::ops::clip_by_global_norm op;
    auto result = op.evaluate({&x, &a}, {1.8}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    auto y = result.at(1);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.isSameShape(y));
    ASSERT_TRUE(exp.equalsTo(z));
    ASSERT_TRUE(exp.equalsTo(y));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {

    auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
    auto a = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
    auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
            -0.19595918,  0.,          0.,
              0.2612789,  0.,          0.,
            -0.19595918,  0.,          0.,
              0.2612789,  0.,          0.,
            -0.19595918,  0.,          0.,
              0.2612789,   0.,          0.}
    );

    sd::ops::clip_by_global_norm op;
    auto result = op.evaluate({&x, &a}, {0.8}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    auto y = result.at(1);
    //z->printIndexedBuffer("Output 1");
    //y->printIndexedBuffer("Output 2");
    //result.at(2)->printIndexedBuffer("Global norm is");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.isSameShape(y));
    ASSERT_TRUE(result.at(2)->isScalar());
    ASSERT_TRUE(exp.equalsTo(z));
    ASSERT_TRUE(exp.equalsTo(y));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0});
    auto exp = NDArrayFactory::create<double>({36.0, -48.0});

    sd::ops::matrix_determinant op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    //z->printIndexedBuffer("Output ");
    //exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {

    auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0});
    auto exp = NDArrayFactory::create<double>({-2.0, -2.0});

    sd::ops::matrix_determinant op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    //z->printIndexedBuffer("Output ");
    //exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {

    auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0});
    NDArray exp('c', {1}, std::vector<double>{-54.0});

    sd::ops::matrix_determinant op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    //z->printIndexedBuffer("Output ");
    //exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {

    auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0});
    auto exp = NDArrayFactory::create<double>('c', {1}, {189.0});

    sd::ops::matrix_determinant op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    // z->printIndexedBuffer("Output ");
    // exp.printIndexedBuffer("Expected ");
    // z->printShapeInfo("Output shape");
    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {

    auto x = NDArrayFactory::create<double>('c', {1, 4, 4});
    NDArray exp('c', {1}, std::vector<double>{-16.0});
    x.linspace(1);
    x.p(5, 4.0);
    x.p(12, 12.0);

    sd::ops::matrix_determinant op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    //z->printIndexedBuffer("Output ");
    //exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {

    auto x = NDArrayFactory::create<double>('c', {4, 4});
    auto exp = NDArrayFactory::create<double>(-16.0);
    x.linspace(1);
    x.p(5, 4.0);
    x.p(12, 12.0);

    sd::ops::matrix_determinant op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    //z->printIndexedBuffer("Output ");
    //z->printShapeInfo("Shape");
    //exp.printIndexedBuffer("Expected ");
    ASSERT_TRUE(z->isScalar());
    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0});
    auto exp = NDArrayFactory::create<double>({3.58351893845611, 3.871201010907891});

    sd::ops::log_matrix_determinant op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_1) {

    auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8});
    auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});

    sd::ops::logdet op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_2) {

    auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98});
    auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189});

    sd::ops::logdet op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_3) {

    auto x = NDArrayFactory::create<double>('c', {3, 3}, {4,12,-16,12,37,-43,-16,-43,98});
    auto exp = NDArrayFactory::create<double>( 3.5835189);

    sd::ops::logdet op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_1) {

    auto x = NDArrayFactory::create<float>('c', {2, 5, 5}, {
                    2.f, 4.f, 60.f, 8.f, 10.f,
                    0.f, 1.f, 2.f, 3.f, 4.f,
                    0.f, 0.f, 2.f, 4.f, 6.f,
                    0.f, 0.f, 0.f, 1.f, 2.f,
                    0.f, 0.f, 0.f, 0.f, 4.f,

                     1.f, 0.f, 0.f, 0.f, 0.f,
                     2.f, 1.f, 0.f, 0.f, 0.f,
                    30.f, 2.f, 1.f, 0.f, 0.f,
                     4.f, 3.f, 2.f, 1.f, 0.f,
                     5.f, 4.f, 3.f, 2.f, 1.f
    });

    auto exp = NDArrayFactory::create<float>('c', {2, 5, 5}, {
                    0.5f, -2.0f, -13.0f, 54.0f, -6.75f,
                    0.0f, 1.0f, -1.0f, 1.0f, 0.0f,
                      0.f, 0.f, 0.5f, -2.0f, 0.25f,
                      0.f, 0.f, 0.f, 1.0f, -0.5f,
                      0.f, 0.f, 0.f, 0.f, 0.25f,

                    1.0f, 0.0f, 0.0f, 0.0f, 0.f,
                   -2.0f, 1.0f, 0.f, 0.f, 0.f,
                  -26.0f, -2.0f, 1.f, 0.f, 0.f,
                   54.0f, 1.0f, -2.0f, 1.f, 0.f,
                  -27.0f, 0.0f, 1.0f, -2.0f, 1.f,
    });

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_010) {

    auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f, });
    auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f});

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_01) {

    auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f });

    auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f });
    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_02) {

    auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f });
    auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f });

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
/*
TEST_F(DeclarableOpsTests6, MatrixInverse_2) {

    auto x = NDArrayFactory::create<double>('c', {2, 5, 5}, {
                    1.,  2., 30.,  4.,  5.,
                    0.,  1.,  2.,  3.,  4.,
                    0.,  0.,  1.,  2.,  3.,
                    0.,  0.,  0.,  1.,  2.,
                    0.,  0.,  0.,  0.,  1.,

                     4.,   0.,  0.,  0.,  0.,
                     4.,   2.,  0.,  0.,  0.,
                    30.,   2.,  1.,  0.,  0.,
                     8.,   6.,  4.,  2.,  0.,
                    15.,  12.,  9.,  6.,  3.,
    });

    auto exp = NDArrayFactory::create<double>('c', {2, 5, 5}, {
     1.0,  -2.0,  -26.0,  54.0, -27.0,
     0.0,   1.0,  -2.0,    1.0,   0.0,
     0.0,   0.0,   1.0,   -2.0,   1.0,
     0.0,   0.0,   0.0,    1.0,  -2.0,
     0.0,   0.0,   0.0,    0.0,   1.0,

     0.25,  0.0,    0.0,   0.0,   0.0,
    -0.50,  0.5,    0.0,   0.0,   0.0,
    -6.50, -1.0,    1.0,   0.0,   0.0,
    13.50,  0.5,   -2.0,   0.5,   0.0,
    -6.75,  0.0,    1.0,  -1.0,   0.33333333
    });

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    z->printIndexedBuffer("Output ");
    exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}
*/
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {

    auto x = NDArrayFactory::create<float>('c', {5, 5}, {
            4.f,   0.f,  0.f,  0.f,  0.f,
            4.f,   2.f,  0.f,  0.f,  0.f,
           30.f,   2.f,  1.f,  0.f,  0.f,
            8.f,   6.f,  4.f,  2.f,  0.f,
           15.f,  12.f,  9.f,  6.f,  3.f,
    });

    auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
            0.25f,  0.0f,    0.0f,   0.0f,   0.0f,
            -0.50f,  0.5f,    0.0f,   0.0f,   0.0f,
            -6.50f, -1.0f,    1.0f,   0.0f,   0.0f,
            13.50f,  0.5f,   -2.0f,   0.5f,   0.0f,
            -6.75f,  0.0f,    1.0f,  -1.0f,   0.33333333f
    });

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
//    z->printIndexedBuffer("Output ");
//    exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {

    auto x = NDArrayFactory::create<float>('c', {5, 5}, {
                     4.f,   0.f,  0.f,  0.f,  0.f,
                     4.f,   2.f,  0.f,  0.f,  0.f,
                    30.f,   2.f,  1.f,  0.f,  0.f,
                     8.f,   6.f,  4.f,  2.f,  0.f,
                    15.f,  12.f,  9.f,  6.f,  3.f,
    });

    auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
     0.25f,  0.0f,    0.0f,   0.0f,   0.0f,
    -0.50f,  0.5f,    0.0f,   0.0f,   0.0f,
    -6.50f, -1.0f,    1.0f,   0.0f,   0.0f,
    13.50f,  0.5f,   -2.0f,   0.5f,   0.0f,
    -6.75f,  0.0f,    1.0f,  -1.0f,   0.33333333f
    });

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
//    exp.printIndexedBuffer("Expected ");
//    z->printIndexedBuffer("Output ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_4) {

    auto x = NDArrayFactory::create<float>('c', {5, 5}, {
                    1.f,  2.f, 30.f,  4.f,  5.f,
                    0.f,  1.f,  2.f,  3.f,  4.f,
                    0.f,  0.f,  1.f,  2.f,  3.f,
                    0.f,  0.f,  0.f,  1.f,  2.f,
                    0.f,  0.f,  0.f,  0.f,  1.f
    });

    auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
     1.0f,  -2.0f,  -26.0f,  54.0f, -27.0f,
     0.0f,   1.0f,  -2.0f,    1.0f,   0.0f,
     0.0f,   0.0f,   1.0f,   -2.0f,   1.0f,
     0.0f,   0.0f,   0.0f,    1.0f,  -2.0f,
     0.0f,   0.0f,   0.0f,    0.0f,   1.0f
    });

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
//    z->printIndexedBuffer("Output ");
//    exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_04) {

    auto x = NDArrayFactory::create<float>('c', {5, 5}, {
            1.f,  2.f, 30.f,  4.f,  5.f,
            0.f,  1.f,  2.f,  3.f,  4.f,
            0.f,  0.f,  1.f,  2.f,  3.f,
            0.f,  0.f,  0.f,  1.f,  2.f,
            0.f,  0.f,  0.f,  0.f,  1.f
    });

    auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
            1.0f,  -2.0f,  -26.0f,  54.0f, -27.0f,
            0.0f,   1.0f,  -2.0f,    1.0f,   0.0f,
            0.0f,   0.0f,   1.0f,   -2.0f,   1.0f,
            0.0f,   0.0f,   0.0f,    1.0f,  -2.0f,
            0.0f,   0.0f,   0.0f,    0.0f,   1.0f
    });

    sd::ops::matrix_inverse op;
    auto result = op.evaluate({&x});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
//    z->printIndexedBuffer("Output ");
//    exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ReluLayer_1) {
    auto x = NDArrayFactory::create<double>('c', {3, 4}, {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12});
    auto w = NDArrayFactory::create<double>('c', {4, 3}, {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25});
    auto b = NDArrayFactory::create<double>({20.0, 30.0, 50.0});



    auto exp = NDArrayFactory::create<double>('c', {3, 3}, {
                        21.4,  30.45, 52.3,
                        23.8,  31.05, 56.5,
                        26.2,  31.65, 60.7});

    sd::ops::relu_layer op;
    auto result = op.evaluate({&x, &w, &b});

    ASSERT_EQ(ND4J_STATUS_OK, result.status());

    auto z = result.at(0);
    // z->printShapeInfo("Output shape");
    // z->printIndexedBuffer("Output ");
    // exp.printIndexedBuffer("Expected ");

    ASSERT_TRUE(exp.isSameShape(z));
    ASSERT_TRUE(exp.equalsTo(z));


}

TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) {
    auto x = NDArrayFactory::create<double>('c', {3, 4, 5});
    auto y = NDArrayFactory::create<double>('c', {3, 4, 5});


    std::vector<int> dims = {0, 1};
    auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims);
    ASSERT_TRUE(&z != nullptr);
}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test1) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
    auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3});

    x.linspace(0.01, 0.01);
    h0 = 0.2;
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484, 0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 ,
                                                          0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0.        ,
                                                          0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0.        ,0., 0., 0., 0.,0., 0., 0., 0.});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});

    sd::ops::static_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test2) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});

    x.linspace(0.01, 0.01);
    h0 = 0.2;
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333,
                                                          0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0.97338548, 0.97338548, 0.97338548, 0.97338548,
                                                          0.97732812, 0.97732812, 0.97732812, 0.97732812,0.97864398, 0.97864398, 0.97864398, 0.97864398,0.98000654, 0.98000654, 0.98000654, 0.98000654,
                                                          0.98112648, 0.98112648, 0.98112648, 0.98112648});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648});

    sd::ops::static_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test3) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
    auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, 0});

    x.linspace(0.01, 0.01);
    h0 = 0.2;
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0., 0., 0., 0., 0.9312333, 0.9312333, 0.9312333, 0.9312333,
                                                          0., 0., 0., 0.           , 0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. ,
                                                          0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2       , 0.2       , 0.2       , 0.2});

    sd::ops::static_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test4) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
    auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3});

    x.linspace(0.01, 0.01);
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664,
                                                          0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0., 0., 0., 0.        ,
                                                          0.97688859, 0.97688859, 0.97688859, 0.97688859,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882});

    sd::ops::static_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test5) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});

    x.linspace(0.01, 0.01);
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664,
                                                          0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0.96849345, 0.96849345, 0.96849345, 0.96849345,
                                                          0.97688859, 0.97688859, 0.97688859, 0.97688859,0.97831069, 0.97831069, 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868,
                                                          0.98110653, 0.98110653, 0.98110653, 0.98110653});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653});

    sd::ops::static_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
    auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
    auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3, time-4, 0});

    x.linspace(0.01, 0.01);
    h0FW = 0.2;
    h0BW = 0.25;
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881,0.86708881,0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842,0.78347842,
                                                                       0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176,0.55529176,0., 0., 0., 0., 0.,0.,0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605,
                                                                       0.90935605, 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945,0.64692945,0., 0., 0., 0., 0.,0.,0., 0., 0., 0., 0.,0.,
                                                                       0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,
                                                                       0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,
                                                                       0., 0., 0., 0., 0., 0.,       0., 0., 0., 0., 0., 0.,       0., 0., 0., 0., 0., 0.,       0., 0., 0., 0., 0., 0.});

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2});
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25});

    sd::ops::static_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW,  &h0FW, &h0BW, &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFWfinal = results.at(1);
    auto hBWfinal = results.at(2);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3, time-4, 0});

    x.linspace(0.01, 0.01);
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273,0.86518273,0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761,0.66617761,
                                                                       0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203,0.31492203,0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,
                                                                       0.60005558, 0.60005558, 0.60005558, 0.9029975 , 0.9029975 ,0.9029975 ,0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931,0.43819931,
                                                                       0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,
                                                                       0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032,0.88852032,0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,
                                                                       0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,
                                                                       0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775,0.66737775,0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,
                                                                       0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,0.        , 0.        , 0.        , 0.        , 0.        ,0.        ,
                                                                       0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.});

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, 0.31492203, 0.31492203, 0.31492203, 0.        , 0.        , 0.});
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0.        , 0.        , 0.});

    sd::ops::static_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW,  &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFWfinal = results.at(1);
    auto hBWfinal = results.at(2);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}


///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    x.linspace(0.01, 0.01);
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012,0.86841012,0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531,0.88207531,
                                                                       0.31492203, 0.31492203, 0.31492203, 0.8941667 , 0.8941667 ,0.8941667 ,0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713,
                                                                       0.90489713, 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375,0.91381375,0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504,
                                                                       0.92253504,0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876,0.93027876,0.75947891, 0.75947891, 0.75947891, 0.9371767 , 0.9371767 ,
                                                                       0.9371767 , 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274,0.94014274,0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926,
                                                                       0.94648926,0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779,0.95204779,0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206,
                                                                       0.95694206, 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086,0.93773086,0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176,
                                                                       0.94579176,0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886,0.95267886,0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985,
                                                                       0.95857985, 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293,0.76075293,0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637,
                                                                       0.78024637,0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344,0.79833344,0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646,0.81508646});

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195});
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713});

    sd::ops::static_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFWfinal = results.at(1);
    auto hBWfinal = results.at(2);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
    auto maxTimeStep = NDArrayFactory::create<Nd4jLong>('c', {bS}, {time-1, time-3});

    x.linspace(0.01, 0.01);
    h0 = 0.2;
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 ,
                                                          0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0.        , 0.        , 0.        , 0.        ,
                                                          0.97732812, 0.97732812, 0.97732812, 0.97732812,0.    , 0.  , 0.  , 0. ,0.   , 0.  , 0.   , 0.  ,0.      , 0.        , 0.        , 0.        });

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});

    sd::ops::dynamic_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}


///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
    auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time});

    x.linspace(0.01, 0.01);
    h0 = 0.2;
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334,
                                                          0.96778334,0.97309129, 0.97309129, 0.97309129, 0.97309129,0.        , 0.        , 0.        , 0.        ,
                                                          0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828,
                                                          0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782});

    sd::ops::dynamic_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});

    x.linspace(0.01, 0.01);
    h0 = 0.2;
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, 0.96778334,0.97309129,
                                                          0.97309129, 0.97309129, 0.97309129,0.97491207, 0.97491207, 0.97491207, 0.97491207,0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491,
                                                          0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782});

    sd::ops::dynamic_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});
    auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-4});

    x.linspace(0.01, 0.01);
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545,
                                                          0.96059545, 0.96059545,0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.        , 0.        , 0.        , 0.        ,
                                                          0.57368608, 0.57368608, 0.57368608, 0.57368608,0. , 0. , 0  , 0. ,0., 0. , 0, 0.,0., 0., 0. , 0. ,0. , 0. , 0., 0. });

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608});

    sd::ops::dynamic_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {

    const int bS       = 2;
    const int inSize   = 3;
    const int numUnits = 4;
    const int time     = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
    auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
    auto b  = NDArrayFactory::create<double>('c', {2*numUnits});

    x.linspace(0.01, 0.01);
    Wx = 0.3;
    Wh = 0.4;
    b  = 0.25;

    auto expH      = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, 0.96059545, 0.96059545,
                                                          0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.97486307, 0.97486307, 0.97486307, 0.97486307,0.57368608, 0.57368608, 0.57368608, 0.57368608,
                                                          0.92135149, 0.92135149, 0.92135149, 0.92135149,0.97482354, 0.97482354, 0.97482354, 0.97482354,0.97984727, 0.97984727, 0.97984727, 0.97984727,
                                                          0.98119833, 0.98119833, 0.98119833, 0.98119833});

    auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits},       {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833});

    sd::ops::dynamic_rnn op;
    auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto h = results.at(0);
    auto hFinal = results.at(1);

    ASSERT_TRUE(expH.isSameShape(h));
    ASSERT_TRUE(expH.equalsTo(h));
    ASSERT_TRUE(expHFinal.isSameShape(hFinal));
    ASSERT_TRUE(expHFinal.equalsTo(hFinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {time, bS, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
    auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
    auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});

    x.linspace(0.01, 0.01);
    h0FW = 0.2;
    h0BW = 0.25;
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expHFW  = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.47615493, 0.47615493, 0.47615493,0.51241561, 0.51241561, 0.51241561,0.        , 0.        , 0.        ,
                                                          0.73880324, 0.73880324, 0.73880324,0.77843476, 0.77843476, 0.77843476,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.9052501 , 0.9052501 , 0.9052501 ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.9555734 , 0.9555734 , 0.9555734 ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        });

    auto expHBW  = NDArrayFactory::create<double>('c', {time, bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881,0.78347842, 0.78347842, 0.78347842,0.55529176, 0.55529176, 0.55529176,0.        , 0.        , 0.        ,
                                                          0.90935605, 0.90935605, 0.90935605,0.64692945, 0.64692945, 0.64692945,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.9181592 , 0.9181592 , 0.9181592 ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.8026439 , 0.8026439 , 0.8026439 ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        });

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2       , 0.2       , 0.2});
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25      , 0.25      , 0.25});

    sd::ops::dynamic_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW,  &h0FW, &h0BW, &maxTimeStep}, {}, {1});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto hFW = results.at(0);
    auto hBW = results.at(1);
    auto hFWfinal = results.at(2);
    auto hBWfinal = results.at(3);

    ASSERT_TRUE(expHFW.isSameShape(hFW));
    ASSERT_TRUE(expHFW.equalsTo(hFW));
    ASSERT_TRUE(expHBW.isSameShape(hBW));
    ASSERT_TRUE(expHBW.equalsTo(hBW));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
    auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
    auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});

    x.linspace(0.01, 0.01);
    h0FW = 0.2;
    h0BW = 0.25;
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expHFW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0.        , 0.        , 0.        ,
                                                          0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.73978305, 0.73978305, 0.73978305,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        });

    auto expHBW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207,0.83584708, 0.83584708, 0.83584708,0.77435951, 0.77435951, 0.77435951,0.58760492, 0.58760492, 0.58760492,0.        , 0.        , 0.        ,
                                                          0.85615841, 0.85615841, 0.85615841,0.67397984, 0.67397984, 0.67397984,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.76576202, 0.76576202, 0.76576202,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        });

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.87294706, 0.87294706, 0.87294706,0.84851124, 0.84851124, 0.84851124,0.73978305, 0.73978305, 0.73978305,0.2       , 0.2       , 0.2});
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25      , 0.25      , 0.25});

    sd::ops::dynamic_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW,  &h0FW, &h0BW, &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto hFW = results.at(0);
    auto hBW = results.at(1);
    auto hFWfinal = results.at(2);
    auto hBWfinal = results.at(3);

    ASSERT_TRUE(expHFW.isSameShape(hFW));
    ASSERT_TRUE(expHFW.equalsTo(hFW));
    ASSERT_TRUE(expHBW.isSameShape(hBW));
    ASSERT_TRUE(expHBW.equalsTo(hBW));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});

    x.linspace(0.01, 0.01);
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expHFW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0.        , 0.        , 0.        ,
                                                          0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.61067683, 0.61067683, 0.61067683,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        });

    auto expHBW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707,0.77935851, 0.77935851, 0.77935851,0.6381121 , 0.6381121 , 0.6381121 ,0.35748551, 0.35748551, 0.35748551,0.        , 0.        , 0.        ,
                                                          0.77843476, 0.77843476, 0.77843476,0.47615493, 0.47615493, 0.47615493,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.61067683, 0.61067683, 0.61067683,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,
                                                          0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        ,0.        , 0.        , 0.        });

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.84784327, 0.84784327, 0.84784327, 0.7793996 , 0.7793996 , 0.7793996 , 0.61067683, 0.61067683, 0.61067683, 0.        , 0.        , 0.});
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0.        , 0.        , 0.});

    sd::ops::dynamic_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW,  &maxTimeStep}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto hFW = results.at(0);
    auto hBW = results.at(1);
    auto hFWfinal = results.at(2);
    auto hBWfinal = results.at(3);

    ASSERT_TRUE(expHFW.isSameShape(hFW));
    ASSERT_TRUE(expHFW.equalsTo(hFW));
    ASSERT_TRUE(expHBW.isSameShape(hBW));
    ASSERT_TRUE(expHBW.equalsTo(hBW));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}

///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
    auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});

    x.linspace(0.01, 0.01);
    h0FW = 0.2;
    h0BW = 0.25;
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expHFW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0.89948899, 0.89948899, 0.89948899,
                                                          0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0.91925737, 0.91925737, 0.91925737,0.93751395, 0.93751395, 0.93751395,0.94544483, 0.94544483, 0.94544483,
                                                          0.73978305, 0.73978305, 0.73978305,0.92827068, 0.92827068, 0.92827068,0.95791111, 0.95791111, 0.95791111,0.96427356, 0.96427356, 0.96427356,0.96797541, 0.96797541, 0.96797541,
                                                          0.83057887, 0.83057887, 0.83057887,0.96365083, 0.96365083, 0.96365083,0.97585698, 0.97585698, 0.97585698,0.97866981, 0.97866981, 0.97866981,0.9807326 , 0.9807326 , 0.9807326 });

    auto expHBW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722,0.86427295, 0.86427295, 0.86427295,0.8599919 , 0.8599919 , 0.8599919 ,0.80609463, 0.80609463, 0.80609463,0.61814662, 0.61814662, 0.61814662,
                                                          0.91888753, 0.91888753, 0.91888753,0.92652672, 0.92652672, 0.92652672,0.92939674, 0.92939674, 0.92939674,0.90661931, 0.90661931, 0.90661931,0.74516764, 0.74516764, 0.74516764,
                                                          0.95254269, 0.95254269, 0.95254269,0.95710717, 0.95710717, 0.95710717,0.96021584, 0.96021584, 0.96021584,0.95222547, 0.95222547, 0.95222547,0.83426363, 0.83426363, 0.83426363,
                                                          0.97154357, 0.97154357, 0.97154357,0.97424915, 0.97424915, 0.97424915,0.97644817, 0.97644817, 0.97644817,0.97410547, 0.97410547, 0.97410547,0.89409962, 0.89409962, 0.89409962});

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, 0.96797541, 0.96797541, 0.96797541, 0.9807326 , 0.9807326 , 0.9807326 });
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357});

    sd::ops::dynamic_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW,  &h0FW, &h0BW}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto hFW = results.at(0);
    auto hBW = results.at(1);
    auto hFWfinal = results.at(2);
    auto hBWfinal = results.at(3);

    ASSERT_TRUE(expHFW.isSameShape(hFW));
    ASSERT_TRUE(expHFW.equalsTo(hFW));
    ASSERT_TRUE(expHBW.isSameShape(hBW));
    ASSERT_TRUE(expHBW.equalsTo(hBW));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}

TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {

    const int bS         = 4;
    const int inSize     = 4;
    const int numUnitsFW = 3;
    const int numUnitsBW = 3;
    const int time       = 5;

    auto x  = NDArrayFactory::create<double>('c', {bS, time, inSize});
    auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
    auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
    auto bFW  = NDArrayFactory::create<double>('c', {2*numUnitsFW});

    x.linspace(0.01, 0.01);
    WxFW = 0.3;
    WhFW = 0.4;
    bFW  = 0.1;

    auto expHFW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0.89357928, 0.89357928, 0.89357928,
                                                          0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0.9053792 , 0.9053792 , 0.9053792 ,0.93546593, 0.93546593, 0.93546593,0.94518339, 0.94518339, 0.94518339,
                                                          0.61067683, 0.61067683, 0.61067683,0.90347408, 0.90347408, 0.90347408,0.95538786, 0.95538786, 0.95538786,0.96406045, 0.96406045, 0.96406045,0.96795929, 0.96795929, 0.96795929,
                                                          0.73978305, 0.73978305, 0.73978305,0.95499984, 0.95499984, 0.95499984,0.97535671, 0.97535671, 0.97535671,0.97864446, 0.97864446, 0.97864446,0.98073144, 0.98073144, 0.98073144});

    auto expHBW  = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345,0.85160683, 0.85160683, 0.85160683,0.81997657, 0.81997657, 0.81997657,0.69228829, 0.69228829, 0.69228829,0.39861399, 0.39861399, 0.39861399,
                                                          0.91865453, 0.91865453, 0.91865453,0.92528094, 0.92528094, 0.92528094,0.92212167, 0.92212167, 0.92212167,0.86418213, 0.86418213, 0.86418213,0.57969286, 0.57969286, 0.57969286,
                                                          0.95252666, 0.95252666, 0.95252666,0.95696305, 0.95696305, 0.95696305,0.95878749, 0.95878749, 0.95878749,0.93722463, 0.93722463, 0.93722463,0.71727031, 0.71727031, 0.71727031,
                                                          0.97154234, 0.97154234, 0.97154234,0.97423089, 0.97423089, 0.97423089,0.976149  , 0.976149  , 0.976149  ,0.96878298, 0.96878298, 0.96878298,0.81508646, 0.81508646, 0.81508646});

    auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW},  {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144});
    auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW},  {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234});

    sd::ops::dynamic_bidirectional_rnn op;
    auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW,  &WxFW,&WhFW,&bFW}, {}, {});

    ASSERT_EQ(ND4J_STATUS_OK, results.status());

    auto hFW = results.at(0);
    auto hBW = results.at(1);
    auto hFWfinal = results.at(2);
    auto hBWfinal = results.at(3);

    ASSERT_TRUE(expHFW.isSameShape(hFW));
    ASSERT_TRUE(expHFW.equalsTo(hFW));
    ASSERT_TRUE(expHBW.isSameShape(hBW));
    ASSERT_TRUE(expHBW.equalsTo(hBW));
    ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
    ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
    ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
    ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));


}


TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
    auto x = NDArrayFactory::create<double>('c', {3}, {0.15f, 0.25f, 0.35f});
    auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.15f, 0.0f, 0.0f,   0.0f, 0.25f, 0.0f,   0.0f, 0.0f, 0.35f});

    sd::ops::diag op;
    auto result = op.evaluate({&x}, {}, {});
    ASSERT_EQ(Status::OK(), result.status());

    ASSERT_EQ(e, *result.at(0));


}

TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
    auto x = NDArrayFactory::create<double>('c', {1}, {0.15f});
    auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});

    sd::ops::diag op;
    auto result = op.evaluate({&x}, {}, {});
    ASSERT_EQ(Status::OK(), result.status());

    ASSERT_EQ(e, *result.at(0));


}

TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
    auto x = NDArrayFactory::create<double>(0.15f);
    auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});

    sd::ops::diag op;
    auto result = op.evaluate({&x}, {}, {});
    ASSERT_EQ(Status::OK(), result.status());

    ASSERT_EQ(e, *result.at(0));


}