cavis/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp

2300 lines
93 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 09.02.18.
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
#include <helpers/helper_hash.h>
#include <NDArray.h>
#include <array/NDArrayList.h>
using namespace nd4j;
using namespace nd4j::graph;
class DeclarableOpsTests6 : public testing::Test {
public:
DeclarableOpsTests6() {
printf("\n");
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
auto matrix = NDArrayFactory::create<double>('c', {5, 2});
auto b = NDArrayFactory::create<double>('c', {1}, {0.});
auto e = NDArrayFactory::create<double>('c', {1}, {1});
auto s = NDArrayFactory::create<double>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
auto matrix = NDArrayFactory::create<double>('c', {5, 2});
auto b = NDArrayFactory::create<double>('c', {1}, {0.0f});
auto e = NDArrayFactory::create<double>('c', {1}, {1.0f});
auto s = NDArrayFactory::create<double>('c', {1}, {1.0f});
auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
auto matrix = NDArrayFactory::create<double>(10);
auto b = NDArrayFactory::create<double>(0);
auto e = NDArrayFactory::create<double>(0);
auto s = NDArrayFactory::create<double>(1.0);
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
//z->printShapeInfo("SS OS shape");
ASSERT_TRUE(z->isEmpty());
//ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
auto matrix = NDArrayFactory::create<double>('c', {1}, {10});
auto b = NDArrayFactory::create<double>('c', {1}, {0.});
auto e = NDArrayFactory::create<double>('c', {1}, {0.});
auto s = NDArrayFactory::create<double>('c', {1}, {1.0});
Dev branch merge: dev_20190606 (#7904) * correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaks
2019-06-15 13:34:34 +02:00
auto exp = NDArrayFactory::create<double>(10);
2019-06-06 14:21:15 +02:00
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("SS OS shape");
Dev branch merge: dev_20190606 (#7904) * correct logsoftmax looss (#2) * Small SameDiff listener fix (#4) * Various fixes (#6) * #7839 Fix for asXMatrix and tests * #7866 EmbeddingSequenceLayer dtype fix + test * #7856 SameDiff save/load stream methods * #7859 RegressionEvaluation rank 4 fix + tests + axis configuration * EvaluationBinary 3d/4d * More evaluation 3d/4d tests * #7847 Evaluation empty checks * Small test ifx * #7848 Fix median edge case * Improve DL4J samediff layer tests * [WIP] FastText wrapper implemented (#8) * FastText implemented * Some fixes * Fix shapes for wordsNearest * Validation of input vectors * Fixes * Fixed test * Thread tagged * Some tweaks * setContextClassLoader for DeallocatorServiceThread * Numpy format tests (#1) * Various fixes (#11) * #7852 SameDiff gather fix * #7892 SameDiff placeholder to constant conversion * #7890 validate input rank for MLN/CG init methods * Fix broken permute shape calculation * Permute and gather fixes * Tests * #7850 LogSumExp fix + test * Handful of test fixes * Empty arrays with non-scalar shapes (#10) * minor rearrangements for lambdas * empty tensors with non-scalar shapes * numpy empty tensors with non-scalar shapes * few more empty tweaks * Small fixes * conv3d signature update * micro fix in batchnorm mkldnn * Import fixes * Fix * MKL-DNN update * Small fill fix * fill with empty input + test * Fixes * Small error improvement * Fix * one special test * couple of fixes for lstm * Rewrite TFGraphMapper.getNDArrayFromTensor to be maintainable and less error prone * Fixes * FP16 * Unsigned * BFloat16 * Fill op - empty tweaks * - couple of fixes for empty arrays construction - stack updated * strided slice fix * one transform test * provide method for reducing shapeInfo in case of input array is empty * Fixed reduceAlongDimensions to use empty input properly. * couple of broadcast tests * couple of tests broadcast tests + tweak to make them pass * add check of non-empty to methods producing sub-arrays * Fixed reshapeC with zeros in shape. * complete empty check in reduce_... legacy ops * Concat and cumsum/prod * Tweak to empty shape inference on import * add empty check to the rest of reduce legacy ops * one more test * correct typo in evalReduceShapeInfoEmpty * Added tests for reduce_* ops to tests with zero shapes. * few more tests for empty reductions * Fixed strided_slice op with empty case and tests. * one more empty reduction test * Fixed strided_slice test. * add empty check to NDArray::reshapei * infOrMax * empty min/max with infinity tests * made unstack working correctly with empty arrays * few IndexReduce tests + tweaks for empty shapes * add test for empty concat * few tests fixed * Validation fix for reductions on empty shapes * Reverse fix * Reduction shape calc fixes * SameDiff.generateOutputVariable: don't use shape function to determine number of outputs * Range fix * - NDArray constructor updated for scalars/empty arrays - few tests fixed * More fixes * Empty creator fixes * concat fix * concat fix * TF import tests: allow 'both all NaN' and 'both all inf' to pass * Slice, zero fraction, and reshape fixes * transpose, gather * Zero fraction * scalar cast fix * Empty reduction axis support * few more tests fixed * Fixed input checks conforming with TF for concat op and tests. * few tests fixed * matmul scalar shape fix * Fixed checkout for data type and scalarity with concat to allow non-empty scalars with vector concats. * broadcast bool fix * few more tests * few more tests * correct evalReduceShapeInfoEmpty * argmax/argmin + tests * one more empty edge case + one more test * argmax/argmin/realdiv_bp tweaks * empty reshape test + fix * Helper fixes * Small fixes * Gather test fix * Gather test fix * Small fixes * reduce scalar zero values * scalar mean workaround * Remove debug code * along dim mean workaround * one more test * - equalsTo() tweak for empty arrays - one more test * broadcast tweaks
2019-06-15 13:34:34 +02:00
z->printIndexedBuffer("SS OS out");
ASSERT_TRUE(z->equalsTo(exp));
2019-06-06 14:21:15 +02:00
//ASSERT_EQ(exp, *z);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
auto matrix = NDArrayFactory::create<double>('c', {1}, {10});
auto b = NDArrayFactory::create_<int>('c', {1}, {1});
auto e = NDArrayFactory::create_<int>('c', {1}, {(int)0});
auto s = NDArrayFactory::create_<int>('c', {1}, {1});
nd4j::ops::ones_as opOnes;
//auto exp = NDArrayFactory::create<double>('c', {2}, {1.0f, 2.0f});
auto onesRes = opOnes.execute({&matrix}, {}, {});
//matrix.linspace(1);
ASSERT_EQ(onesRes->status(), Status::OK());
auto ones = onesRes->at(0);
ones->printShapeInfo("Shape ones");
*ones *= 10;
auto onesD = ones->dup();
auto variableSpace = new VariableSpace();
variableSpace->putVariable(-1, onesD);
variableSpace->putVariable(-2, b);
variableSpace->putVariable(-3, e);
variableSpace->putVariable(-4, s);
auto block = new Context(1, variableSpace, false); // not-in-place
block->fillInputs({-1});
block->fillInputs({-2});
block->fillInputs({-3});
block->fillInputs({-4});
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(1);
block->getIArguments()->push_back(0);
block->getIArguments()->push_back(0);
auto inputShapes = new ShapeList({ones->getShapeInfo(), b->getShapeInfo(), e->getShapeInfo(), s->getShapeInfo()});
nd4j::ops::strided_slice op;
auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0});
ASSERT_EQ(result->size(), 1);
shape::printShapeInfoLinear(result->at(0));
//auto z = result->at(0);
// z->printShapeInfo("SS OS shape");
ASSERT_TRUE(shape::isEmpty(result->at(0)));
//ASSERT_EQ(exp, *z);
delete block;
delete onesRes;
delete result;
delete variableSpace;
delete inputShapes;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
auto matrix = NDArrayFactory::create<double>('c', {3, 2, 2});
auto b = NDArrayFactory::create<int>('c', {1}, {2});
auto e = NDArrayFactory::create<int>('c', {1}, {3});
auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {2,2}, {0.0f, 0.0f, 0., 0.});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
auto matrix = NDArrayFactory::create<double>('c', {3, 2, 2});
auto b = NDArrayFactory::create<int>('c', {1}, {2});
auto e = NDArrayFactory::create<int>('c', {1}, {3});
auto s = NDArrayFactory::create<int>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {1,2,2}, {0.0f, 0.0f, 0., 0.});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
int zero = 0;
auto matrix = NDArrayFactory::create<double>('c', {5, 4});
auto b = NDArrayFactory::create<int>('c', {1}, {zero});
auto e = NDArrayFactory::create<int>('c', {1}, {zero});
auto s = NDArrayFactory::create<int>('c', {1}, {1});
//auto exp = NDArrayFactory::create<double>('c', {1,2,2}, {0.0f, 0.0f, 0., 0.});
//matrix.linspace(1);
nd4j::ops::strided_slice op;
auto result = op.execute({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printShapeInfo("Output shape");
z->printIndexedBuffer("Output");
//ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
auto x = NDArrayFactory::create<double>('c', {1, 1}, {2.0f});
auto exp = NDArrayFactory::create<double>('c', {1, 1}, {4.0f});
nd4j::ops::test_scalar op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_gatherNd_Edge_1) {
auto x = NDArrayFactory::create<double>('c', {2, 4, 2, 2});
auto indices = NDArrayFactory::create<int>('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1});
auto exp = NDArrayFactory::create<double>('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f});
x.linspace(1);
nd4j::ops::gather_nd op;
auto result = op.execute({&x, &indices}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
//z->printIndexedBuffer();
//z->printShapeInfo("z shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StB_1) {
auto x = NDArrayFactory::create<double>('c', {4, 64, 64, 4});
auto blocks = NDArrayFactory::create<double>('c', {2}, {8, 8});
auto paddings = NDArrayFactory::create<double>('c', {2, 2}, {12, 12, 16, 16});
x.assign(1.0f);
nd4j::ops::space_to_batch op;
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
//nd4j_printf("Mean: %f\n", z->meanNumber());
delete result;
}
TEST_F(DeclarableOpsTests6, Test_StB_2) {
auto x = NDArrayFactory::create<double>('c', {2, 6, 6, 2});
auto blocks = NDArrayFactory::create<double>('c', {2}, {2, 2});
auto paddings = NDArrayFactory::create<double>('c', {2, 2}, {2, 2, 2, 2});
x.assign(1.0f);
nd4j::ops::space_to_batch op;
auto result = op.execute({&x, &blocks, &paddings}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_BtS_1) {
auto x = NDArrayFactory::create<double>('f', {256, 8, 8, 2});
auto blocks = NDArrayFactory::create<double>('c',{2}, {8, 8});
auto crops = NDArrayFactory::create<double>('c', {2, 2});
nd4j::ops::batch_to_space op;
auto result = op.execute({&x, &blocks, &crops}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Order_1) {
auto x = NDArrayFactory::create<double>('f', {2, 3});
auto exp = NDArrayFactory::create<double>('c', {2, 3});
x.linspace(1);
exp.linspace(1);
nd4j::ops::order op;
auto result = op.execute({&x}, {}, {0});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
z->printIndexedBuffer("O Output");
exp.printIndexedBuffer("O Expect");
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_NE(x.ordering(), z->ordering());
delete result;
}
TEST_F(DeclarableOpsTests6, Test_CumSum_Inclusive_Reverse_1) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_CumSum_Inclusive_Reverse_2) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {0, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_CumSum_Exclusive_Reverse_1) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 0}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_CumSum_Exclusive_Reverse_2) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
nd4j::ops::cumsum op;
auto result = op.execute({&x}, {}, {1, 1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_CumSum_Exclusive_Reverse_2_1) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto axis = NDArrayFactory::create<Nd4jLong>('c', {1}, {1});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f});
nd4j::ops::cumsum op;
auto result = op.execute({&x, &axis}, {}, {1, 1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestDropout_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto shape = NDArrayFactory::create<Nd4jLong>({2, 2});
nd4j::ops::dropout op;
auto ress = op.execute({&x, &shape}, {0.2f}, {113}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//ress->at(0)->printIndexedBuffer("Result is ");
//x.printIndexedBuffer("Input is");
delete ress;
}
TEST_F(DeclarableOpsTests6, TestDropout_2) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f});
nd4j::ops::dropout op;
auto ress = op.execute({&x}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//x.printIndexedBuffer("Input is");
//ress->at(0)->printIndexedBuffer("Result is ");
delete ress;
}
TEST_F(DeclarableOpsTests6, TestDropout_3) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto shape = NDArrayFactory::create<int>({1, 2});
nd4j::ops::dropout op;
auto ress = op.execute({&x, &shape}, {0.4f}, {113}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
//x.printIndexedBuffer("Input is");
//ress->at(0)->printIndexedBuffer("Result is ");
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,
1.5, 1., 1.3, 1.5, 3.5, 0., 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5});
auto expI = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2, 4}, {0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15,
0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15});
nd4j::ops::max_pool_with_argmax op;
auto ress = op.execute({&x}, {}, {1,1,1,1,1,1,1,1,1});
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_TRUE(expI.isSameShape(ress->at(0)));
ASSERT_TRUE(expI.isSameShape(ress->at(1)));
ASSERT_TRUE(x.equalsTo(ress->at(0)));
ASSERT_TRUE(expI.equalsTo(ress->at(1)));
//x.printIndexedBuffer("Input is");
//ress->at(0)->printIndexedBuffer("Result is ");
ASSERT_TRUE(expI.equalsTo(ress->at(1)));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1.,
1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5});
// ------------------------------------
double count = 8.0;
auto sumExp = NDArrayFactory::create<double>({30.2, 5., 7.8, 22.8});
auto sqrExp = NDArrayFactory::create<double>({154.22, 7., 14.34, 103.62});
auto axis = NDArrayFactory::create<Nd4jLong>({0, 1, 2});
nd4j::ops::sufficient_statistics op;
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_EQ(ress->at(0)->e<double>(0), count);
ASSERT_TRUE(sumExp.equalsTo(ress->at(1)));
ASSERT_TRUE(sqrExp.equalsTo(ress->at(2)));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
// auto x1 = NDArrayFactory::create<double>('c', {10, 10});
auto x = NDArrayFactory::create<double>('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,
1.5, 1., 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5});
// ------------------------------------
double count = 4.0;
auto sumExp = NDArrayFactory::create<double>('c', {2, 4}, {
18.2, 3., 4.6, 8.8,
12., 2., 3.2, 14.}
);
auto sqrExp = NDArrayFactory::create<double>('c', {2, 4}, {
113.22, 5., 10.78, 34.62,
41., 2., 3.56, 69.}
);
auto axis = NDArrayFactory::create<int>({0, 1});
nd4j::ops::sufficient_statistics op;
auto ress = op.execute({&x, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, ress->status());
ASSERT_EQ(ress->at(0)->e<double>(0), count);
ASSERT_TRUE(sumExp.equalsTo(ress->at(1)));
ASSERT_TRUE(sqrExp.equalsTo(ress->at(2)));
delete ress;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_1) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
// ------------------------------------
NDArray exp('c', {3}, {1, 3, 4}, nd4j::DataType::INT32);
nd4j::ops::bincount op;
auto res = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_2) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4., 13.});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_3) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4.});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {0, 2});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_4) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4., 13., 0.0});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights}, {}, {4, 4});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BinCount_5) {
auto x = NDArrayFactory::create<int>('c', {2, 2, 2}, {
1, 2, 0, 1, 2, 2, 1, 2}
);
auto weights = NDArrayFactory::create<double>('c', {2, 2, 2}, {
2, 1, 3, 1, 5, 1, 1, 6}
);
auto minV = NDArrayFactory::create(4);
auto maxV = NDArrayFactory::create(4);
// ------------------------------------
auto exp = NDArrayFactory::create<double>({3., 4., 13., 0.0});
nd4j::ops::bincount op;
auto res = op.execute({&x, &weights, &minV, &maxV}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res->status());
res->at(0)->printBuffer("BC out");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) {
auto x = NDArrayFactory::create<int>( {2, 2, 2} );
auto y = NDArrayFactory::create<int>({ 2, 1, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<int>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 2} );
auto y = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 2, 2} );
auto y = NDArrayFactory::create<Nd4jLong>({ 2, 1});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) {
auto x = NDArrayFactory::create<Nd4jLong>( {2, 1} );
auto y = NDArrayFactory::create<Nd4jLong>('c', {1}, { 4,});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
//res->at(0)->printBuffer("Shape SGO 4");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_5) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 2});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
// res->at(0)->printIndexedBuffer("Output");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_5) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 2});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
res->at(0)->printIndexedBuffer("Output SGO 5");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) {
auto x = NDArrayFactory::create<Nd4jLong>({2, 1, 4});
auto y = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 2, 4});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
res->at(0)->printIndexedBuffer("Output SGO 6");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
/////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) {
auto x = NDArrayFactory::create<Nd4jLong>({1, 1, 3});
auto y = NDArrayFactory::create<Nd4jLong>({2, 4, 1});
// ------------------------------------
auto exp = NDArrayFactory::create<Nd4jLong>({2, 4, 3});
nd4j::ops::broadcast_dynamic_shape op;
auto res = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT64);
ASSERT_EQ(ND4J_STATUS_OK, res->status());
res->at(0)->printIndexedBuffer("Output SGO 7");
// exp.printIndexedBuffer("Expect");
ASSERT_TRUE(exp.equalsTo(res->at(0)));
delete res;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
-0.2771281, 0., 0.,
0.36950415, 0., 0.,
-0.2771281, 0., 0.,
0.36950415, 0., 0.,
-0.2771281, 0., 0.,
0.36950415, 0., 0.}
);
// 8.660254
// auto expNorm(8.660254);
nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x}, {0.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
auto norm = result->at(1);
//z->printIndexedBuffer("Output");
//exp.printIndexedBuffer("Expected");
//norm->printIndexedBuffer("Norm");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
// ASSERT_TRUE(expNorm.equalsTo(norm));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
);
auto a = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0,
-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}
);
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
-0.44090813, 0., 0.,
0.5878775, 0., 0.,
-0.44090813, 0., 0.,
0.5878775, 0., 0.,
-0.44090813, 0., 0.,
0.5878775, 0., 0.}
//12.247449
);
nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x, &a}, {1.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
auto y = result->at(1);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.isSameShape(y));
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_TRUE(exp.equalsTo(y));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto a = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 3}, {
-0.19595918, 0., 0.,
0.2612789, 0., 0.,
-0.19595918, 0., 0.,
0.2612789, 0., 0.,
-0.19595918, 0., 0.,
0.2612789, 0., 0.}
);
nd4j::ops::clip_by_global_norm op;
auto result = op.execute({&x, &a}, {0.8}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
auto y = result->at(1);
//z->printIndexedBuffer("Output 1");
//y->printIndexedBuffer("Output 2");
//result->at(2)->printIndexedBuffer("Global norm is");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.isSameShape(y));
ASSERT_TRUE(result->at(2)->isScalar());
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_TRUE(exp.equalsTo(y));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0});
auto exp = NDArrayFactory::create<double>({36.0, -48.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0});
auto exp = NDArrayFactory::create<double>({-2.0, -2.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0});
NDArray exp('c', {1}, {-54.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
auto x = NDArrayFactory::create<double>('c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0});
auto exp = NDArrayFactory::create<double>('c', {1}, {189.0});
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
z->printShapeInfo("Output shape");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
auto x = NDArrayFactory::create<double>('c', {1, 4, 4});
NDArray exp('c', {1}, {-16.0});
x.linspace(1);
x.p(5, 4.0);
x.p(12, 12.0);
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {
auto x = NDArrayFactory::create<double>('c', {4, 4});
auto exp = NDArrayFactory::create<double>(-16.0);
x.linspace(1);
x.p(5, 4.0);
x.p(12, 12.0);
nd4j::ops::matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//z->printShapeInfo("Shape");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(z->isScalar());
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0});
auto exp = NDArrayFactory::create<double>({3.58351893845611, 3.871201010907891});
nd4j::ops::log_matrix_determinant op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, LogDet_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8});
auto exp = NDArrayFactory::create<double>({ 3.5835189, 4.159008});
//x.printIndexedBuffer("Input");
nd4j::ops::logdet op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
auto x = NDArrayFactory::create<double>('c', {2, 5, 5}, {
2., 4., 60., 8., 10.,
0., 1., 2., 3., 4.,
0., 0., 2., 4., 6.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 4.,
1., 0., 0., 0., 0.,
2., 1., 0., 0., 0.,
30., 2., 1., 0., 0.,
4., 3., 2., 1., 0.,
5., 4., 3., 2., 1.,
});
auto exp = NDArrayFactory::create<double>('c', {2, 5, 5}, {
0.5, -2.0, -13.0, 54.0, -6.75,
0.0, 1.0, -1.0, 1.0, 0.0,
0, 0, 0.5, -2.0, 0.25,
0, 0, 0, 1.0, -0.5,
0, 0, 0, 0, 0.25,
1.0, 0.0, 0.0, 0.0, 0.,
-2.0, 1.0, 0., 0., 0.,
-26.0, -2.0, 1, 0, 0.,
54.0, 1.0, -2.0, 1, 0.,
-27.0, 0.0, 1.0, -2.0, 1.
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
/*
TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
auto x = NDArrayFactory::create<double>('c', {2, 5, 5}, {
1., 2., 30., 4., 5.,
0., 1., 2., 3., 4.,
0., 0., 1., 2., 3.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 1.,
4., 0., 0., 0., 0.,
4., 2., 0., 0., 0.,
30., 2., 1., 0., 0.,
8., 6., 4., 2., 0.,
15., 12., 9., 6., 3.,
});
auto exp = NDArrayFactory::create<double>('c', {2, 5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0,
0.0, 1.0, -2.0, 1.0, 0.0,
0.0, 0.0, 1.0, -2.0, 1.0,
0.0, 0.0, 0.0, 1.0, -2.0,
0.0, 0.0, 0.0, 0.0, 1.0,
0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0,
13.50, 0.5, -2.0, 0.5, 0.0,
-6.75, 0.0, 1.0, -1.0, 0.33333333
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
z->printIndexedBuffer("Output ");
exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
*/
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
auto x = NDArrayFactory::create<double>('c', {5, 5}, {
4., 0., 0., 0., 0.,
4., 2., 0., 0., 0.,
30., 2., 1., 0., 0.,
8., 6., 4., 2., 0.,
15., 12., 9., 6., 3.,
});
auto exp = NDArrayFactory::create<double>('c', {5, 5}, {
0.25, 0.0, 0.0, 0.0, 0.0,
-0.50, 0.5, 0.0, 0.0, 0.0,
-6.50, -1.0, 1.0, 0.0, 0.0,
13.50, 0.5, -2.0, 0.5, 0.0,
-6.75, 0.0, 1.0, -1.0, 0.33333333
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
auto x = NDArrayFactory::create<double>('c', {5, 5}, {
1., 2., 30., 4., 5.,
0., 1., 2., 3., 4.,
0., 0., 1., 2., 3.,
0., 0., 0., 1., 2.,
0., 0., 0., 0., 1.
});
auto exp = NDArrayFactory::create<double>('c', {5, 5}, {
1.0, -2.0, -26.0, 54.0, -27.0,
0.0, 1.0, -2.0, 1.0, 0.0,
0.0, 0.0, 1.0, -2.0, 1.0,
0.0, 0.0, 0.0, 1.0, -2.0,
0.0, 0.0, 0.0, 0.0, 1.0
});
nd4j::ops::matrix_inverse op;
auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
//z->printIndexedBuffer("Output ");
//exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, ReluLayer_1) {
auto x = NDArrayFactory::create<double>('c', {3, 4}, {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12});
auto w = NDArrayFactory::create<double>('c', {4, 3}, {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25});
auto b = NDArrayFactory::create<double>({20.0, 30.0, 50.0});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {
21.4, 30.45, 52.3,
23.8, 31.05, 56.5,
26.2, 31.65, 60.7});
nd4j::ops::relu_layer op;
auto result = op.execute({&x, &w, &b}, {}, {}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printShapeInfo("Output shape");
// z->printIndexedBuffer("Output ");
// exp.printIndexedBuffer("Expected ");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) {
auto x = NDArrayFactory::create<double>('c', {3, 4, 5});
auto y = NDArrayFactory::create<double>('c', {3, 4, 5});
std::vector<int> dims = {0, 1};
auto z = x.applyReduce3(reduce3::CosineSimilarity, &y, dims, nullptr);
ASSERT_TRUE(z != nullptr);
delete z;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test1) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484, 0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 ,
0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. ,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0. ,0., 0., 0., 0.,0., 0., 0., 0.});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test2) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333,
0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0.97338548, 0.97338548, 0.97338548, 0.97338548,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0.97864398, 0.97864398, 0.97864398, 0.97864398,0.98000654, 0.98000654, 0.98000654, 0.98000654,
0.98112648, 0.98112648, 0.98112648, 0.98112648});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test3) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, 0});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0., 0., 0., 0., 0.9312333, 0.9312333, 0.9312333, 0.9312333,
0., 0., 0., 0. , 0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. ,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test4) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664,
0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0., 0., 0., 0. ,
0.97688859, 0.97688859, 0.97688859, 0.97688859,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_rnn_test5) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664,
0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0.96849345, 0.96849345, 0.96849345, 0.96849345,
0.97688859, 0.97688859, 0.97688859, 0.97688859,0.97831069, 0.97831069, 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868,
0.98110653, 0.98110653, 0.98110653, 0.98110653});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653});
nd4j::ops::static_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881,0.86708881,0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842,0.78347842,
0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176,0.55529176,0., 0., 0., 0., 0.,0.,0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605,
0.90935605, 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945,0.64692945,0., 0., 0., 0., 0.,0.,0., 0., 0., 0., 0.,0.,
0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,
0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25});
nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFWfinal = results->at(1);
auto hBWfinal = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273,0.86518273,0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761,0.66617761,
0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203,0.31492203,0. , 0. , 0. , 0. , 0. ,0. ,
0.60005558, 0.60005558, 0.60005558, 0.9029975 , 0.9029975 ,0.9029975 ,0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931,0.43819931,
0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. ,
0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032,0.88852032,0. , 0. , 0. , 0. , 0. ,0. ,
0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. ,
0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775,0.66737775,0. , 0. , 0. , 0. , 0. ,0. ,
0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. ,
0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.});
nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFWfinal = results->at(1);
auto hBWfinal = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012,0.86841012,0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531,0.88207531,
0.31492203, 0.31492203, 0.31492203, 0.8941667 , 0.8941667 ,0.8941667 ,0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713,
0.90489713, 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375,0.91381375,0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504,
0.92253504,0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876,0.93027876,0.75947891, 0.75947891, 0.75947891, 0.9371767 , 0.9371767 ,
0.9371767 , 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274,0.94014274,0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926,
0.94648926,0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779,0.95204779,0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206,
0.95694206, 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086,0.93773086,0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176,
0.94579176,0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886,0.95267886,0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985,
0.95857985, 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293,0.76075293,0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637,
0.78024637,0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344,0.79833344,0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646,0.81508646});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713});
nd4j::ops::static_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFWfinal = results->at(1);
auto hBWfinal = results->at(2);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<Nd4jLong>('c', {bS}, {time-1, time-3});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 ,
0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0. , 0. , 0. , 0. ,
0.97732812, 0.97732812, 0.97732812, 0.97732812,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. });
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334,
0.96778334,0.97309129, 0.97309129, 0.97309129, 0.97309129,0. , 0. , 0. , 0. ,
0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828,
0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto h0 = NDArrayFactory::create<double>('c', {bS, numUnits});
x.linspace(0.01, 0.01);
h0 = 0.2;
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, 0.96778334,0.97309129,
0.97309129, 0.97309129, 0.97309129,0.97491207, 0.97491207, 0.97491207, 0.97491207,0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491,
0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &h0}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
auto maxTimeStep = NDArrayFactory::create<double>('c', {bS}, {time-1, time-4});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545,
0.96059545, 0.96059545,0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0. , 0. , 0. , 0. ,
0.57368608, 0.57368608, 0.57368608, 0.57368608,0. , 0. , 0 , 0. ,0., 0. , 0, 0.,0., 0., 0. , 0. ,0. , 0. , 0., 0. });
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {
const int bS = 2;
const int inSize = 3;
const int numUnits = 4;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto Wx = NDArrayFactory::create<double>('c', {inSize, numUnits});
auto Wh = NDArrayFactory::create<double>('c', {numUnits, numUnits});
auto b = NDArrayFactory::create<double>('c', {2*numUnits});
x.linspace(0.01, 0.01);
Wx = 0.3;
Wh = 0.4;
b = 0.25;
auto expH = NDArrayFactory::create<double>('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, 0.96059545, 0.96059545,
0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.97486307, 0.97486307, 0.97486307, 0.97486307,0.57368608, 0.57368608, 0.57368608, 0.57368608,
0.92135149, 0.92135149, 0.92135149, 0.92135149,0.97482354, 0.97482354, 0.97482354, 0.97482354,0.97984727, 0.97984727, 0.97984727, 0.97984727,
0.98119833, 0.98119833, 0.98119833, 0.98119833});
auto expHFinal = NDArrayFactory::create<double>('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833});
nd4j::ops::dynamic_rnn op;
auto results = op.execute({&x, &Wx, &Wh, &b}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto h = results->at(0);
auto hFinal = results->at(1);
ASSERT_TRUE(expH.isSameShape(h));
ASSERT_TRUE(expH.equalsTo(h));
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {time, bS, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {time, bS, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.47615493, 0.47615493, 0.47615493,0.51241561, 0.51241561, 0.51241561,0. , 0. , 0. ,
0.73880324, 0.73880324, 0.73880324,0.77843476, 0.77843476, 0.77843476,0. , 0. , 0. ,0. , 0. , 0. ,
0.9052501 , 0.9052501 , 0.9052501 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.9555734 , 0.9555734 , 0.9555734 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHBW = NDArrayFactory::create<double>('c', {time, bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881,0.78347842, 0.78347842, 0.78347842,0.55529176, 0.55529176, 0.55529176,0. , 0. , 0. ,
0.90935605, 0.90935605, 0.90935605,0.64692945, 0.64692945, 0.64692945,0. , 0. , 0. ,0. , 0. , 0. ,
0.9181592 , 0.9181592 , 0.9181592 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.8026439 , 0.8026439 , 0.8026439 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2 , 0.2 , 0.2});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}, {}, false, nd4j::DataType::DOUBLE);
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0. , 0. , 0. ,
0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.73978305, 0.73978305, 0.73978305,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207,0.83584708, 0.83584708, 0.83584708,0.77435951, 0.77435951, 0.77435951,0.58760492, 0.58760492, 0.58760492,0. , 0. , 0. ,
0.85615841, 0.85615841, 0.85615841,0.67397984, 0.67397984, 0.67397984,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.76576202, 0.76576202, 0.76576202,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.87294706, 0.87294706, 0.87294706,0.84851124, 0.84851124, 0.84851124,0.73978305, 0.73978305, 0.73978305,0.2 , 0.2 , 0.2});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto maxTimeStep = NDArrayFactory::create<int>('c', {bS}, {time-1, time-3, time-4, 0});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0. , 0. , 0. ,
0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707,0.77935851, 0.77935851, 0.77935851,0.6381121 , 0.6381121 , 0.6381121 ,0.35748551, 0.35748551, 0.35748551,0. , 0. , 0. ,
0.77843476, 0.77843476, 0.77843476,0.47615493, 0.47615493, 0.47615493,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,
0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. });
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.84784327, 0.84784327, 0.84784327, 0.7793996 , 0.7793996 , 0.7793996 , 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
auto h0FW = NDArrayFactory::create<double>('c', {bS, numUnitsFW});
auto h0BW = NDArrayFactory::create<double>('c', {bS, numUnitsBW});
x.linspace(0.01, 0.01);
h0FW = 0.2;
h0BW = 0.25;
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0.89948899, 0.89948899, 0.89948899,
0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0.91925737, 0.91925737, 0.91925737,0.93751395, 0.93751395, 0.93751395,0.94544483, 0.94544483, 0.94544483,
0.73978305, 0.73978305, 0.73978305,0.92827068, 0.92827068, 0.92827068,0.95791111, 0.95791111, 0.95791111,0.96427356, 0.96427356, 0.96427356,0.96797541, 0.96797541, 0.96797541,
0.83057887, 0.83057887, 0.83057887,0.96365083, 0.96365083, 0.96365083,0.97585698, 0.97585698, 0.97585698,0.97866981, 0.97866981, 0.97866981,0.9807326 , 0.9807326 , 0.9807326 });
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722,0.86427295, 0.86427295, 0.86427295,0.8599919 , 0.8599919 , 0.8599919 ,0.80609463, 0.80609463, 0.80609463,0.61814662, 0.61814662, 0.61814662,
0.91888753, 0.91888753, 0.91888753,0.92652672, 0.92652672, 0.92652672,0.92939674, 0.92939674, 0.92939674,0.90661931, 0.90661931, 0.90661931,0.74516764, 0.74516764, 0.74516764,
0.95254269, 0.95254269, 0.95254269,0.95710717, 0.95710717, 0.95710717,0.96021584, 0.96021584, 0.96021584,0.95222547, 0.95222547, 0.95222547,0.83426363, 0.83426363, 0.83426363,
0.97154357, 0.97154357, 0.97154357,0.97424915, 0.97424915, 0.97424915,0.97644817, 0.97644817, 0.97644817,0.97410547, 0.97410547, 0.97410547,0.89409962, 0.89409962, 0.89409962});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, 0.96797541, 0.96797541, 0.96797541, 0.9807326 , 0.9807326 , 0.9807326 });
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
const int bS = 4;
const int inSize = 4;
const int numUnitsFW = 3;
const int numUnitsBW = 3;
const int time = 5;
auto x = NDArrayFactory::create<double>('c', {bS, time, inSize});
auto WxFW = NDArrayFactory::create<double>('c', {inSize, numUnitsFW});
auto WhFW = NDArrayFactory::create<double>('c', {numUnitsFW, numUnitsFW});
auto bFW = NDArrayFactory::create<double>('c', {2*numUnitsFW});
x.linspace(0.01, 0.01);
WxFW = 0.3;
WhFW = 0.4;
bFW = 0.1;
auto expHFW = NDArrayFactory::create<double>('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0.89357928, 0.89357928, 0.89357928,
0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0.9053792 , 0.9053792 , 0.9053792 ,0.93546593, 0.93546593, 0.93546593,0.94518339, 0.94518339, 0.94518339,
0.61067683, 0.61067683, 0.61067683,0.90347408, 0.90347408, 0.90347408,0.95538786, 0.95538786, 0.95538786,0.96406045, 0.96406045, 0.96406045,0.96795929, 0.96795929, 0.96795929,
0.73978305, 0.73978305, 0.73978305,0.95499984, 0.95499984, 0.95499984,0.97535671, 0.97535671, 0.97535671,0.97864446, 0.97864446, 0.97864446,0.98073144, 0.98073144, 0.98073144});
auto expHBW = NDArrayFactory::create<double>('c', {bS, time, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345,0.85160683, 0.85160683, 0.85160683,0.81997657, 0.81997657, 0.81997657,0.69228829, 0.69228829, 0.69228829,0.39861399, 0.39861399, 0.39861399,
0.91865453, 0.91865453, 0.91865453,0.92528094, 0.92528094, 0.92528094,0.92212167, 0.92212167, 0.92212167,0.86418213, 0.86418213, 0.86418213,0.57969286, 0.57969286, 0.57969286,
0.95252666, 0.95252666, 0.95252666,0.95696305, 0.95696305, 0.95696305,0.95878749, 0.95878749, 0.95878749,0.93722463, 0.93722463, 0.93722463,0.71727031, 0.71727031, 0.71727031,
0.97154234, 0.97154234, 0.97154234,0.97423089, 0.97423089, 0.97423089,0.976149 , 0.976149 , 0.976149 ,0.96878298, 0.96878298, 0.96878298,0.81508646, 0.81508646, 0.81508646});
auto expHFWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsFW}, {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144});
auto expHBWfinal = NDArrayFactory::create<double>('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234});
nd4j::ops::dynamic_bidirectional_rnn op;
auto results = op.execute({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto hFW = results->at(0);
auto hBW = results->at(1);
auto hFWfinal = results->at(2);
auto hBWfinal = results->at(3);
ASSERT_TRUE(expHFW.isSameShape(hFW));
ASSERT_TRUE(expHFW.equalsTo(hFW));
ASSERT_TRUE(expHBW.isSameShape(hBW));
ASSERT_TRUE(expHBW.equalsTo(hBW));
ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal));
ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal));
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
delete results;
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
auto x = NDArrayFactory::create<double>('c', {3}, {0.15f, 0.25f, 0.35f});
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f});
nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
auto x = NDArrayFactory::create<double>('c', {1}, {0.15f});
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0));
delete result;
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
auto x = NDArrayFactory::create<double>(0.15f);
auto e = NDArrayFactory::create<double>('c', {1, 1}, {0.15f});
nd4j::ops::diag op;
auto result = op.execute({&x}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
ASSERT_EQ(e, *result->at(0));
delete result;
}
TEST_F(DeclarableOpsTests6, maxPool2D_float_test1) {
NDArray input('c', {1,1,4,5}, nd4j::DataType::FLOAT32);
NDArray z('c', {1,1,4,5}, nd4j::DataType::FLOAT32);
input.linspace(1.);
nd4j::ops::maxpool2d op;
auto results = op.execute({&input}, {}, {2,2, 1,1, 1,1, 2,2, 1,0,0});
ASSERT_EQ(Status::OK(), results->status());
delete results;
}
TEST_F(DeclarableOpsTests6, concat_test14) {
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
x0 = 1.;
x1 = 2.;
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printShapeInfo();
// z->printIndexedBuffer();
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(2 == numOfTads);
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
delete result;
}