cavis/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp
raver119 3c4e959e21 [WIP] More of CUDA (#95)
* initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* Implementation of hashcode cuda helper. Working edition.

* Fixed parallel test input arangements.

* Fixed tests for hashcode op.

* Fixed shape calculation for image:crop_and_resize op and test.

* NativeOps tests. Initial test suite.

* Added tests for indexReduce methods.

* Added test on execBroadcast with NDArray as dimensions.

* Added test on execBroadcastBool with NDArray as dimensions.

* Added tests on execPairwiseTransform and execPairwiseTransofrmBool.

* Added tests for execReduce with scalar results.

* Added reduce tests for non-empty dims array.

* Added tests for reduce3.

* Added tests for execScalar.

* Added tests for execSummaryStats.

* - provide cpu/cuda code for batch_to_space
- testing it

Signed-off-by: Yurii <yurii@skymind.io>

* - remove old test for batch_to_space (had wrong format and numbers were not checked)

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed complilation errors with test.

* Added test for execTransformFloat.

* Added test for execTransformSame.

* Added test for execTransformBool.

* Added test for execTransformStrict.

* Added tests for execScalar/execScalarBool with TADs.

* Added test for flatten.

* - provide cpu/cuda code for space_to_Batch operaion

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for concat.

* comment unnecessary stuff in s_t_b

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for specialConcat.

* Added tests for memcpy/set routines.

* Fixed pullRow cuda test.

* Added pullRow test.

* Added average test.

* - correct typo in NDArray::applyPairwiseTransform(nd4j::pairwise::BoolOps op...)

Signed-off-by: Yurii <yurii@skymind.io>

* - debugging and fixing cuda tests in JavaInteropTests file

Signed-off-by: Yurii <yurii@skymind.io>

* - correct some tests

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for shuffle.

* Fixed ops declarations.

* Restored omp and added shuffle test.

* Added convertTypes test.

* Added tests for execRandom. Eliminated usage of RandomBuffer with NativeOps.

* Added sort tests.

* Added tests for execCustomOp.

* - further debuging and fixing tests terminated with crash

Signed-off-by: Yurii <yurii@skymind.io>

* Added tests for calculateOutputShapes.

* Addded Benchmarks test.

* Commented benchmark tests.

* change assertion

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for apply_sgd op. Added cpu helper for that op.

* Implement cuda helper for aplly_sgd op. Fixed tests for NativeOps.

* Added test for assign broadcastable.

* Added tests for assign_bp op.

* Added tests for axpy op.

* - assign/execScalar/execTransformAny signature change
- minor test fix

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed axpy op.

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* - fix tests for nativeOps::concat

Signed-off-by: Yurii <yurii@skymind.io>

* sequential transform/scalar

Signed-off-by: raver119 <raver119@gmail.com>

* allow nested parallelism

Signed-off-by: raver119 <raver119@gmail.com>

* assign_bp leak fix

Signed-off-by: raver119 <raver119@gmail.com>

* block setRNG fix

Signed-off-by: raver119 <raver119@gmail.com>

* enable parallelism by default

Signed-off-by: raver119 <raver119@gmail.com>

* enable nested parallelism by default

Signed-off-by: raver119 <raver119@gmail.com>

* Added cuda implementation for row_count helper.

* Added implementation for tnse gains op helper.

* - take into account possible situations when input arrays are empty in reduce_ cuda stuff

Signed-off-by: Yurii <yurii@skymind.io>

* Implemented tsne/edge_forces op cuda-based helper. Parallelized cpu-based helper for edge_forces.

* Added kernel for tsne/symmetrized op heleper.

* Implementation of tsne/symmetrized op cuda helper. Working edition.

* Eliminated waste printfs.

* Added test for broadcastgradientargs op.

* host-only fallback for empty reduce float

Signed-off-by: raver119 <raver119@gmail.com>

* - some tests fixes

Signed-off-by: Yurii <yurii@skymind.io>

* - correct the rest of reduce_ stuff

Signed-off-by: Yurii <yurii@skymind.io>

* - further correction of reduce_ stuff

Signed-off-by: Yurii <yurii@skymind.io>

* Added test for Cbow op. Also added cuda implementation for cbow helpers.

* - improve code of stack operation for scalar case

Signed-off-by: Yurii <yurii@skymind.io>

* - provide cuda kernel for gatherND operation

Signed-off-by: Yurii <yurii@skymind.io>

* Implementation of cbow helpers with cuda kernels.

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* minor tests tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* - further correction of cuda stuff

Signed-off-by: Yurii <yurii@skymind.io>

* Implementatation of cbow op helper with cuda kernels. Working edition.

* Skip random testing for cudablas case.

* lstmBlockCell context fix

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for ELU and ELU_BP ops.

* Added tests for eq_scalar, gt_scalar, gte_scalar and lte_scalar ops.

* Added tests for neq_scalar.

* Added test for noop.

* - further work on clipbynorm_bp

Signed-off-by: Yurii <yurii@skymind.io>

* - get rid of concat op call, use instead direct concat helper call

Signed-off-by: Yurii <yurii@skymind.io>

* lstmBlockCell context fix

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for lrelu and lrelu_bp.

* Added tests for selu and selu_bp.

* Fixed lrelu derivative helpers.

* - some corrections in lstm

Signed-off-by: Yurii <yurii@skymind.io>

* operator * result shape fix

Signed-off-by: raver119 <raver119@gmail.com>

* - correct typo in lstmCell

Signed-off-by: Yurii <yurii@skymind.io>

* few tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA inverse broadcast bool fix

Signed-off-by: raver119 <raver119@gmail.com>

* disable MMAP test for CUDA

Signed-off-by: raver119 <raver119@gmail.com>

* BooleanOp syncToDevice

Signed-off-by: raver119 <raver119@gmail.com>

* meh

Signed-off-by: raver119 <raver119@gmail.com>

* additional data types for im2col/col2im

Signed-off-by: raver119 <raver119@gmail.com>

* Added test for firas_sparse op.

* one more RandomBuffer test excluded

Signed-off-by: raver119 <raver119@gmail.com>

* Added tests for flatten op.

* Added test for Floor op.

* bunch of tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* mmulDot tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* Implemented floordiv_bp op and tests.

* Fixed scalar case with cuda implementation for bds.

* - work on cuda kernel for clip_by_norm backprop op is completed

Signed-off-by: Yurii <yurii@skymind.io>

* Eliminate cbow crach.

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* Eliminated abortion with batched nlp test.

* more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed shared flag initializing.

* disabled bunch of cpu workspaces tests

Signed-off-by: raver119 <raver119@gmail.com>

* scalar operators fix: missing registerSpecialUse call

Signed-off-by: raver119 <raver119@gmail.com>

* Fixed logdet for cuda and tests.

* - correct clipBynorm_bp

Signed-off-by: Yurii <yurii@skymind.io>

* Fixed crop_and_resize shape datatype.

* - correct some mmul tests

Signed-off-by: Yurii <yurii@skymind.io>
2019-08-05 11:27:05 +10:00

2786 lines
107 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 09.02.18.
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <helpers/helper_hash.h>
#include <NDArray.h>
#include <array/NDArrayList.h>
using namespace nd4j;
using namespace nd4j::graph;
class DeclarableOpsTests6 : public testing::Test {
public:
DeclarableOpsTests6() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
auto matrix = NDArrayFactory::create<double>('c', {5, 2});
auto b = NDArrayFactory::create<double>('c', {1}, {0.});
auto e = NDArrayFactory::create<double>('c', {1}, {1});
auto s = NDArrayFactory::create<double>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
auto matrix = NDArrayFactory::create<double>('c', {5, 2});
auto b = NDArrayFactory::create<double>('c', {1}, {0.0f});
auto e = NDArrayFactory::create<double>('c', {1}, {1.0f});
auto s = NDArrayFactory::create<double>('c', {1}, {1.0f});
auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
auto matrix = NDArrayFactory::create<double>(10);
auto b = NDArrayFactory::create<double>(0);
auto e = NDArrayFactory::create<double>(0);
auto s = NDArrayFactory::create<double>(1.0);
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
//z->printShapeInfo("SS OS shape");
ASSERT_TRUE(z->isEmpty());
//ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
auto matrix = NDArrayFactory::create<double>('c', {1}, {10});
auto b = NDArrayFactory::create<double>('c', {1}, {0.});
auto e = NDArrayFactory::create<double>('c', {1}, {0.});
auto s = NDArrayFactory::create<double>('c', {1}, {1.0});
auto exp = NDArrayFactory::create<double>(10);
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("SS OS shape");
z->printIndexedBuffer("SS OS out");
ASSERT_TRUE(z->equalsTo(exp));
//ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
auto matrix = NDArrayFactory::create<double>('c', {1}, {10});
auto b = NDArrayFactory::create_<int>('c', {1}, {1});
auto e = NDArrayFactory::create_<int>('c', {1}, {(int)0});
auto s = NDArrayFactory::create_<int>('c', {1}, {1});
nd4j::ops::ones_as opOnes;
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
auto onesRes = opOnes.execute({&matrix}, {}, {});
//matrix.linspace(1);
ASSERT_EQ(onesRes->status(), Status::OK());
auto ones = onesRes->at(0);
ones->printShapeInfo("Shape ones");
*ones *= 10;
auto onesD = ones->dup();
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, onesD);
variableSpace->putVariable(-2, b);
variableSpace->putVariable(-3, e);
variableSpace->putVariable(-4, s);
auto block = new Context(1, variableSpace, false); // not-in-place
block->fillInputs({-1});
block->fillInputs({-2});
block->fillInputs({-3});
block->fillInputs({-4});
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
auto inputShapes = new ShapeList({ones->getShapeInfo(), b->getShapeInfo(), e->getShapeInfo(), s->getShapeInfo()});
nd4j::ops::strided_slice op;
auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0});
ASSERT_EQ(result->size(), 1);
shape::printShapeInfoLinear(result->at(0));
//auto z = result->at(0);
// z->printShapeInfo("SS OS shape");
ASSERT_TRUE(shape::isEmpty(result->at(0)));
//ASSERT_EQ(exp, *z);
delete block;
delete onesRes;
delete result;
delete variableSpace;
delete inputShapes;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
auto matrix = NDArrayFactory::create<double>('c', {3, 2, 2});
auto b = NDArrayFactory::create<int>('c', {1}, {2});
auto e = NDArrayFactory::create<int>('c', {1}, {3});
auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {2,2}, {0.0f, 0.0f, 0., 0.});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
auto matrix = NDArrayFactory::create<double>('c', {3, 2, 2});
auto b = NDArrayFactory::create<int>('c', {1}, {2});
auto e = NDArrayFactory::create<int>('c', {1}, {3});
auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {1,2,2}, {0.0f, 0.0f, 0., 0.});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
int zero = 0;
auto matrix = NDArrayFactory::create<double>('c', {5, 4});
auto b = NDArrayFactory::create<int>('c', {1}, {zero});
auto e = NDArrayFactory::create<int>('c', {1}, {zero});
auto s = NDArrayFactory::create<int>('c', {1}, {1});
//auto exp = NDArrayFactory::create<double>('c', {1,2,2}, {0.0f, 0.0f, 0., 0.});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
//ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
nd4j::ops::test_scalar op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_gatherNd_Edge_1) {
auto x = NDArrayFactory::create<double>('c', {2, 4, 2, 2});
auto indices = NDArrayFactory::create<int>('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1});
auto exp = NDArrayFactory::create<double>('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f});
x.linspace(1);
nd4j::ops::gather_nd op;
auto result = op.execute({&x, &indices}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
//z->printIndexedBuffer();
//z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Order_1) {
auto x = NDArrayFactory::create<double>('f', {2, 3});
auto exp = NDArrayFactory::create<double>('c', {2, 3});
x.linspace(1);
exp.linspace(1);
nd4j::ops::order op;
auto result = op.execute({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("O Output");
exp.printIndexedBuffer("O Expect");
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_NE(x.ordering(), z->ordering());
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_1) {
auto x = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 2.f, 3.f, 4.f});
auto exp = NDArrayFactory::create<float>('c', {1, 4}, {1.f, 3.f, 6.f, 10.f});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("CumSum1");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_2) {
auto x= NDArrayFactory::create<float>('c', {2, 4}, {1, 2, 3, 4, 1, 2, 3, 4});
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1, 3, 6, 10, 1, 3, 6, 10});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("CumSum1");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_3) {
auto x= NDArrayFactory::create<float>('c', {2, 4}, {1, 2, 3, 4, 1, 2, 3, 4});
auto exp= NDArrayFactory::create<float>('c', {2, 4}, {1, 2, 3, 4, 2, 4, 6, 8});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 0});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_4) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printBuffer();
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_5) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_6) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_7) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, cumSum_8) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
nd4j::ops::cumsum op;
auto result = op.execute({&x, &axis}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_9) {
auto inputC = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.});
auto axis = NDArrayFactory::create<Nd4jLong>(1);
auto expFF = NDArrayFactory::create<double>('c', {3, 5}, {1., 3., 6., 10., 15., 6., 13., 21., 30., 40., 11., 23., 36., 50., 65.});
auto expTF = NDArrayFactory::create<double>('c', {3, 5}, {0., 1., 3., 6., 10., 0., 6., 13., 21., 30., 0., 11., 23., 36., 50.});
auto expFT = NDArrayFactory::create<double>('c', {3, 5}, {15, 14, 12, 9, 5,40, 34, 27, 19, 10,65, 54, 42, 29, 15}); //+++
auto expTT = NDArrayFactory::create<double>('c', {3, 5}, {14, 12, 9, 5, 0,34, 27, 19, 10, 0,54, 42, 29, 15, 0});
int exclusive, reverse;
//************************************//
exclusive = 0; reverse = 0;
nd4j::ops::cumsum op;
auto result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(expFF.equalsTo(z));
delete result;
//************************************//
exclusive = 1; reverse = 0;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTF.equalsTo(z));
delete result;
//************************************//
exclusive = 0; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expFT.equalsTo(z));
delete result;
//************************************//
exclusive = 1; reverse = 1;
result = op.execute({&inputC, &axis}, {}, {exclusive, reverse}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
z = result->at(0);
ASSERT_TRUE(expTT.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_10) {
auto x = NDArrayFactory::create<double>('c', {4, 16, 16, 1});
auto y = NDArrayFactory::create<int>(-3);
nd4j::ops::cumsum op;
auto result = op.execute({&x, &y}, {}, {1, 1}, {});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_11) {
NDArray x('c', {3, 3, 3}, nd4j::DataType::DOUBLE);
auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {12., 15., 18.,11., 13., 15.,7., 8., 9., 39., 42., 45.,29., 31., 33.,16., 17., 18., 66., 69., 72.,47., 49., 51.,25., 26., 27.});
x.linspace(1);
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_12) {
NDArray x('c', {3, 3, 3}, nd4j::DataType::DOUBLE);
auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {1., 2., 3.,5., 7., 9.,12., 15., 18., 10., 11., 12.,23., 25., 27.,39., 42., 45., 19., 20., 21.,41., 43., 45., 66., 69., 72.});
x.linspace(1);
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_13) {
NDArray x('c', {3, 3, 3}, nd4j::DataType::DOUBLE);
auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {11., 13., 15.,7., 8., 9.,0., 0., 0., 29., 31., 33.,16., 17., 18.,0., 0., 0., 47., 49., 51.,25., 26., 27.,0., 0., 0.});
x.linspace(1);
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_14) {
NDArray x('c', {3, 3, 3}, nd4j::DataType::DOUBLE);
auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {29., 31., 33.,35., 37., 39.,41., 43., 45., 19., 20., 21.,22., 23., 24.,25., 26., 27., 0., 0., 0.,0., 0., 0.,0., 0., 0.});
x.linspace(1);
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_15) {
NDArray x('c', {3, 3, 3}, nd4j::DataType::DOUBLE);
auto exp = NDArrayFactory::create<double>('c', {3,3,3}, {6., 5., 3.,15., 11., 6.,24., 17., 9., 33., 23., 12.,42., 29., 15.,51., 35., 18., 60., 41., 21.,69., 47., 24.,78., 53., 27.});
x.linspace(1);
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 2}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_16) {
NDArray x('f', {3, 4}, nd4j::DataType::FLOAT32);
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printShapeInfo();
// x.printShapeInfo();
ASSERT_TRUE(z->ews() == 1);
ASSERT_TRUE(x.ews() == 1);
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_17) {
NDArray x('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray x0 = x(0, {0});
NDArray x1 = x(1, {0});
x0.linspace(1);
x1.linspace(1);
NDArray exp('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0});
exp0.p<float>(0, 1.);
exp1.p<float>(0, 1.);
for (int i = 1; i < 1500; ++i) {
const auto prev = exp0.e<float>(i-1);
exp0.p<float>(i, prev + i + 1);
exp1.p<float>(i, prev + i + 1);
}
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_18) {
NDArray x('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray x0 = x(0, {0});
NDArray x1 = x(1, {0});
x0.linspace(1);
x1.linspace(1);
NDArray exp('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0});
exp0.p<float>(0, 0.);
exp1.p<float>(0, 0.);
for (int i = 1; i < 1500; ++i) {
const auto prev = exp0.e<float>(i-1);
exp0.p<float>(i, prev + i);
exp1.p<float>(i, prev + i);
}
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_19) {
NDArray x('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray x0 = x(0, {0});
NDArray x1 = x(1, {0});
x0.linspace(1);
x1.linspace(1);
NDArray exp('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0});
exp0.p<float>(1499, 1500.);
exp1.p<float>(1499, 1500.);
for (int i = 1498; i >= 0; --i) {
const auto prev = exp0.e<float>(i + 1);
exp0.p<float>(i, prev + i + 1);
exp1.p<float>(i, prev + i + 1);
}
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// exp0.printBuffer();
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, cumSum_20) {
NDArray x('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray x0 = x(0, {0});
NDArray x1 = x(1, {0});
x0.linspace(1);
x1.linspace(1);
NDArray exp('c', {2, 1500}, nd4j::DataType::FLOAT32);
NDArray exp0 = exp(0, {0});
NDArray exp1 = exp(1, {0});
exp0.p<float>(1499, 0.);
exp1.p<float>(1499, 0.);
for (int i = 1498; i >= 0; --i) {
const auto prev = exp0.e<float>(i + 1);
exp0.p<float>(i, prev + i + 2);
exp1.p<float>(i, prev + i + 2);
}
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestDropout_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto shape = NDArrayFactory::create<Nd4jLong>({2, 2});
nd4j::ops::dropout op;
auto ress = op.execute({&x, &shape}, {0.2f}, {113}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//ress->at(0)->printIndexedBuffer("Result is ");
//x.printIndexedBuffer("Input is");
delete ress;
}
TEST_F(DeclarableOpsTests6, TestDropout_2) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
nd4j::ops::dropout op;
auto ress = op.execute({&x}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//x.printIndexedBuffer("Input is");
//ress->at(0)->printIndexedBuffer("Result is ");
delete ress;
}
TEST_F(DeclarableOpsTests6, TestDropout_3) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto shape = NDArrayFactory::create<int>({1, 2});
nd4j::ops::dropout op;
auto ress = op.execute({&x, &shape}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//x.printIndexedBuffer("Input is");
//ress->at(0)->printIndexedBuffer("Result is ");
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,
1.5, 1., 1.3, 1.5, 3.5, 0., 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5});
auto expI = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2, 4}, {0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15,
0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15});
nd4j::ops::max_pool_with_argmax op;
auto ress = op.execute({&x}, {}, {1,1,1,1,1,1,1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_TRUE(expI.isSameShape(ress->at(0)));
ASSERT_TRUE(expI.isSameShape(ress->at(1)));
ASSERT_TRUE(x.equalsTo(ress->at(0)));
ASSERT_TRUE(expI.equalsTo(ress->at(1)));
//x.printIndexedBuffer("Input is");
ASSERT_TRUE(expI.equalsTo(ress->at(1)));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5});
// ------------------------------------
double count = 8.0;
auto sumExp = NDArrayFactory::create<double>({30.2, 5., 7.8, 22.8});
auto sqrExp = NDArrayFactory::create<double>({154.22, 7., 14.34, 103.62});
auto axis = NDArrayFactory::create<Nd4jLong>({0, 1, 2});
nd4j::ops::sufficient_statistics op;
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_EQ(ress->at(0)->e<double>(0), count);
ASSERT_TRUE(sumExp.equalsTo(ress->at(1)));
ASSERT_TRUE(sqrExp.equalsTo(ress->at(2)));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,
1.5, 1., 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5});
// ------------------------------------
double count = 4.0;
auto sumExp = NDArrayFactory::create<double>('c', {2, 4}, {
18.2, 3., 4.6, 8.8,
12., 2., 3.2, 14.}
);
auto sqrExp = NDArrayFactory::create<double>('c', {2, 4}, {
113.22, 5., 10.78, 34.62,
41., 2., 3.56, 69.}
);
auto axis = NDArrayFactory::create<int>({0, 1});
nd4j::ops::sufficient_statistics op;
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_EQ(ress->at(0)->e<double>(0), count);
ASSERT_TRUE(sumExp.equalsTo(ress->at(1)));
ASSERT_TRUE(sqrExp.equalsTo(ress->at(2)));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_1) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
// ------------------------------------
NDArray exp('c', {3}, {1, 3, 4}, nd4j::DataType::INT32);
nd4j::ops::bincount op;
auto res = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_2) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4., 13.});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_3) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4.});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {0, 2});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_4) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4., 13., 0.0});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {4, 4});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_5) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
auto minV = NDArrayFactory::create(4);
auto maxV = NDArrayFactory::create(4);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4., 13., 0.0});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights, &minV, &maxV}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printBuffer("BC out");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {
auto x = NDArrayFactory::create<int>( {2, 2, 2} );
auto y = NDArrayFactory::create<int>({ 2, 1, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<int>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 2} );
auto y = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 2, 2} );
auto y = NDArrayFactory::create<Nd4jLong>({ 2, 1});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 1} );
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, { 4,});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
//res->at(0)->printBuffer("Shape SGO 4");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_5) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_5) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 5");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 4});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 6");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
auto x = NDArrayFactory::create<Nd4jLong>({1, 1, 3});
auto y = NDArrayFactory::create<Nd4jLong>({2, 4, 1});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 7");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) {
auto x = NDArrayFactory::create<int>('c', {1}, {1});
auto y = NDArrayFactory::create<int>('c', {1}, {4});
// ------------------------------------
auto exp = NDArrayFactory::create<int>('c', {1}, {4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output SGO 8");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) {
auto x = NDArrayFactory::create<int>('c', {2}, {2,2});
auto y = NDArrayFactory::create<int>('c', {1}, {1});
// ------------------------------------
auto exp = NDArrayFactory::create<int>('c', {2}, {2,2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
res->at(0)->printIndexedBuffer("Output SGO 9");
exp.printIndexedBuffer("Expect9");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
-0.2771281, 0., 0.,
0.36950415, 0., 0.,
-0.2771281, 0., 0.,
0.36950415, 0., 0.,
-0.2771281, 0., 0.,
0.36950415, 0., 0.}
);
// 8.660254
// auto expNorm(8.660254);
nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x}, {0.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
auto norm = result->at(1);
//z->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expected");
//norm->printIndexedBuffer("Norm");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
// ASSERT_TRUE(expNorm.equalsTo(norm));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
);
auto a = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
-0.44090813, 0., 0.,
0.5878775, 0., 0.,
-0.44090813, 0., 0.,
0.5878775, 0., 0.,
-0.44090813, 0., 0.,
0.5878775, 0., 0.}
//12.247449
);
nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x, &a}, {1.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
auto y = result->at(1);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.isSameShape(y));
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_TRUE(exp.equalsTo(y));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto a = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
-0.19595918, 0., 0.,
0.2612789, 0., 0.,
-0.19595918, 0., 0.,
0.2612789, 0., 0.,
-0.19595918, 0., 0.,
0.2612789, 0., 0.}
);
nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x, &a}, {0.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
auto y = result->at(1);
//z->printIndexedBuffer("Output 1");
//y->printIndexedBuffer("Output 2");
//result->at(2)->printIndexedBuffer("Global norm is");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.isSameShape(y));
ASSERT_TRUE(result->at(2)->isScalar());
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_TRUE(exp.equalsTo(y));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0});
auto exp = NDArrayFactory::create<double>({36.0, -48.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0});
auto exp = NDArrayFactory::create<double>({-2.0, -2.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0});
NDArray exp('c', {1}, {-54.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0});
auto exp = NDArrayFactory::create<double>('c', {1}, {189.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
// z->printShapeInfo("Output shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
auto x = NDArrayFactory::create<double>('c', {1, 4, 4});
NDArray exp('c', {1}, {-16.0});
x.linspace(1);
x.p(5, 4.0);
x.p(12, 12.0);
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {
auto x = NDArrayFactory::create<double>('c', {4, 4});
auto exp = NDArrayFactory::create<double>(-16.0);
x.linspace(1);
x.p(5, 4.0);
x.p(12, 12.0);
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//z->printShapeInfo("Shape");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(z->isScalar());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0});
auto exp = NDArrayFactory::create<double>({3.58351893845611, 3.871201010907891});
nd4j::ops::log_matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8});
auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});
//x.printIndexedBuffer("Input");
nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("LogDet Output1 ");
exp.printIndexedBuffer("LogDet Expected1 ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_2) {
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98});
auto exp = NDArrayFactory::create<double>('c', {1}, { 3.5835189});
//x.printIndexedBuffer("Input");
nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("LogDet Output2 ");
// z->printShapeInfo("Shape");
exp.printIndexedBuffer("LogDet Expected2 ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_3) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {4,12,-16,12,37,-43,-16,-43,98});
auto exp = NDArrayFactory::create<double>( 3.5835189);
//x.printIndexedBuffer("Input");
nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("LogDet Output3 ");
// z->printShapeInfo("Shape");
exp.printIndexedBuffer("LogDet Expected3 ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
auto x = NDArrayFactory::create<float>('c', {2, 5, 5}, {
2., 4., 60., 8., 10.,
0., 1., 2., 3., 4.,
0., 0., 2., 4., 6.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 4.,
1., 0., 0., 0., 0.,
2., 1., 0., 0., 0.,
30., 2., 1., 0., 0.,
4., 3., 2., 1., 0.,
5., 4., 3., 2., 1.,
});
auto exp = NDArrayFactory::create<float>('c', {2, 5, 5}, {
0.5, -2.0, -13.0, 54.0, -6.75,
0.0, 1.0, -1.0, 1.0, 0.0,
0, 0, 0.5, -2.0, 0.25,
0, 0, 0, 1.0, -0.5,
0, 0, 0, 0, 0.25,
1.0, 0.0, 0.0, 0.0, 0.,
-2.0, 1.0, 0., 0., 0.,
-26.0, -2.0, 1, 0, 0.,
54.0, 1.0, -2.0, 1, 0.,
-27.0, 0.0, 1.0, -2.0, 1.
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
2., 4., 60., 8., 10.,
0., 1., 2., 3., 4.,
0., 0., 2., 4., 6.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 4.
});
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
0.5, -2.0, -13.0, 54.0, -6.75,
0.0, 1.0, -1.0, 1.0, 0.0,
0, 0, 0.5, -2.0, 0.25,
0, 0, 0, 1.0, -0.5,
0, 0, 0, 0, 0.25
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
auto x = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1., 0., 0., 0., 0.,
2., 1., 0., 0., 0.,
30., 2., 1., 0., 0.,
4., 3., 2., 1., 0.,
5., 4., 3., 2., 1.
});
auto exp = NDArrayFactory::create<float>('c', {1, 5, 5}, {
1.0, 0.0, 0.0, 0.0, 0.,
-2.0, 1.0, 0., 0., 0.,
-26.0, -2.0, 1, 0, 0.,
54.0, 1.0, -2.0, 1, 0.,
-27.0, 0.0, 1.0, -2.0, 1.
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
/*
TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
auto x = NDArrayFactory::create<double>('c', {2, 5, 5}, {
1., 2., 30., 4., 5.,
0., 1., 2., 3., 4.,
0., 0., 1., 2., 3.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 1.,
4., 0., 0., 0., 0.,
4., 2., 0., 0., 0.,
30., 2., 1., 0., 0.,
8., 6., 4., 2., 0.,
15., 12., 9., 6., 3.,
});
auto exp = NDArrayFactory::create<double>('c', {2, 5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0,
0.0, 1.0, -2.0, 1.0, 0.0,
0.0, 0.0, 1.0, -2.0, 1.0,
0.0, 0.0, 0.0, 1.0, -2.0,
0.0, 0.0, 0.0, 0.0, 1.0,
0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0,
13.50, 0.5, -2.0, 0.5, 0.0,
-6.75, 0.0, 1.0, -1.0, 0.33333333
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
*/
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
4., 0., 0., 0., 0.,
4., 2., 0., 0., 0.,
30., 2., 1., 0., 0.,
8., 6., 4., 2., 0.,
15., 12., 9., 6., 3.,
});
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0,
13.50, 0.5, -2.0, 0.5, 0.0,
-6.75, 0.0, 1.0, -1.0, 0.33333333
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
auto x = NDArrayFactory::create<double>('c', {5, 5}, {
4., 0., 0., 0., 0.,
4., 2., 0., 0., 0.,
30., 2., 1., 0., 0.,
8., 6., 4., 2., 0.,
15., 12., 9., 6., 3.,
});
auto exp = NDArrayFactory::create<double>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0,
13.50, 0.5, -2.0, 0.5, 0.0,
-6.75, 0.0, 1.0, -1.0, 0.33333333
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
auto x = NDArrayFactory::create<float>('c', {5, 5}, {
1., 2., 30., 4., 5.,
0., 1., 2., 3., 4.,
0., 0., 1., 2., 3.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 1.
});
auto exp = NDArrayFactory::create<float>('c', {5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0,
0.0, 1.0, -2.0, 1.0, 0.0,
0.0, 0.0, 1.0, -2.0, 1.0,
0.0, 0.0, 0.0, 1.0, -2.0,
0.0, 0.0, 0.0, 0.0, 1.0
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::FLOAT32);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ReluLayer_1) {
auto x = NDArrayFactory::create<double>('c', {3, 4}, {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12});
auto w = NDArrayFactory::create<double>('c', {4, 3}, {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25});
auto b = NDArrayFactory::create<double>({20.0, 30.0, 50.0});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {
21.4, 30.45, 52.3,
23.8, 31.05, 56.5,
26.2, 31.65, 60.7});
nd4j::ops::relu_layer op;
auto result = op.execute({&x, &w, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printShapeInfo("Output shape");
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) {
auto x = NDArrayFactory::create<double>('c', {3, 4, 5});
auto y = NDArrayFactory::create<double>('c', {3, 4, 5});
std::vector<int> dims = {0, 1};
auto z = x.applyReduce3(reduce3::CosineSimilarity, &y, dims, nullptr);
ASSERT_TRUE(z != nullptr);
delete z;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test1) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484, 0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 ,
0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. ,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0. ,0., 0., 0., 0.,0., 0., 0., 0.});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test2) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333,
0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0.97338548, 0.97338548, 0.97338548, 0.97338548,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0.97864398, 0.97864398, 0.97864398, 0.97864398,0.98000654, 0.98000654, 0.98000654, 0.98000654,
0.98112648, 0.98112648, 0.98112648, 0.98112648});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test3) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, 0});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0., 0., 0., 0., 0.9312333, 0.9312333, 0.9312333, 0.9312333,
0., 0., 0., 0. , 0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. ,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test4) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664,
0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0., 0., 0., 0. ,
0.97688859, 0.97688859, 0.97688859, 0.97688859,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test5) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664,
0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0.96849345, 0.96849345, 0.96849345, 0.96849345,
0.97688859, 0.97688859, 0.97688859, 0.97688859,0.97831069, 0.97831069, 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868,
0.98110653, 0.98110653, 0.98110653, 0.98110653});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881,0.86708881,0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842,0.78347842,
0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176,0.55529176,0., 0., 0., 0., 0.,0.,0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605,
0.90935605, 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945,0.64692945,0., 0., 0., 0., 0.,0.,0., 0., 0., 0., 0.,0.,
0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,
0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25});
nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFWfinal = results->at(1);
auto hBWfinal = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273,0.86518273,0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761,0.66617761,
0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203,0.31492203,0. , 0. , 0. , 0. , 0. ,0. ,
0.60005558, 0.60005558, 0.60005558, 0.9029975 , 0.9029975 ,0.9029975 ,0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931,0.43819931,
0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. ,
0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032,0.88852032,0. , 0. , 0. , 0. , 0. ,0. ,
0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. ,
0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775,0.66737775,0. , 0. , 0. , 0. , 0. ,0. ,
0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. ,
0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.});
nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFWfinal = results->at(1);
auto hBWfinal = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012,0.86841012,0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531,0.88207531,
0.31492203, 0.31492203, 0.31492203, 0.8941667 , 0.8941667 ,0.8941667 ,0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713,
0.90489713, 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375,0.91381375,0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504,
0.92253504,0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876,0.93027876,0.75947891, 0.75947891, 0.75947891, 0.9371767 , 0.9371767 ,
0.9371767 , 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274,0.94014274,0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926,
0.94648926,0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779,0.95204779,0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206,
0.95694206, 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086,0.93773086,0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176,
0.94579176,0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886,0.95267886,0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985,
0.95857985, 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293,0.76075293,0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637,
0.78024637,0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344,0.79833344,0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646,0.81508646});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713});
nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFWfinal = results->at(1);
auto hBWfinal = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<Nd4jLong>('c', {bS}, {time-1, time-3});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 ,
0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0. , 0. , 0. , 0. ,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. });
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334,
0.96778334,0.97309129, 0.97309129, 0.97309129, 0.97309129,0. , 0. , 0. , 0. ,
0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828,
0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, 0.96778334,0.97309129,
0.97309129, 0.97309129, 0.97309129,0.97491207, 0.97491207, 0.97491207, 0.97491207,0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491,
0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-4});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545,
0.96059545, 0.96059545,0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0. , 0. , 0. , 0. ,
0.57368608, 0.57368608, 0.57368608, 0.57368608,0. , 0. , 0 , 0. ,0., 0. , 0, 0.,0., 0., 0. , 0. ,0. , 0. , 0., 0. });
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, 0.96059545, 0.96059545,
0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.97486307, 0.97486307, 0.97486307, 0.97486307,0.57368608, 0.57368608, 0.57368608, 0.57368608,
0.92135149, 0.92135149, 0.92135149, 0.92135149,0.97482354, 0.97482354, 0.97482354, 0.97482354,0.97984727, 0.97984727, 0.97984727, 0.97984727,
0.98119833, 0.98119833, 0.98119833, 0.98119833});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.47615493, 0.47615493, 0.47615493,0.51241561, 0.51241561, 0.51241561,0. , 0. , 0. ,
0.73880324, 0.73880324, 0.73880324,0.77843476, 0.77843476, 0.77843476,0. , 0. , 0. ,0. , 0. , 0. ,
0.9052501 , 0.9052501 , 0.9052501 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.9555734 , 0.9555734 , 0.9555734 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHBW = NDArrayFactory::create<double>('c', {time, bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881,0.78347842, 0.78347842, 0.78347842,0.55529176, 0.55529176, 0.55529176,0. , 0. , 0. ,
0.90935605, 0.90935605, 0.90935605,0.64692945, 0.64692945, 0.64692945,0. , 0. , 0. ,0. , 0. , 0. ,
0.9181592 , 0.9181592 , 0.9181592 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.8026439 , 0.8026439 , 0.8026439 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2 , 0.2 , 0.2});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0. , 0. , 0. ,
0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.73978305, 0.73978305, 0.73978305,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207,0.83584708, 0.83584708, 0.83584708,0.77435951, 0.77435951, 0.77435951,0.58760492, 0.58760492, 0.58760492,0. , 0. , 0. ,
0.85615841, 0.85615841, 0.85615841,0.67397984, 0.67397984, 0.67397984,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.76576202, 0.76576202, 0.76576202,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.87294706, 0.87294706, 0.87294706,0.84851124, 0.84851124, 0.84851124,0.73978305, 0.73978305, 0.73978305,0.2 , 0.2 , 0.2});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0. , 0. , 0. ,
0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707,0.77935851, 0.77935851, 0.77935851,0.6381121 , 0.6381121 , 0.6381121 ,0.35748551, 0.35748551, 0.35748551,0. , 0. , 0. ,
0.77843476, 0.77843476, 0.77843476,0.47615493, 0.47615493, 0.47615493,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.84784327, 0.84784327, 0.84784327, 0.7793996 , 0.7793996 , 0.7793996 , 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0.89948899, 0.89948899, 0.89948899,
0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0.91925737, 0.91925737, 0.91925737,0.93751395, 0.93751395, 0.93751395,0.94544483, 0.94544483, 0.94544483,
0.73978305, 0.73978305, 0.73978305,0.92827068, 0.92827068, 0.92827068,0.95791111, 0.95791111, 0.95791111,0.96427356, 0.96427356, 0.96427356,0.96797541, 0.96797541, 0.96797541,
0.83057887, 0.83057887, 0.83057887,0.96365083, 0.96365083, 0.96365083,0.97585698, 0.97585698, 0.97585698,0.97866981, 0.97866981, 0.97866981,0.9807326 , 0.9807326 , 0.9807326 });
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722,0.86427295, 0.86427295, 0.86427295,0.8599919 , 0.8599919 , 0.8599919 ,0.80609463, 0.80609463, 0.80609463,0.61814662, 0.61814662, 0.61814662,
0.91888753, 0.91888753, 0.91888753,0.92652672, 0.92652672, 0.92652672,0.92939674, 0.92939674, 0.92939674,0.90661931, 0.90661931, 0.90661931,0.74516764, 0.74516764, 0.74516764,
0.95254269, 0.95254269, 0.95254269,0.95710717, 0.95710717, 0.95710717,0.96021584, 0.96021584, 0.96021584,0.95222547, 0.95222547, 0.95222547,0.83426363, 0.83426363, 0.83426363,
0.97154357, 0.97154357, 0.97154357,0.97424915, 0.97424915, 0.97424915,0.97644817, 0.97644817, 0.97644817,0.97410547, 0.97410547, 0.97410547,0.89409962, 0.89409962, 0.89409962});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, 0.96797541, 0.96797541, 0.96797541, 0.9807326 , 0.9807326 , 0.9807326 });
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0.89357928, 0.89357928, 0.89357928,
0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0.9053792 , 0.9053792 , 0.9053792 ,0.93546593, 0.93546593, 0.93546593,0.94518339, 0.94518339, 0.94518339,
0.61067683, 0.61067683, 0.61067683,0.90347408, 0.90347408, 0.90347408,0.95538786, 0.95538786, 0.95538786,0.96406045, 0.96406045, 0.96406045,0.96795929, 0.96795929, 0.96795929,
0.73978305, 0.73978305, 0.73978305,0.95499984, 0.95499984, 0.95499984,0.97535671, 0.97535671, 0.97535671,0.97864446, 0.97864446, 0.97864446,0.98073144, 0.98073144, 0.98073144});
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345,0.85160683, 0.85160683, 0.85160683,0.81997657, 0.81997657, 0.81997657,0.69228829, 0.69228829, 0.69228829,0.39861399, 0.39861399, 0.39861399,
0.91865453, 0.91865453, 0.91865453,0.92528094, 0.92528094, 0.92528094,0.92212167, 0.92212167, 0.92212167,0.86418213, 0.86418213, 0.86418213,0.57969286, 0.57969286, 0.57969286,
0.95252666, 0.95252666, 0.95252666,0.95696305, 0.95696305, 0.95696305,0.95878749, 0.95878749, 0.95878749,0.93722463, 0.93722463, 0.93722463,0.71727031, 0.71727031, 0.71727031,
0.97154234, 0.97154234, 0.97154234,0.97423089, 0.97423089, 0.97423089,0.976149 , 0.976149 , 0.976149 ,0.96878298, 0.96878298, 0.96878298,0.81508646, 0.81508646, 0.81508646});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
auto x = NDArrayFactory::create<double>('c', {3}, {0.15f, 0.25f, 0.35f});
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f});
nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
auto x = NDArrayFactory::create<double>('c', {1}, {0.15f});
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
auto x = NDArrayFactory::create<double>(0.15f);
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0));
delete result;
}
TEST_F(DeclarableOpsTests6, concat_test14) {
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
x0 = 1.;
x1 = 2.;
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printShapeInfo();
// z->printIndexedBuffer();
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(2 == numOfTads);
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
delete result;
}