3046 lines
122 KiB
C++
3046 lines
122 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
#include "testlayers.h"
|
|
#include <ops/declarable/CustomOperations.h>
|
|
#include <helpers/helper_hash.h>
|
|
#include <NDArray.h>
|
|
#include <array/NDArrayList.h>
|
|
#include <MmulHelper.h>
|
|
#include <PointersManager.h>
|
|
|
|
|
|
using namespace nd4j;
|
|
using namespace nd4j::graph;
|
|
|
|
class DeclarableOpsTests3 : public testing::Test {
|
|
public:
|
|
|
|
DeclarableOpsTests3() {
|
|
//
|
|
}
|
|
};
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Tile_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
|
auto rep_vector= NDArrayFactory::create<int>('c', {1, 2}, {2, 2});
|
|
std::vector<Nd4jLong> reps({2, 2});
|
|
|
|
auto exp = x.tile(reps);
|
|
|
|
nd4j::ops::tile op;
|
|
auto result = op.execute({&x, &rep_vector}, {}, {});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Tile_2) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
|
std::vector<Nd4jLong> reps({2, 2});
|
|
|
|
auto exp = x.tile(reps);
|
|
|
|
nd4j::ops::tile op;
|
|
auto result = op.execute({&x}, {}, {2, 2});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Permute_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 3, 4});
|
|
auto permute= NDArrayFactory::create<Nd4jLong>('c', {1, 3}, {0, 2, 1});
|
|
auto exp= NDArrayFactory::create<float>('c', {2, 4, 3});
|
|
|
|
nd4j::ops::permute op;
|
|
auto result = op.execute({&x, &permute}, {}, {});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Permute_2) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 3, 4});
|
|
auto exp= NDArrayFactory::create<float>('c', {4, 3, 2});
|
|
|
|
nd4j::ops::permute op;
|
|
auto result = op.execute({&x}, {}, {});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Unique_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {1, 5}, {1, 2, 1, 2, 3});
|
|
auto expV= NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
|
auto expI= NDArrayFactory::create<Nd4jLong>('c', {5}, {0, 1, 0, 1, 2});
|
|
// auto expI= NDArrayFactory::create<float>('c', {3}, {0, 1, 4});
|
|
|
|
nd4j::ops::unique op;
|
|
auto result = op.execute({&x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
ASSERT_EQ(2, result->size());
|
|
|
|
auto v = result->at(0);
|
|
auto i = result->at(1);
|
|
v->printIndexedBuffer("Values");
|
|
i->printIndexedBuffer("Indices");
|
|
i->printShapeInfo("Indices shape");
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
|
|
ASSERT_TRUE(expI.isSameShape(i));
|
|
ASSERT_TRUE(expI.equalsTo(i));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Unique_2) {
|
|
auto x= NDArrayFactory::create<float>('c', {1, 5}, {1, 2, 1, 2, 3});
|
|
auto expV= NDArrayFactory::create<float>('c', {3}, {1, 2, 3});
|
|
auto expI= NDArrayFactory::create<Nd4jLong>('c', {5}, {0, 1, 0, 1, 2});
|
|
auto expC= NDArrayFactory::create<Nd4jLong>('c', {3}, {2, 2, 1});
|
|
|
|
nd4j::ops::unique_with_counts op;
|
|
auto result = op.execute({&x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
auto v = result->at(0);
|
|
auto i = result->at(1);
|
|
auto c = result->at(2);
|
|
|
|
v->printShapeInfo();
|
|
v->printIndexedBuffer("Values");
|
|
i->printShapeInfo();
|
|
i->printIndexedBuffer("Indices");
|
|
c->printShapeInfo();
|
|
c->printIndexedBuffer("Counts");
|
|
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
|
|
ASSERT_TRUE(expI.isSameShape(i));
|
|
ASSERT_TRUE(expI.equalsTo(i));
|
|
|
|
ASSERT_TRUE(expC.isSameShape(c));
|
|
ASSERT_TRUE(expC.equalsTo(c));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Rint_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {1, 7}, {-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0});
|
|
auto exp= NDArrayFactory::create<float>('c', {1, 7}, {-2., -2., -0., 0., 2., 2., 2.});
|
|
|
|
nd4j::ops::rint op;
|
|
auto result = op.execute({&x}, {}, {});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Norm_1) {
|
|
auto x = NDArrayFactory::create<float>('c', {100, 100});
|
|
x.linspace(1);
|
|
|
|
std::vector<int> empty;
|
|
std::vector<int> dims({1});
|
|
nd4j::ops::norm op;
|
|
|
|
auto result0 = op.execute({&x}, {0.}, {});
|
|
|
|
auto z0 = result0->at(0);
|
|
auto exp0 = x.reduceAlongDims(reduce::NormFrobenius, empty, false, false);
|
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
|
ASSERT_TRUE(exp0.equalsTo(z0));
|
|
|
|
delete result0;
|
|
|
|
auto result1 = op.execute({&x}, {1.}, {1});
|
|
ASSERT_EQ(result1->status(), ND4J_STATUS_OK);
|
|
auto z1 = result1->at(0);
|
|
z1->printIndexedBuffer("Z1");
|
|
auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false);
|
|
exp1.printIndexedBuffer("EXP1");
|
|
z1->printShapeInfo("Z1 shape");
|
|
exp1.printShapeInfo("EXP1 shape");
|
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
|
ASSERT_TRUE(exp1.equalsTo(z1));
|
|
|
|
delete result1;
|
|
|
|
auto result4 = op.execute({&x}, {4.}, {1});
|
|
|
|
auto z4 = result4->at(0);
|
|
auto exp4= x.reduceAlongDims(reduce::NormMax, dims, false, false);
|
|
ASSERT_TRUE(exp4.isSameShape(z4));
|
|
ASSERT_TRUE(exp4.equalsTo(z4));
|
|
|
|
delete result4;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
|
auto x = NDArrayFactory::create<float>('c', {100, 100});
|
|
x.linspace(1);
|
|
auto axis= NDArrayFactory::create<Nd4jLong>('c', {1, 1}, {1});
|
|
|
|
std::vector<int> empty;
|
|
std::vector<int> dims({1});
|
|
nd4j::ops::norm op;
|
|
|
|
auto result0 = op.execute({&x}, {0}, {});
|
|
|
|
auto z0 = result0->at(0);
|
|
auto exp0 = x.reduceAlongDims(reduce::NormFrobenius, empty, false, false);
|
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
|
ASSERT_TRUE(exp0.equalsTo(z0));
|
|
|
|
delete result0;
|
|
|
|
auto result1 = op.execute({&x, &axis}, {1}, {});
|
|
|
|
auto z1 = result1->at(0);
|
|
auto exp1 = x.reduceAlongDims(reduce::Norm2, dims, false, false);
|
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
|
ASSERT_TRUE(exp1.equalsTo(z1));
|
|
|
|
delete result1;
|
|
|
|
auto result4 = op.execute({&x, &axis}, {4}, {});
|
|
|
|
auto z4 = result4->at(0);
|
|
auto exp4= x.reduceAlongDims(reduce::NormMax, dims, false, false);
|
|
ASSERT_TRUE(exp4.isSameShape(z4));
|
|
ASSERT_TRUE(exp4.equalsTo(z4));
|
|
|
|
delete result4;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) {
|
|
auto x = NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
|
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
|
|
|
|
nd4j::ops::clipbyavgnorm op;
|
|
auto result = op.execute({&x}, {0.8}, {}, {}, false, nd4j::DataType::DOUBLE);
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
|
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3, 0.0, 0.0, 4, 0.0, 0.0});
|
|
|
|
nd4j::ops::clipbyavgnorm op;
|
|
auto result = op.execute({&x}, {0.9}, {});
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
|
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
|
|
|
|
nd4j::ops::clipbynorm op;
|
|
auto result = op.execute({&x}, {4.0}, {});
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
|
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
|
|
|
nd4j::ops::clipbynorm op;
|
|
auto result = op.execute({&x}, {6.0}, {});
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
|
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
|
|
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
|
|
|
|
x.linspace(100.);
|
|
|
|
auto xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true);
|
|
x /= xNorm1;
|
|
xNorm1 = x.reduceAlongDims(reduce::Norm2,{1}, true);
|
|
|
|
ASSERT_TRUE(unities.isSameShape(xNorm1));
|
|
ASSERT_TRUE(unities.equalsTo(xNorm1));
|
|
|
|
x *= scale;
|
|
xNorm1 = x.reduceAlongDims(reduce::Norm2, {1}, true);
|
|
|
|
nd4j::ops::clipbynorm op;
|
|
auto result = op.execute({&x}, {1.0}, {1}, {}, false, nd4j::DataType::DOUBLE);
|
|
auto z = result->at(0);
|
|
|
|
auto zNorm1 = z->reduceAlongDims(reduce::Norm2, {1}, true);
|
|
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
|
|
|
|
ASSERT_TRUE(exp.isSameShape(&zNorm1));
|
|
ASSERT_TRUE(exp.equalsTo(&zNorm1));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {6}, {1, 2, 3, 4, 5, 6});
|
|
auto y= NDArrayFactory::create<float>('c', {3}, {1, 3, 5});
|
|
|
|
auto exp0= NDArrayFactory::create<float>('c', {3}, {2, 4, 6});
|
|
auto exp1= NDArrayFactory::create<Nd4jLong>('c', {3}, {1, 3, 5});
|
|
|
|
nd4j::ops::listdiff op;
|
|
auto result = op.execute({&x, &y}, {}, {});
|
|
|
|
ASSERT_EQ(Status::OK(), result->status());
|
|
|
|
auto z0 = result->at(0);
|
|
auto z1 = result->at(1);
|
|
|
|
z0->getDataBuffer()->syncToSpecial(true); // force sync
|
|
z1->getDataBuffer()->syncToSpecial(true); // force sync
|
|
|
|
ASSERT_TRUE(exp0.isSameShape(z0));
|
|
ASSERT_TRUE(exp0.equalsTo(z0));
|
|
|
|
ASSERT_TRUE(exp1.isSameShape(z1));
|
|
ASSERT_TRUE(exp1.equalsTo(z1));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_1) {
|
|
auto start = NDArrayFactory::create<float>(0.3);
|
|
auto stop = NDArrayFactory::create<float>(-5);
|
|
auto step = NDArrayFactory::create<float>(-0.33);
|
|
auto exp= NDArrayFactory::create<float>('c', {17}, { 0.3 , -0.03, -0.36, -0.69, -1.02, -1.35, -1.68, -2.01, -2.34, -2.67,-3. , -3.33, -3.66, -3.99, -4.32, -4.65, -4.98});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({&start, &stop, &step}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_2) {
|
|
auto start= NDArrayFactory::create<float>('c', {1, 1}, {2});
|
|
auto stop= NDArrayFactory::create<float>('c', {1, 1}, {0.});
|
|
auto step= NDArrayFactory::create<float>('c', {1, 1}, {-1});
|
|
auto exp= NDArrayFactory::create<float>('c', {2}, {2, 1});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({&start, &stop, &step}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_3) {
|
|
auto start= NDArrayFactory::create<float>('c', {1, 1}, {0.});
|
|
auto stop= NDArrayFactory::create<float>('c', {1, 1}, {2});
|
|
auto step= NDArrayFactory::create<float>('c', {1, 1}, {1});
|
|
auto exp= NDArrayFactory::create<float>('c', {2}, {0, 1});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({&start, &stop, &step}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_4) {
|
|
auto exp= NDArrayFactory::create<float>('c', {13}, {-10., -8.334, -6.668, -5.002, -3.336, -1.67 , -0.004, 1.662, 3.328, 4.994, 6.66 , 8.326, 9.992});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({}, {-10., 10., 1.666}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_5) {
|
|
auto exp= NDArrayFactory::create<float>('c', {2}, {2, 1});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({}, {2, 0, -1}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_6) {
|
|
auto exp= NDArrayFactory::create<float>('c', {2}, {0, 1});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({}, {0, 2, 1}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_7) {
|
|
auto exp= NDArrayFactory::create<float>('c', {10}, {10., 8.334, 6.668, 5.002, 3.336, 1.67 , 0.004, -1.662, -3.328, -4.994});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({}, {10,-5,-1.666}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_8) {
|
|
auto exp= NDArrayFactory::create<int>('c', {2}, {2, 1});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({}, {}, {2, 0, -1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Range_9) {
|
|
auto exp= NDArrayFactory::create<int>('c', {2}, {0, 1});
|
|
|
|
nd4j::ops::range op;
|
|
auto result = op.execute({}, {}, {0, 2, 1});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) {
|
|
auto a= NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b= NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x= NDArrayFactory::create<float>('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
|
auto y= NDArrayFactory::create<float>('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
|
|
|
auto exp = MmulHelper::mmul(&x, &y);
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
for (int e = 0; e < 3; e++) {
|
|
auto z = result->at(e);
|
|
|
|
// exp->printIndexedBuffer("e");
|
|
// z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp->isSameShape(z));
|
|
ASSERT_TRUE(exp->equalsTo(z));
|
|
}
|
|
|
|
delete exp;
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
|
|
auto a= NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b= NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x= NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
|
auto y= NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
|
|
|
auto exp = MmulHelper::mmul(&x, &y);
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
for (int e = 0; e < 3; e++) {
|
|
auto z = result->at(e);
|
|
|
|
//exp->printIndexedBuffer("e");
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp->isSameShape(z));
|
|
ASSERT_TRUE(exp->equalsTo(z));
|
|
}
|
|
|
|
delete exp;
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
|
|
auto a= NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b= NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x= NDArrayFactory::create<float>('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
|
auto y= NDArrayFactory::create<float>('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
|
|
|
|
auto exp = MmulHelper::mmul(&x, &y);
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
for (int e = 0; e < 3; e++) {
|
|
auto z = result->at(e);
|
|
|
|
// exp->printIndexedBuffer("e");
|
|
// z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp->isSameShape(z));
|
|
ASSERT_TRUE(exp->equalsTo(z));
|
|
}
|
|
|
|
delete exp;
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
|
|
auto a= NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b= NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x= NDArrayFactory::create<float>('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
|
auto y= NDArrayFactory::create<float>('f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
|
|
|
auto exp = MmulHelper::mmul(&x, &y);
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
for (int e = 0; e < 3; e++) {
|
|
auto z = result->at(e);
|
|
|
|
//exp->printIndexedBuffer("e");
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp->isSameShape(z));
|
|
ASSERT_TRUE(exp->equalsTo(z));
|
|
}
|
|
|
|
delete exp;
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
|
|
auto a= NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b= NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x= NDArrayFactory::create<float>('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
|
auto y= NDArrayFactory::create<float>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
|
|
|
auto exp = MmulHelper::mmul(&x, &y);
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
for (int e = 0; e < 3; e++) {
|
|
auto z = result->at(e);
|
|
|
|
//exp->printIndexedBuffer("e");
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp->isSameShape(z));
|
|
ASSERT_TRUE(exp->equalsTo(z));
|
|
}
|
|
|
|
delete exp;
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) {
|
|
auto a= NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b= NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x= NDArrayFactory::create<float>('f', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
|
auto y= NDArrayFactory::create<float>('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
|
|
|
auto exp = MmulHelper::mmul(&x, &y);
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
for (int e = 0; e < 3; e++) {
|
|
auto z = result->at(e);
|
|
|
|
//exp->printIndexedBuffer("e");
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp->isSameShape(z));
|
|
ASSERT_TRUE(exp->equalsTo(z));
|
|
}
|
|
|
|
delete exp;
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
|
auto a= NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b= NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x= NDArrayFactory::create<float>('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
|
auto y= NDArrayFactory::create<float>('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
|
|
|
auto exp = MmulHelper::mmul(&x, &y);
|
|
|
|
exp->printShapeInfo("exp shape");
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
ASSERT_EQ(3, result->size());
|
|
|
|
for (int e = 0; e < 3; e++) {
|
|
auto z = result->at(e);
|
|
|
|
//exp->printIndexedBuffer("e");
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp->isSameShape(z));
|
|
ASSERT_TRUE(exp->equalsTo(z));
|
|
}
|
|
|
|
delete exp;
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
|
|
auto a = NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b = NDArrayFactory::create<double>('c', {1, 3}, {0, 0, 0});
|
|
auto x = NDArrayFactory::create<float16>('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
|
auto y = NDArrayFactory::create<float>('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
try {
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
|
delete result;
|
|
ASSERT_TRUE(false);
|
|
} catch (std::invalid_argument &e) {
|
|
//
|
|
}
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) {
|
|
auto a = NDArrayFactory::create<float>('c', {1, 3}, {1, 1, 1});
|
|
auto b = NDArrayFactory::create<float>('c', {1, 3}, {0, 0, 0});
|
|
auto x = NDArrayFactory::create<float>('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
|
|
auto y = NDArrayFactory::create<float>('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
|
|
|
|
auto z = NDArrayFactory::create<double>('c', {2, 3});
|
|
|
|
nd4j::ops::batched_gemm op;
|
|
try {
|
|
auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {&z}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}, {});
|
|
ASSERT_TRUE(false);
|
|
} catch (std::invalid_argument &e) {
|
|
//
|
|
}
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
|
auto y= NDArrayFactory::create<float>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
|
auto exp= NDArrayFactory::create<float>('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0});
|
|
|
|
nd4j::ops::matmul op;
|
|
auto result = op.execute({&x, &y}, {}, {1, 1});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_2) {
|
|
auto x= NDArrayFactory::create<float>('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
|
auto y= NDArrayFactory::create<float>('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12});
|
|
auto exp= NDArrayFactory::create<float>('f', {3, 3}, {70.0, 158.0, 246.0, 80.0, 184.0, 288.0, 90.0, 210.0, 330.0});
|
|
|
|
nd4j::ops::matmul op;
|
|
auto result = op.execute({&x, &y}, {}, {0, 0});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
|
|
delete result;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_3) {
|
|
auto x= NDArrayFactory::create<float>('c', {1, 3}, {1, 2, 3});
|
|
auto y= NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4});
|
|
auto exp= NDArrayFactory::create<float>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
|
|
|
nd4j::ops::matmul op;
|
|
auto result = op.execute({&x, &y}, {}, {1, 0});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_4) {
|
|
auto x= NDArrayFactory::create<float>('c', {3, 1}, {1, 2, 3});
|
|
auto y= NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4});
|
|
auto exp= NDArrayFactory::create<float>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
|
|
|
nd4j::ops::matmul op;
|
|
auto result = op.execute({&x, &y}, {}, {0, 1});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_5) {
|
|
auto x= NDArrayFactory::create<float>('c', {3, 1}, {1, 2, 3});
|
|
auto y= NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4});
|
|
auto exp= NDArrayFactory::create<float>('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0});
|
|
|
|
nd4j::ops::matmul op;
|
|
auto result = op.execute({&x, &y}, {}, {});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_Manual_Gemm_6) {
|
|
auto x= NDArrayFactory::create<float>('c', {4, 1}, {1, 2, 3, 4});
|
|
auto y= NDArrayFactory::create<float>('c', {1, 4}, {1, 2, 3, 4});
|
|
auto exp= NDArrayFactory::create<float>('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16});
|
|
|
|
nd4j::ops::matmul op;
|
|
auto result = op.execute({&x, &y}, {}, {});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
//z->printIndexedBuffer("z");
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_AvgPool_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {2, 10, 10, 3});
|
|
x.linspace(1);
|
|
|
|
nd4j::ops::avgpool2d op;
|
|
// kY kX sY sX pY pX dY dX M P
|
|
auto result = op.execute({&x}, {}, {3, 3, 3, 3, 0, 0, 1, 1, 1, 0, 1});
|
|
// 0 1 2 3 4 5 6 7 8 9 10
|
|
auto z = result->at(0);
|
|
|
|
// z->printShapeInfo("z shape");
|
|
// z->printIndexedBuffer("z buffr");
|
|
|
|
delete result;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) {
|
|
auto x= NDArrayFactory::create<float>('c', {1, 3}, {2, 2, 2});
|
|
auto y= NDArrayFactory::create<float>('c', {1, 3}, {4, 6, 8});
|
|
auto exp= NDArrayFactory::create<float>('c', {1, 3}, {2, 3, 4});
|
|
|
|
nd4j::ops::reversedivide op;
|
|
auto result = op.execute({&x, &y}, {}, {});
|
|
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
|
|
|
auto z = result->at(0);
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
delete result;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, sruCell_test1) {
|
|
|
|
const int batchSize = 2;
|
|
const int inSize = 5;
|
|
|
|
auto xt = NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto ct_1= NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto w = NDArrayFactory::create<float>('c', {inSize, 3*inSize});
|
|
auto b = NDArrayFactory::create<float>('c', {2*inSize});
|
|
|
|
xt.assign(1.);
|
|
ct_1.assign(2.);
|
|
w.assign(0.5);
|
|
b.assign(0.7);
|
|
|
|
auto expHt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103,0.96674103});
|
|
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.01958286,2.01958286,2.01958286,2.01958286,2.01958286, 2.01958286,2.01958286,2.01958286,2.01958286,2.01958286});
|
|
|
|
nd4j::ops::sruCell op;
|
|
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *ht = results->at(0);
|
|
auto *ct = results->at(1);
|
|
|
|
ASSERT_TRUE(expHt.isSameShape(ht));
|
|
ASSERT_TRUE(expHt.equalsTo(ht));
|
|
ASSERT_TRUE(expCt.isSameShape(ct));
|
|
ASSERT_TRUE(expCt.equalsTo(ct));
|
|
|
|
delete results;
|
|
}
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, sruCell_test2) {
|
|
|
|
const int batchSize = 2;
|
|
const int inSize = 5;
|
|
|
|
auto xt = NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto ct_1= NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto w = NDArrayFactory::create<float>('c', {inSize, 3*inSize});
|
|
auto b = NDArrayFactory::create<float>('c', {2*inSize});
|
|
|
|
xt.assign(1.);
|
|
ct_1.assign(2.);
|
|
w.assign(0.5);
|
|
b.assign(-1.);
|
|
|
|
auto expHt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038,0.97542038});
|
|
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276,2.09121276});
|
|
|
|
nd4j::ops::sruCell op;
|
|
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *ht = results->at(0);
|
|
auto *ct = results->at(1);
|
|
|
|
ASSERT_TRUE(expHt.isSameShape(ht));
|
|
ASSERT_TRUE(expHt.equalsTo(ht));
|
|
ASSERT_TRUE(expCt.isSameShape(ct));
|
|
ASSERT_TRUE(expCt.equalsTo(ct));
|
|
|
|
delete results;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, sruCell_test3) {
|
|
|
|
const int batchSize = 2;
|
|
const int inSize = 5;
|
|
|
|
auto xt = NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto ct_1= NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto w = NDArrayFactory::create<float>('c', {inSize, 3*inSize});
|
|
auto b = NDArrayFactory::create<float>('c', {2*inSize});
|
|
|
|
xt.assign(10.);
|
|
ct_1.assign(1.);
|
|
w.assign(0.5);
|
|
b.assign(-1.);
|
|
|
|
auto expHt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416,0.76159416});
|
|
auto expCt= NDArrayFactory::create<float>('c', {batchSize, inSize}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.});
|
|
|
|
nd4j::ops::sruCell op;
|
|
auto results = op.execute({&xt, &ct_1, &w, &b}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *ht = results->at(0);
|
|
auto *ct = results->at(1);
|
|
|
|
ASSERT_TRUE(expHt.isSameShape(ht));
|
|
ASSERT_TRUE(expHt.equalsTo(ht));
|
|
ASSERT_TRUE(expCt.isSameShape(ct));
|
|
ASSERT_TRUE(expCt.equalsTo(ct));
|
|
|
|
delete results;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, gruCell_test1) {
|
|
|
|
const int batchSize = 2;
|
|
const int inSize = 10;
|
|
const int numUnits = 4;
|
|
|
|
auto xt = NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto ht_1 = NDArrayFactory::create<float>('c', {batchSize, numUnits});
|
|
auto Wru = NDArrayFactory::create<float>('c', {(inSize+numUnits), 2*numUnits});
|
|
auto Wc = NDArrayFactory::create<float>('c', {(inSize+numUnits), numUnits});
|
|
auto bru = NDArrayFactory::create<float>('c', {2*numUnits});
|
|
auto bc = NDArrayFactory::create<float>('c', {numUnits});
|
|
|
|
xt.assign(1.);
|
|
ht_1.assign(2.);
|
|
Wru.assign(0.5);
|
|
Wc.assign(0.5);
|
|
bru.assign(0.7);
|
|
bc.assign(0.7);
|
|
|
|
auto expHt = NDArrayFactory::create<float>('c', {batchSize, numUnits}, {1.99993872,1.99993872,1.99993872,1.99993872,1.99993872,1.99993872,1.99993872,1.99993872});
|
|
|
|
nd4j::ops::gruCell op;
|
|
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *ht = results->at(3);
|
|
|
|
ASSERT_TRUE(expHt.isSameShape(ht));
|
|
ASSERT_TRUE(expHt.equalsTo(ht));
|
|
|
|
delete results;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, gruCell_test2) {
|
|
|
|
const int batchSize = 2;
|
|
const int inSize = 10;
|
|
const int numUnits = 4;
|
|
|
|
auto xt = NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto ht_1 = NDArrayFactory::create<float>('c', {batchSize, numUnits});
|
|
auto Wru = NDArrayFactory::create<float>('c', {(inSize+numUnits), 2*numUnits});
|
|
auto Wc = NDArrayFactory::create<float>('c', {(inSize+numUnits), numUnits});
|
|
auto bru = NDArrayFactory::create<float>('c', {2*numUnits});
|
|
auto bc = NDArrayFactory::create<float>('c', {numUnits});
|
|
|
|
xt.assign(1.);
|
|
ht_1.assign(0.);
|
|
Wru.assign(1.5);
|
|
Wc.assign(1.5);
|
|
bru.assign(-10);
|
|
bc.assign(-10);
|
|
|
|
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.00669224,0.00669224,0.00669224,0.00669224,0.00669224,0.00669224,0.00669224,0.00669224});
|
|
|
|
nd4j::ops::gruCell op;
|
|
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *ht = results->at(3);
|
|
|
|
ASSERT_TRUE(expHt.isSameShape(ht));
|
|
ASSERT_TRUE(expHt.equalsTo(ht));
|
|
|
|
delete results;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, gruCell_test3) {
|
|
|
|
const int batchSize = 2;
|
|
const int inSize = 10;
|
|
const int numUnits = 4;
|
|
|
|
auto xt = NDArrayFactory::create<float>('c', {batchSize, inSize});
|
|
auto ht_1= NDArrayFactory::create<float>('c', {batchSize, numUnits});
|
|
auto Wru = NDArrayFactory::create<float>('c', {(inSize+numUnits), 2*numUnits});
|
|
auto Wc = NDArrayFactory::create<float>('c', {(inSize+numUnits), numUnits});
|
|
auto bru = NDArrayFactory::create<float>('c', {2*numUnits});
|
|
auto bc = NDArrayFactory::create<float>('c', {numUnits});
|
|
|
|
xt.assign(1.);
|
|
ht_1.assign(0.);
|
|
Wru.assign(0.1);
|
|
Wc.assign(0.1);
|
|
bru.assign(1);
|
|
bc.assign(1);
|
|
|
|
auto expHt= NDArrayFactory::create<float>('c', {batchSize, numUnits}, {0.1149149,0.1149149,0.1149149,0.1149149,0.1149149,0.1149149,0.1149149,0.1149149});
|
|
|
|
nd4j::ops::gruCell op;
|
|
auto results = op.execute({&xt, &ht_1, &Wru, &Wc, &bru, &bc}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *ht = results->at(3);
|
|
|
|
ASSERT_TRUE(expHt.isSameShape(ht));
|
|
ASSERT_TRUE(expHt.equalsTo(ht));
|
|
|
|
delete results;
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, invertPermutation_test1) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {1, 8}, {5,2,7,4,6,3,1,0});
|
|
auto expected= NDArrayFactory::create<float>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
|
|
|
|
nd4j::ops::invert_permutation op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, invertPermutation_test2) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {1, 8}, {5,2,7,4,6,3,1,0});
|
|
auto expected= NDArrayFactory::create<float>('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2});
|
|
|
|
nd4j::ops::invert_permutation op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, invertPermutation_test3) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {1, 8}, {1,2,0,4,6,3,5,7});
|
|
auto expected= NDArrayFactory::create<float>('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7});
|
|
|
|
nd4j::ops::invert_permutation op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diag_test1) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {3, 2});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diag_test2) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {2, 3});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diag_test_vector) {
|
|
|
|
|
|
auto input = NDArrayFactory::linspace<float>(1,4,4);
|
|
auto expected= NDArrayFactory::create<float>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
delete input;
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, diag_test_col_vector) {
|
|
|
|
|
|
auto input = NDArrayFactory::linspace<float>(1,4,4);
|
|
input->reshapei({4,1});
|
|
auto expected= NDArrayFactory::create<float>('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
delete input;
|
|
}
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diag_test3) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {1, 3});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diag_test4) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {3, 1});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1,0,0, 0,2,0, 0,0,3});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diag_test5) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {1, 1});
|
|
input.linspace(2);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {1,1}, {2});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diag_test6) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {2,2,2});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8});
|
|
|
|
nd4j::ops::diag op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {4,3,2});
|
|
auto diagonal= NDArrayFactory::create<float>('c', {4,2});
|
|
input.assign(0.);
|
|
diagonal.assign(1.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0});
|
|
|
|
nd4j::ops::matrix_set_diag op;
|
|
auto results = op.execute({&input, &diagonal}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {1,1,2});
|
|
auto diagonal= NDArrayFactory::create<float>('c', {1,1});
|
|
input.assign(0.);
|
|
diagonal.assign(1.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {1,1,2}, {1,0});
|
|
|
|
nd4j::ops::matrix_set_diag op;
|
|
auto results = op.execute({&input, &diagonal}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {2,1,4});
|
|
auto diagonal= NDArrayFactory::create<float>('c', {2,1});
|
|
input.assign(0.);
|
|
diagonal.assign(1.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {2,1,4}, {1,0,0,0,1,0,0,0});
|
|
|
|
nd4j::ops::matrix_set_diag op;
|
|
auto results = op.execute({&input, &diagonal}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {2,1,4,1});
|
|
auto diagonal= NDArrayFactory::create<float>('c', {2,1,1});
|
|
input.assign(0.);
|
|
diagonal.assign(1.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {2,1,4,1}, {1,0,0,0,1,0,0,0});
|
|
|
|
nd4j::ops::matrix_set_diag op;
|
|
auto results = op.execute({&input, &diagonal}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diagPart_test1) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {2,2});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {2}, {1,4});
|
|
|
|
nd4j::ops::diag_part op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
// output->printBuffer();
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diagPart_test2) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {2,2,2,2});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {2,2}, {1,6,11,16});
|
|
|
|
nd4j::ops::diag_part op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, diagPart_test3) {
|
|
|
|
auto input= NDArrayFactory::create<float>('c', {2,2,2,2,2,2});
|
|
input.linspace(1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {2,2,2}, {1,10,19,28,37,46,55,64});
|
|
|
|
nd4j::ops::diag_part op;
|
|
auto results = op.execute({&input}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test1) {
|
|
|
|
auto a = NDArrayFactory::create<float16>('c', {3,3});
|
|
auto b = NDArrayFactory::create<float16>('c', {3,3});
|
|
auto x = NDArrayFactory::create<float16>('c', {3,3});
|
|
|
|
a.linspace((float16)0.1, (float16)0.1);
|
|
b.linspace((float16)0.1, (float16)0.1);
|
|
x.assign(0.1);
|
|
|
|
auto expected = NDArrayFactory::create<float16>('c', {3,3}, {0.40638509,0.33668978,0.28271242,0.23973916,0.20483276,0.17604725,0.15203027,0.13180567,0.114647});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output, 1e-2));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test2) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(0.1, 0.1);
|
|
b.linspace(0.1, 0.1);
|
|
x.assign(0.1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509,0.33668978,0.28271242,0.23973916,0.20483276,0.17604725,0.15203027,0.13180567,0.114647});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test3) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(0.1, 0.1);
|
|
b.linspace(0.1, 0.1);
|
|
x.assign(0.1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.40638509,0.33668978,0.28271242,0.23973916,0.20483276,0.17604725,0.15203027,0.13180567,0.114647});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test4) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(1);
|
|
b.linspace(1);
|
|
x.assign(0.1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00000000e-01,2.80000000e-02,8.56000000e-03,2.72800000e-03,8.90920000e-04,2.95706080e-04,9.92854864e-05,3.36248880e-05,1.14644360e-05});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test5) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(3200.);
|
|
b.linspace(3200.);
|
|
x.assign(0.1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.,0.,0.,0.,0.,0.,0.,0.,0.});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test6) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(10.);
|
|
b.linspace(10.);
|
|
x.assign(0.1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {3.92988233e-06,1.35306497e-06,4.67576826e-07,1.62083416e-07,5.63356971e-08,1.96261318e-08,6.85120307e-09,2.39594668e-09,8.39227685e-10});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test7) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(10.);
|
|
b.linspace(10.);
|
|
x.assign(0.9);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.99999607,0.99999865,0.99999953,0.99999984,0.99999994,0.99999998,0.99999999,1.,1.});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test8) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(10.);
|
|
b.linspace(10.);
|
|
x.assign(1.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.,1.,1.,1.,1.,1.,1.,1.,1.});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test9) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(10.);
|
|
b.linspace(10.);
|
|
x.assign(0.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.,0.,0.,0.,0.,0.,0.,0.,0.});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, betainc_test10) {
|
|
|
|
auto a= NDArrayFactory::create<float>('c', {3,3});
|
|
auto b= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
a.linspace(10.);
|
|
b.linspace(10.);
|
|
x.assign(0.5);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5});
|
|
|
|
nd4j::ops::betainc op;
|
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test1) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
auto q= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
q.linspace(1.);
|
|
x.assign(2.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.64493407,0.64493407,0.39493407,0.28382296,0.22132296,0.18132296,0.15354518,0.13313701,0.11751201});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test2) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
auto q= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
q.linspace(10.);
|
|
x.assign(2.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.10516634,0.09516634,0.08690187,0.07995743,0.07404027,0.06893823,0.06449378,0.06058753,0.05712733});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test3) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
auto q= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
q.linspace(100.);
|
|
x.assign(2.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017,0.00995017,0.00985214,0.00975602,0.00966176,0.0095693 ,0.0094786 ,0.0093896 ,0.00930226});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test4) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
auto q= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
q.linspace(100.);
|
|
x.assign(2.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {0.01005017,0.00995017,0.00985214,0.00975602,0.00966176,0.0095693 ,0.0094786 ,0.0093896 ,0.00930226});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test5) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
auto q= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
q.linspace(1.);
|
|
x.assign(1.1);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {10.58444846,9.58444846,9.11793197, 8.81927915,8.60164151,8.43137352, 8.29204706,8.17445116,8.07291961});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test6) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
auto q= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
q.linspace(1.);
|
|
x.assign(1.01);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {100.57794334,99.57794334,99.08139709, 98.75170576,98.50514758,98.30834069, 98.1446337 ,98.00452955,97.88210202});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test7) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
auto q= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
q.linspace(1.);
|
|
x.assign(10.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.00099458e+00,9.94575128e-04,1.80126278e-05,1.07754001e-06,1.23865693e-07,2.14656932e-08,4.92752156e-09,1.38738839e-09,4.56065812e-10});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test8) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12});
|
|
auto q= NDArrayFactory::create<float>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12});
|
|
|
|
//q.linspace(1.);
|
|
//x.assign(10.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test9) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12});
|
|
auto q= NDArrayFactory::create<float>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12});
|
|
auto z= NDArrayFactory::create<float>('c', {3,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.});
|
|
|
|
//q.linspace(1.);
|
|
//x.assign(10.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {&z}, {}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results);
|
|
|
|
//auto *output = results->at(0);
|
|
// z.printIndexedBuffer("Zeta output");
|
|
ASSERT_TRUE(expected.isSameShape(z));
|
|
ASSERT_TRUE(expected.equalsTo(z));
|
|
|
|
// delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, zeta_test10) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12});
|
|
auto q= NDArrayFactory::create<float>('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12});
|
|
auto z= NDArrayFactory::create<float>('c', {3,4});
|
|
|
|
//q.linspace(1.);
|
|
//x.assign(10.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
|
|
|
nd4j::ops::zeta op;
|
|
auto results = op.execute({&x, &q}, {&z}, {}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results);
|
|
|
|
//auto *output = results->at(0);
|
|
// z.printIndexedBuffer("Zeta output");
|
|
ASSERT_TRUE(expected.isSameShape(z));
|
|
ASSERT_TRUE(expected.equalsTo(z));
|
|
|
|
// delete results;
|
|
}
|
|
|
|
|
|
TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) {
|
|
auto x = NDArrayFactory::create<float>('c', {8, 7});
|
|
auto indices = NDArrayFactory::create<int>('c',{2}, {5, 3});
|
|
auto axis = NDArrayFactory::create<int>(-2);
|
|
|
|
auto z0 = NDArrayFactory::create<float>('c', {5, 7});
|
|
auto z1 = NDArrayFactory::create<float>('c', {3, 7});
|
|
|
|
nd4j::ops::split_v op;
|
|
auto status = op.execute({&x, &indices, &axis}, {&z0, &z1}, {}, {}, {});
|
|
ASSERT_EQ(Status::OK(), status);
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, polygamma_test1) {
|
|
|
|
auto n= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
// ASSERT_FALSE(true);
|
|
n.linspace(1.);
|
|
x.assign(0.5);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
|
|
|
|
nd4j::ops::polygamma op;
|
|
auto results = op.execute({&n, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto output = results->at(0);
|
|
// output->printBuffer();
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, polygamma_test2) {
|
|
|
|
auto n= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
n.linspace(10.);
|
|
x.linspace(0.5);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {-7.43182451e+09, 3.08334759e+05,-3.25669798e+03, 1.55186197e+02,-1.46220433e+01, 2.00905201e+00,-3.48791235e-01, 7.08016273e-02,-1.60476052e-02});
|
|
|
|
//ASSERT_FALSE(true);
|
|
|
|
nd4j::ops::polygamma op;
|
|
auto results = op.execute({&n, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, polygamma_test3) {
|
|
|
|
auto n= NDArrayFactory::create<float>('c', {3,3});
|
|
auto x= NDArrayFactory::create<float>('c', {3,3});
|
|
|
|
n.linspace(1.);
|
|
x.linspace(10.);
|
|
|
|
auto expected= NDArrayFactory::create<float>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
|
|
|
|
//ASSERT_FALSE(true);
|
|
|
|
nd4j::ops::polygamma op;
|
|
auto results = op.execute({&n, &x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto output = results->at(0);
|
|
|
|
ASSERT_TRUE(expected.isSameShape(output));
|
|
ASSERT_TRUE(expected.equalsTo(output));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test1) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {6,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16});
|
|
auto expS= NDArrayFactory::create<float>('c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827});
|
|
auto expU= NDArrayFactory::create<float>('c', {6,6}, {0.14692,-0.11132,-0.69568, 0.59282,-0.14881, 0.32935,-0.38751, 0.60378,-0.04927,-0.01397,-0.69456,-0.01581, 0.19293,-0.12795,-0.18682,-0.69065,-0.20597, 0.62617, 0.66806, 0.4314 ,-0.33849,-0.22166, 0.04099,-0.44967, 0.11121,-0.64065,-0.02138,-0.07378,-0.60568,-0.45216,-0.5765 ,-0.1007 ,-0.60305,-0.34175, 0.29068,-0.3042});
|
|
auto expV= NDArrayFactory::create<float>('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {1, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test2) {
|
|
|
|
auto x = NDArrayFactory::create<float>('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2});
|
|
auto expS= NDArrayFactory::create<float>('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672});
|
|
auto expU= NDArrayFactory::create<float>('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178});
|
|
auto expV= NDArrayFactory::create<float>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {1, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test3) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2});
|
|
auto expS= NDArrayFactory::create<float>('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672});
|
|
auto expU= NDArrayFactory::create<float>('c', {7,6}, {-0.13417, -0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 ,-0.17329, -0.14666, -0.19639, -0.55355, 0.0614 , 0.75729,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656,-0.08027, -0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359 , -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417});
|
|
auto expV= NDArrayFactory::create<float>('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {0, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test4) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2});
|
|
auto expS= NDArrayFactory::create<float>('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217});
|
|
auto expU= NDArrayFactory::create<float>('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161});
|
|
auto expV= NDArrayFactory::create<float>('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {1, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test5) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2});
|
|
auto expS= NDArrayFactory::create<float>('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217});
|
|
auto expU= NDArrayFactory::create<float>('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161});
|
|
auto expV= NDArrayFactory::create<float>('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {0, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test6) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2
|
|
,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17
|
|
,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14
|
|
,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5});
|
|
auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,
|
|
38.18412, 31.52287, 23.52755, 11.79484, 1.90195,
|
|
39.34498, 32.54861, 17.52492, 7.03003, 2.2399,
|
|
44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
|
|
auto expU= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054,
|
|
-0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,0.58332, 0.10605, 0.51533, 0.50234, 0.36136,0.12588, -0.73123, -0.37812, -0.00215, 0.55361,
|
|
0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,0.44464, -0.25326, -0.42493, -0.01712, -0.74653,0.516 , -0.16688, 0.1854 , -0.77155, 0.27611,
|
|
-0.19321, -0.14317, -0.85886, -0.15224, 0.42585,-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,-0.36993, 0.64862, -0.10956, -0.54483, -0.36552,
|
|
-0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017,
|
|
-0.33605, -0.35257, 0.53215, -0.04936, -0.69075,0.48958, -0.85427, -0.14796, -0.03449, 0.08633,0.15008, 0.60996, 0.31071, -0.67721, 0.22421,
|
|
0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,0.68116, 0.49852, -0.13441, 0.51374, -0.07421,-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,0.12113, -0.13826, 0.83651, 0.11988, -0.50209});
|
|
auto expV= NDArrayFactory::create<float>('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781,
|
|
0.59651, -0.13439, -0.395 , 0.66979, 0.14654,0.73731, 0.47061, 0.19357, -0.41127, -0.16817,0.1047 , -0.29727, 0.73711, 0.38235, -0.45951,
|
|
-0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,-0.66438, 0.05648, 0.03756, -0.31531, 0.67422,
|
|
0.74471, 0.01504, -0.03081, -0.24335, 0.62049,0.03172, 0.91947, 0.30828, 0.23713, 0.04796,-0.01311, 0.38652, -0.79415, -0.42423, -0.19945,
|
|
-0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,0.85214, 0.01628, 0.02688, -0.02891, 0.52157,0.16608, -0.20181, 0.61371, 0.69894, -0.25794,
|
|
0.45726, -0.33952, -0.32659, -0.18938, -0.73015,0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,0.5536 , -0.137 , 0.64688, 0.50536, 0.03017,
|
|
-0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {1, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test7) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2
|
|
,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10
|
|
,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17
|
|
,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2
|
|
,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14
|
|
,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16
|
|
,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5});
|
|
auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,
|
|
38.18412, 31.52287, 23.52755, 11.79484, 1.90195,
|
|
39.34498, 32.54861, 17.52492, 7.03003, 2.2399,
|
|
44.72126, 32.3164 , 16.60139, 6.88783, 0.78122});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {0, 0, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
// TEST_F(DeclarableOpsTests3, svd_test8) {
|
|
|
|
// auto x= NDArrayFactory::create<float>('c', {2,2,11,10}, {3 ,-8 ,0 ,3 ,-5 ,16 ,-3 ,7 ,-4 ,19 ,19 ,13 ,15 ,15 ,9 ,6 ,-7 ,-5 ,-9 ,-12 ,7 ,-1 ,-1 ,6 ,19
|
|
// ,-6 ,16 ,0 ,16 ,16 ,7 ,14 ,18. ,0 ,18 ,-4 ,10 ,-16 ,-17 ,15 ,13 ,-17 ,-14 ,-17 ,-5 ,-9 ,-1 ,-19
|
|
// ,-18 ,5 ,-5 ,-13 ,17 ,-19 ,-5 ,18 ,4 ,10 ,17 ,-7 ,-10 ,16 ,10 ,8 ,-10 ,-3 ,10 ,1 ,-4 ,-16 ,-1
|
|
// ,-1 ,5 ,5 ,17 ,14 ,20 ,15 ,-6 ,19 ,14 ,17 ,0 ,-17 ,-16 ,-8 ,-6 ,3 ,-6 ,-11 ,-4 ,-2 ,-7 ,4 ,-6
|
|
// ,-6 ,-17 ,16 ,-8 ,-20 ,2 ,7 ,-12 ,15 ,-15 ,-19 ,14 ,17 ,9 ,10 ,5 ,18 ,2 ,-6 ,0 ,2 ,-10 ,7 ,8
|
|
// ,-13 ,2 ,8 ,20 ,11 ,-15 ,13 ,-10 ,-14 ,-2 ,20 ,5 ,2 ,16 ,18 ,-3 ,3 ,-18 ,15 ,-11 ,17 ,-8 ,-18
|
|
// ,20 ,-12 ,20 ,20 ,-16 ,20 ,-8. ,19 ,-8 ,3 ,-3 ,17 ,7 ,13 ,9 ,-2 ,11 ,16 ,4 ,-18 ,5 ,0 ,-12 ,9
|
|
// ,-6 ,6 ,0 ,-9 ,-13 ,13 ,17 ,-12 ,3 ,-13 ,17 ,-19 ,17 ,0 ,-8 ,4 ,-19 ,-9 ,-7 ,12 ,-1 ,-12 ,-1
|
|
// ,7 ,2 ,19 ,10 ,19 ,-15 ,-18 ,17 ,-1 ,1 ,14 ,-7 ,-10 ,12 ,-20 ,6 ,-5 ,14 ,5 ,5 ,3 ,-18 ,5 ,17
|
|
// ,-13 ,20 ,-1 ,-2 ,-11 ,-5 ,14 ,8 ,7 ,-13 ,-9 ,-12 ,11 ,3 ,14 ,-6 ,-2 ,13 ,8 ,-15 ,-5 ,-6 ,-7 ,19
|
|
// ,-1 ,6 ,1 ,14 ,8 ,18 ,-20 ,-14 ,-3 ,-5 ,19 ,15 ,13 ,2 ,-20 ,2 ,14 ,13 ,4 ,-15 ,1 ,-14
|
|
// ,0 ,9 ,-1 ,10 ,4 ,6 ,4 ,-7 ,-2 ,-1 ,-15 ,-1 ,-16 ,-5 ,-12 ,-10 ,16 ,-16 ,-15 ,-17 ,-5 ,-6
|
|
// ,18 ,14 ,-3 ,-10 ,8 ,20 ,19 ,20 ,-3 ,-6 ,9 ,10 ,-1 ,-20 ,-5 ,5 ,12 ,8 ,17 ,13 ,-18 ,-14 ,0
|
|
// ,4 ,-11 ,3 ,-12 ,-2 ,-5 ,19 ,-15 ,19 ,16 ,-16 ,13 ,-6 ,11 ,11 ,0 ,-18 ,4 ,5 ,6 ,-12 ,-10
|
|
// ,-3 ,2 ,-18 ,16 ,-5 ,17 ,16 ,-16 ,-20 ,14 ,6 ,10 ,-5 ,-3 ,4 ,20 ,18 ,5 ,1 ,-10 ,15 ,10 ,16
|
|
// ,-18 ,2 ,12 ,20 ,6 ,14 ,8 ,3 ,-2 ,9 ,15 ,-4 ,13 ,-19 ,-5 ,3 ,3 ,-20 ,-4 ,18 ,-11 ,11 ,-10
|
|
// ,3 ,8 ,9 ,20 ,-19 ,6 ,18 ,9 ,20 ,-12 ,4 ,15 ,19 ,3 ,5 ,1 ,2 ,20 ,-3 ,-1 ,-8 ,-3 ,8 ,17 ,
|
|
// -14 ,18 ,-10 ,4 ,13 ,-5 ,13 ,-6 ,12 ,-10 ,19 ,4 ,-7 ,-17 ,20 ,8 ,6 ,-3 ,3 ,-7 ,-18 ,17 ,
|
|
// -13 ,18 ,-20 ,-16 ,-5 ,12 ,5 ,17 ,-4 ,4 ,7 ,8 ,17 ,-9 ,-12 ,-10 ,8 ,-14 ,-11 ,7 ,19 ,-17});
|
|
|
|
// auto expS= NDArrayFactory::create<float>('c', {2,2,10}, { 64.12636, 54.37044, 50.63744, 48.10308, 33.7364 , 29.96456,
|
|
// 25.53945, 19.31856, 15.30939, 9.31349,
|
|
// 67.41342, 59.64963, 58.72687, 39.22496, 32.39772, 29.30833,
|
|
// 23.1491 , 16.92442, 6.38613, 3.49563,
|
|
// 74.37477, 52.07016, 46.10758, 39.10742, 32.02261, 27.05888,
|
|
// 20.54921, 13.17989, 8.4158 , 4.39974,
|
|
// 65.47447, 56.31305, 54.13371, 46.26955, 43.47755, 30.25799,
|
|
// 20.71463, 16.89671, 10.39572, 7.81631});
|
|
|
|
// auto expU= NDArrayFactory::create<float>('c', {2,2,11,11}, {-0.177870, -0.149461, -0.196911, 0.036990, -0.338237, 0.548901,
|
|
// -0.074396, 0.497067, -0.083636, -0.111810, -0.466989, -0.010465, 0.434732, 0.337198, 0.305239, -0.292813,
|
|
// 0.041280, -0.517144, 0.121499, 0.464908, 0.003658, 0.135017, -0.446916, -0.098318, 0.073571, -0.200521,
|
|
// 0.186776, -0.353022, -0.435582, -0.225959, 0.052972, 0.032390, -0.583801, -0.402790, 0.562809, 0.102744,
|
|
// 0.066555, 0.206079, 0.115322, 0.217220, -0.062591, -0.273173, -0.569645, 0.005612, 0.092601, 0.350055,
|
|
// -0.608007, -0.367743, 0.064860, 0.112656, 0.091576, -0.144262, 0.554655, -0.042100, -0.092023, 0.026986,
|
|
// -0.395811, -0.245209, 0.572522, 0.429430, 0.099621, -0.159236, -0.086263, 0.268160, -0.391298, 0.050417,
|
|
// 0.150175, 0.045253, 0.464173, 0.138376, 0.265551, 0.049691, 0.528778, 0.116951, 0.384609, 0.144416,
|
|
// -0.453591, -0.519390, -0.150671, 0.072897, 0.102406, -0.154184, 0.450735, 0.174171, -0.519405, 0.147109,
|
|
// 0.333670, 0.178053, 0.360763, 0.226976, 0.069976, -0.046765, 0.448897, 0.511309, -0.361050, -0.191690,
|
|
// -0.304442, 0.270383, -0.124133, 0.417183, -0.083359, 0.137022, 0.004276, -0.462336, 0.051267, 0.020622,
|
|
// -0.566932, -0.051351, -0.417106, -0.292202, -0.021595, -0.315956, 0.396626, -0.604952, 0.155990, 0.258395,
|
|
// -0.125080, 0.115404, 0.234517, -0.357460, 0.271271, 0.063771, -0.087400, -0.024710, -0.179892, 0.584339,
|
|
// -0.413085, 0.510580, 0.334646, 0.044424, 0.224735, 0.134434, -0.147861, 0.291853, 0.487948, 0.238917,
|
|
// 0.433893, 0.435884, 0.056370, -0.051216, -0.450902, 0.062411, 0.080733, -0.365211, 0.031931, 0.493926,
|
|
// -0.239428, 0.038247, -0.180721, -0.118035, 0.042175, 0.377296, -0.516399, 0.324744, -0.756196, 0.160856,
|
|
// -0.152527, -0.046867, -0.092933, -0.044945, 0.137659, 0.246552, -0.071709, 0.032821, -0.529356, -0.029669,
|
|
// 0.200178, 0.188916, 0.428036, -0.496734, -0.164185, 0.629070, -0.131588, 0.073992, 0.066877, 0.208450,
|
|
// -0.156170, -0.253670, -0.000365, -0.121172, 0.067774, 0.618226, 0.230460, -0.118865, 0.579424, 0.324523,
|
|
// 0.038653, 0.310308, 0.570186, -0.217271, -0.110967, 0.196375, 0.167058, 0.264071, -0.130023, 0.254189,
|
|
// -0.459057, -0.301033, 0.069932, -0.033338, -0.070600, 0.685064, 0.130274, 0.074929, -0.206899, 0.574057,
|
|
// 0.327277, -0.131588, -0.018497, 0.312445, 0.314594, 0.480422, -0.293858, -0.273277, -0.006598, -0.134574,
|
|
// 0.403501, 0.140025, 0.380693, -0.257039, -0.067012, 0.248776, -0.361838, -0.270296, -0.225844, 0.320245,
|
|
// 0.055730, 0.454809, -0.212163, -0.063281, 0.563112, -0.200737, 0.537389, -0.210845, 0.109997, 0.166215,
|
|
// -0.243725, -0.347349, -0.274348, 0.263950, 0.437134, 0.265820, -0.127520, -0.033325, -0.137156, 0.518557,
|
|
// 0.246720, 0.389394, -0.600568, 0.062027, -0.047838, -0.338416, 0.032778, -0.141998, -0.338022, -0.381467,
|
|
// 0.210512, -0.314413, 0.256321, 0.001460, 0.238901, 0.139840, 0.633423, -0.182575, -0.461504, 0.290250,
|
|
// -0.025930, 0.336998, -0.211280, -0.662387, -0.207946, -0.003860, -0.147842, 0.157217, 0.123704, 0.345686,
|
|
// 0.337946, 0.138261, -0.178814, -0.109597, 0.087135, -0.509500, -0.300296, -0.262279, 0.377476, -0.366815,
|
|
// 0.091787, 0.247495, -0.193812, -0.179714, 0.238552, -0.162305, -0.029549, 0.785426, -0.157586, -0.084533,
|
|
// -0.357024, 0.317878, 0.217656, 0.125319, 0.648832, 0.344045, -0.001109, 0.457190, -0.072439, -0.106278,
|
|
// 0.228962, -0.136139, -0.528342, -0.020840, -0.108908, -0.231661, 0.396864, 0.234925, 0.180894, -0.179430,
|
|
// -0.587730, 0.178276, -0.008672, -0.386172, 0.033155, 0.319568, 0.101457, -0.272011, 0.126007, 0.175374,
|
|
// -0.081668, 0.112987, -0.296422, -0.713743, 0.269413, -0.082098, -0.338649, 0.131035, -0.518616, 0.022478,
|
|
// 0.177802, -0.042432, -0.606219, -0.343848, 0.014416, -0.141375, 0.748332, -0.165911, -0.049067, -0.241062,
|
|
// 0.436318, 0.173318, 0.058066, 0.193764, -0.000647, 0.265777, -0.027847, -0.096305, 0.711632, 0.066506,
|
|
// -0.223124, 0.219165, -0.038165, 0.427444, -0.296887, 0.139982, 0.298976, 0.294876, -0.001315, 0.419802,
|
|
// 0.475401, -0.156256, -0.289477, -0.438761, -0.116348, 0.108350, -0.369368, -0.219943, 0.433088, 0.187565,
|
|
// -0.217259, 0.147014, -0.538991, -0.065052, 0.310337, 0.491887, 0.254439, 0.075052, 0.071155, -0.084856,
|
|
// 0.402098, 0.096270, 0.093662, -0.475769, 0.256832, 0.161394, -0.390050, -0.513551, -0.184665, 0.211506,
|
|
// -0.112525, -0.493409, -0.258765, 0.262124, -0.272998, 0.269370, 0.266226, -0.367919, 0.192386, -0.006422,
|
|
// -0.466728, -0.481792, 0.090611, -0.156359, 0.178693, -0.371658, -0.214190, -0.469058, -0.006134, 0.081902,
|
|
// 0.536950, 0.064836, -0.334010, 0.523530, -0.182061, -0.206686, 0.002985, 0.054858, -0.038727, -0.075390,
|
|
// 0.543839, -0.442964, -0.190550, -0.298127, -0.065323, 0.131415, 0.329899, 0.122096, -0.507075, 0.523751,
|
|
// -0.167317, 0.198593, -0.069066, 0.402739, 0.328583, 0.314184, -0.268003, -0.148549, 0.118925, -0.508174,
|
|
// 0.128716, -0.405597, -0.157224, 0.271021, -0.384444, -0.174935, 0.343919, -0.076726, 0.607931, 0.383931,
|
|
// 0.198254, 0.133707, 0.321460, -0.232543, 0.099988, -0.321954, -0.366304, -0.137440, 0.232835, -0.290306,
|
|
// -0.260804, -0.347721, 0.182895, 0.382311, -0.332847, -0.192469, -0.438258, -0.017533, -0.192976, -0.702531,
|
|
// 0.124463, 0.039719, -0.221319, -0.224785, 0.096356, -0.302131, -0.462598, 0.194320});
|
|
|
|
|
|
// auto expV= NDArrayFactory::create<float>('c', {2,2,10,10}, {-0.050761, 0.370975, -0.061567, -0.125530, 0.024081, 0.275524, -0.800334,
|
|
// -0.025855, 0.348132, 0.036882, 0.034921, 0.307295, 0.629837, 0.014276, 0.265687, 0.188407, -0.035481, 0.082827,
|
|
// -0.490175, 0.391118, -0.180180, 0.169108, 0.206663, 0.623321, 0.260009, 0.081943, 0.004485, 0.136199, 0.060353,
|
|
// -0.641224, -0.181559, -0.041761, 0.578416, -0.161798, -0.573128, -0.187563, 0.012533, 0.368041, 0.314619,
|
|
// -0.079349, -0.527508, 0.216020, 0.004721, 0.188769, -0.242534, -0.442685, -0.121683, -0.565306, -0.202894,
|
|
// 0.095280, -0.181900, -0.170627, -0.201655, 0.620259, -0.257996, 0.277656, -0.009623, 0.266775, 0.081952,
|
|
// 0.539241, -0.452254, -0.136142, 0.177049, -0.144734, 0.494673, 0.101613, 0.280091, -0.186281, 0.548779,
|
|
// 0.235160, 0.054763, -0.571503, 0.298086, 0.035312, -0.195188, 0.474030, -0.175457, -0.497267, -0.101439,
|
|
// -0.170678, -0.060605, -0.557305, 0.073433, 0.057195, 0.352091, -0.486102, -0.483569, 0.252091, -0.121245,
|
|
// 0.068719, -0.638919, -0.078029, -0.236556, -0.351440, -0.024437, 0.319855, -0.007406, 0.319691, -0.402334,
|
|
// -0.197966, 0.058936, -0.360900, 0.233414, -0.251532, 0.105457, 0.048097, 0.029321, 0.002714, -0.845953,
|
|
// -0.136344, 0.378037, 0.277491, 0.278420, 0.037491, 0.432117, -0.586745, 0.104573, 0.316569, -0.039848,
|
|
// 0.239645, -0.320923, 0.555156, 0.145059, -0.546959, 0.267760, 0.298029, 0.177831, -0.191286, -0.032427,
|
|
// 0.197034, 0.081887, -0.113063, 0.711713, 0.020279, -0.362346, -0.145776, 0.173289, -0.500880, 0.181624,
|
|
// 0.084391, -0.278967, 0.212143, -0.413382, 0.012879, -0.216886, -0.625774, 0.066795, -0.421937, -0.291320,
|
|
// 0.011402, -0.416660, -0.134200, 0.043039, 0.554715, 0.126867, 0.147315, 0.474334, 0.094354, -0.156458,
|
|
// 0.450168, 0.447448, 0.261750, -0.161426, -0.064309, -0.592417, 0.210891, 0.104312, 0.176178, -0.237020,
|
|
// 0.455579, -0.358056, -0.307454, 0.033700, -0.486831, -0.303963, -0.284916, 0.241549, 0.510701, 0.206104,
|
|
// 0.062587, 0.248212, 0.132088, -0.122704, 0.026342, -0.011108, 0.066306, 0.763127, 0.009491, 0.038822,
|
|
// -0.562773, -0.320104, 0.477773, 0.354169, 0.293329, -0.304227, -0.001662, -0.213324, 0.365277, -0.198056,
|
|
// -0.383499, -0.017789, 0.324542, -0.642856, 0.238689, -0.360461, -0.060599, -0.257192, 0.342400, 0.180845,
|
|
// 0.272810, -0.452278, -0.409323, 0.077013, -0.082561, 0.334893, -0.103309, -0.198049, 0.480416, 0.470593,
|
|
// 0.029072, -0.300574, 0.532293, 0.250892, -0.355298, 0.079716, -0.319781, 0.259925, 0.277872, -0.251917,
|
|
// 0.346821, 0.161642, 0.205861, 0.107125, -0.594779, -0.226272, 0.610183, -0.065926, 0.170332, 0.312553,
|
|
// -0.108093, 0.368268, -0.183109, -0.192222, -0.544559, 0.136824, -0.412352, -0.398250, -0.257291, 0.019911,
|
|
// 0.288797, 0.013350, 0.349817, -0.108331, 0.180576, 0.652863, 0.319319, 0.020218, -0.324499, 0.290877,
|
|
// 0.338518, -0.301776, -0.440871, -0.281683, -0.158759, -0.080281, 0.418260, 0.189926, -0.064112, -0.390914,
|
|
// 0.485420, -0.464327, 0.211070, 0.044295, -0.032292, 0.043985, 0.147160, -0.702247, -0.198395, -0.352940,
|
|
// -0.237014, -0.438235, 0.073448, -0.418712, -0.280275, -0.091373, -0.194273, 0.347558, -0.421767, 0.283011,
|
|
// -0.351869, -0.210088, -0.034628, 0.448410, 0.149194, -0.488551, -0.068805, -0.117007, -0.390999, 0.377100,
|
|
// 0.423252, -0.041944, 0.455115, -0.537818, 0.266732, 0.218202, 0.047475, -0.383506, -0.158858, 0.450881,
|
|
// 0.072415, 0.355772, 0.002360, 0.138976, 0.541349, -0.295405, 0.463832, 0.400676, -0.168962, 0.259334,
|
|
// -0.047960, 0.272197, 0.582658, 0.198052, 0.127300, -0.320468, -0.104858, -0.229698, 0.046672, -0.474224,
|
|
// 0.370765, -0.246450, 0.212667, 0.024935, -0.344530, -0.238547, 0.185931, 0.269068, 0.487414, 0.421376,
|
|
// 0.442391, -0.284247, 0.304973, -0.365006, -0.159016, -0.129088, -0.126454, 0.600462, -0.461163, -0.243552,
|
|
// -0.049814, -0.381340, -0.054504, 0.436237, 0.126120, -0.359677, -0.409734, -0.179422, -0.414820, 0.371149,
|
|
// 0.078299, 0.503544, 0.322165, 0.148341, -0.495447, -0.084355, -0.174667, 0.016802, -0.066954, 0.318825,
|
|
// -0.480771, -0.060163, 0.144302, -0.041555, 0.459106, 0.029882, -0.565026, 0.282336, 0.528472, 0.044916,
|
|
// -0.286167, -0.101052, -0.181529, -0.419406, -0.032204, -0.732282, 0.106833, -0.288881, 0.171516, -0.096242,
|
|
// -0.331834, -0.493188, 0.393195, 0.358365, 0.049125, 0.123457, 0.438169, -0.105015, 0.092386, -0.130413, -0.476991});
|
|
|
|
// nd4j::ops::svd op;
|
|
// auto results = op.execute({&x}, {}, {1, 1, 7});
|
|
|
|
// ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
// auto *s = results->at(0);
|
|
// auto *u = results->at(1);
|
|
// auto *v = results->at(2);
|
|
|
|
// ASSERT_TRUE(expS.isSameShape(s));
|
|
// ASSERT_TRUE(expU.isSameShape(u));
|
|
// ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
// ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
// if(nd4j::Environment::getInstance()->isCPU()) {
|
|
// ASSERT_TRUE(expU.equalsTo(u));
|
|
// ASSERT_TRUE(expV.equalsTo(v));
|
|
// }
|
|
// else {
|
|
// for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
// ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
// for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
// ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
// }
|
|
|
|
// delete results;
|
|
// }
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test9) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 ,
|
|
-15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 ,
|
|
1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 ,
|
|
-8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 ,
|
|
-1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8});
|
|
|
|
auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545,
|
|
38.56035, 30.62846, 26.31646, 19.42605, 3.01162,
|
|
38.56369, 29.18881, 19.54565, 10.89746, 2.017 ,
|
|
44.99108, 34.95059, 26.00453, 15.43898, 7.18752});
|
|
|
|
auto expU= NDArrayFactory::create<float>('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,
|
|
0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,
|
|
0.03076, 0.03015, -0.9128 , 0.36387, 0.18039,
|
|
-0.61335, 0.10076, 0.01381, 0.40922, -0.66783,
|
|
-0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,
|
|
0.52823, 0.31163, -0.78777, 0.02322, -0.05234,
|
|
-0.23942, -0.45801, -0.34248, 0.71286, 0.32778,
|
|
0.26147, 0.60409, 0.39933, 0.46862, 0.43318,
|
|
0.62118, -0.37993, 0.30992, 0.34537, -0.50444,
|
|
0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,
|
|
-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,
|
|
-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944,
|
|
-0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,
|
|
-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,
|
|
0.50772, 0.00757, -0.82374, -0.18941, 0.16658,
|
|
0.49473, -0.39923, -0.20758, 0.74339, -0.01213,
|
|
-0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,
|
|
0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,
|
|
0.34555, 0.12488, -0.50703, -0.29269, 0.72267,
|
|
-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913});
|
|
|
|
auto expV= NDArrayFactory::create<float>('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,
|
|
-1.10690000e-01, 1.37280000e-01,
|
|
2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,
|
|
-1.00090000e-01, 9.35890000e-01,
|
|
-4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,
|
|
6.70320000e-01, 2.10040000e-01,
|
|
1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,
|
|
-4.38680000e-01, 1.83200000e-02,
|
|
-5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01,
|
|
-2.10060000e-01, 2.41550000e-01,
|
|
-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,
|
|
-5.40210000e-01, -4.97000000e-02,
|
|
-6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,
|
|
-4.21850000e-01, 4.00490000e-01,
|
|
1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,
|
|
-2.06580000e-01, 7.68890000e-01,
|
|
-4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,
|
|
5.88210000e-01, 7.12900000e-02,
|
|
2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01,
|
|
1.51570000e-01, 6.02100000e-02,
|
|
1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,
|
|
-5.79750000e-01, -2.92870000e-01,
|
|
4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,
|
|
2.72560000e-01, 3.92350000e-01,
|
|
-6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,
|
|
-1.17970000e-01, -4.08100000e-02,
|
|
4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,
|
|
1.93050000e-01, -6.83340000e-01,
|
|
8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,
|
|
2.37500000e-02, 5.78250000e-01,
|
|
-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,
|
|
-9.19220000e-01, -2.15420000e-01,
|
|
5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,
|
|
-3.19580000e-01, 2.92020000e-01,
|
|
2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,
|
|
3.39100000e-02, 2.55590000e-01,
|
|
-4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,
|
|
4.98470000e-01, -3.65370000e-01,
|
|
6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,
|
|
-7.30460000e-01, -1.09390000e-01,
|
|
-4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,
|
|
5.20000000e-04, 1.90420000e-01,
|
|
2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,
|
|
-3.57600000e-02, -8.60450000e-01,
|
|
1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,
|
|
-4.39400000e-02, 2.17750000e-01,
|
|
-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,
|
|
-4.63400000e-01, -1.74620000e-01});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {1, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test10) {
|
|
|
|
auto x= NDArrayFactory::create<float>('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 ,
|
|
-15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 ,
|
|
1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 ,
|
|
-8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 ,
|
|
-1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8});
|
|
|
|
auto expS= NDArrayFactory::create<float>('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545,
|
|
38.56035, 30.62846, 26.31646, 19.42605, 3.01162,
|
|
38.56369, 29.18881, 19.54565, 10.89746, 2.017 ,
|
|
44.99108, 34.95059, 26.00453, 15.43898, 7.18752});
|
|
|
|
auto expU= NDArrayFactory::create<float>('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,
|
|
0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,
|
|
0.03076, 0.03015, -0.9128 , 0.36387, 0.18039,
|
|
-0.61335, 0.10076, 0.01381, 0.40922, -0.66783,
|
|
-0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,
|
|
0.52823, 0.31163, -0.78777, 0.02322, -0.05234,
|
|
-0.23942, -0.45801, -0.34248, 0.71286, 0.32778,
|
|
0.26147, 0.60409, 0.39933, 0.46862, 0.43318,
|
|
0.62118, -0.37993, 0.30992, 0.34537, -0.50444,
|
|
0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,
|
|
-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,
|
|
-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944,
|
|
-0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,
|
|
-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,
|
|
0.50772, 0.00757, -0.82374, -0.18941, 0.16658,
|
|
0.49473, -0.39923, -0.20758, 0.74339, -0.01213,
|
|
-0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,
|
|
0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,
|
|
0.34555, 0.12488, -0.50703, -0.29269, 0.72267,
|
|
-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913});
|
|
|
|
auto expV= NDArrayFactory::create<float>('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,
|
|
-1.10690000e-01,
|
|
2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,
|
|
-1.00090000e-01,
|
|
-4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,
|
|
6.70320000e-01,
|
|
1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,
|
|
-4.38680000e-01,
|
|
-5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01,
|
|
-2.10060000e-01,
|
|
-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,
|
|
-5.40210000e-01,
|
|
-6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,
|
|
-4.21850000e-01,
|
|
1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,
|
|
-2.06580000e-01,
|
|
-4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,
|
|
5.88210000e-01,
|
|
2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01,
|
|
1.51570000e-01,
|
|
1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,
|
|
-5.79750000e-01,
|
|
4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,
|
|
2.72560000e-01,
|
|
-6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,
|
|
-1.17970000e-01,
|
|
4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,
|
|
1.93050000e-01,
|
|
8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,
|
|
2.37500000e-02,
|
|
-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,
|
|
-9.19220000e-01,
|
|
5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,
|
|
-3.19580000e-01,
|
|
2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,
|
|
3.39100000e-02,
|
|
-4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,
|
|
4.98470000e-01,
|
|
6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,
|
|
-7.30460000e-01,
|
|
-4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,
|
|
5.20000000e-04,
|
|
2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,
|
|
-3.57600000e-02,
|
|
1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,
|
|
-4.39400000e-02,
|
|
-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,
|
|
-4.63400000e-01});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {0, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto *s = results->at(0);
|
|
auto *u = results->at(1);
|
|
auto *v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
|
|
if(nd4j::Environment::getInstance()->isCPU()) {
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
}
|
|
else {
|
|
for(uint i = 0; i < expU.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expU.t<float>(i)), nd4j::math::nd4j_abs(u->t<float>(i)), 1e-5);
|
|
for(uint i = 0; i < expV.lengthOf(); ++i)
|
|
ASSERT_NEAR(nd4j::math::nd4j_abs(expV.t<float>(i)), nd4j::math::nd4j_abs(v->t<float>(i)), 1e-5);
|
|
}
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, svd_test11) {
|
|
|
|
NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555,
|
|
0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234,
|
|
0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695});
|
|
NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571});
|
|
NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01,
|
|
7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01,
|
|
7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01,
|
|
8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01});
|
|
NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763,
|
|
-0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413,
|
|
-0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232,
|
|
-0.43596, 0.83108, -0.34531});
|
|
|
|
nd4j::ops::svd op;
|
|
auto results = op.execute({&x}, {}, {0, 1, 16});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto s = results->at(0);
|
|
auto u = results->at(1);
|
|
auto v = results->at(2);
|
|
|
|
ASSERT_TRUE(expS.isSameShape(s));
|
|
ASSERT_TRUE(expU.isSameShape(u));
|
|
ASSERT_TRUE(expV.isSameShape(v));
|
|
|
|
ASSERT_TRUE(expS.equalsTo(s));
|
|
ASSERT_TRUE(expU.equalsTo(u));
|
|
ASSERT_TRUE(expV.equalsTo(v));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, elu_test1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9});
|
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9});
|
|
|
|
nd4j::ops::elu op;
|
|
auto results = op.execute({&x}, {0.5}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto s = results->at(0);
|
|
ASSERT_TRUE(exp.equalsTo(s));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, elu_bp_test1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9});
|
|
auto eps = NDArrayFactory::create<double>('c', {3,3});
|
|
eps.assign(2.);
|
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2});
|
|
|
|
nd4j::ops::elu_bp op;
|
|
auto results = op.execute({ &x, &eps }, {0.5}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto s = results->at(0);
|
|
ASSERT_TRUE(exp.equalsTo(s));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, lrelu_test1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
|
|
|
|
nd4j::ops::lrelu op;
|
|
auto results = op.execute({&x}, {0.2}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto s = results->at(0);
|
|
ASSERT_TRUE(exp.equalsTo(s));
|
|
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, lrelu_bp_test1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
|
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
|
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2});
|
|
|
|
nd4j::ops::lrelu_bp op;
|
|
auto results = op.execute({&x, &eps}, {0.2}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto s = results->at(0);
|
|
ASSERT_TRUE(exp.equalsTo(s));
|
|
|
|
delete results;
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////
|
|
TEST_F(DeclarableOpsTests3, selu_test1) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
|
|
|
|
nd4j::ops::selu op;
|
|
auto results = op.execute({&x}, {}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto s = results->at(0);
|
|
ASSERT_TRUE(exp.equalsTo(s));
|
|
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, selu_test2) {
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
|
// auto expS = NDArrayFactory::create<double>('c', {3});
|
|
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
|
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402});
|
|
|
|
nd4j::ops::selu_bp op;
|
|
auto results = op.execute({&x, &eps}, {0.2}, {});
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
|
|
|
auto s = results->at(0);
|
|
// auto u = results->at(1);
|
|
// auto v = results->at(2);
|
|
// s->printIndexedBuffer("SELU_BP");
|
|
ASSERT_TRUE(exp.equalsTo(s));
|
|
|
|
delete results;
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, EQScalarTests_1) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(1.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::eq_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_TRUE(res);
|
|
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, EQScalarTests_2) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(2.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::eq_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_FALSE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, GTScalarTests_1) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(1.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::gt_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_FALSE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, GTScalarTests_2) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(2.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::gt_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_TRUE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, GTEScalarTests_1) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(1.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::gte_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_TRUE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, GTEScalarTests_2) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(2.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::gte_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_TRUE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, GTEScalarTests_3) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(1.0f);
|
|
auto scalar = NDArrayFactory::create(2.0f);
|
|
|
|
nd4j::ops::gte_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_FALSE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, LTEScalarTests_1) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(1.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::lte_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_TRUE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, LTEScalarTests_2) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(2.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::lte_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_FALSE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, LTEScalarTests_3) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(1.0f);
|
|
auto scalar = NDArrayFactory::create(2.0f);
|
|
|
|
nd4j::ops::lte_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_TRUE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, NEQScalarTests_1) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(1.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::neq_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_FALSE(res);
|
|
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, NEQScalarTests_2) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(2.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::neq_scalar op;
|
|
auto res = op.evaluate({&x, &scalar});
|
|
ASSERT_TRUE(res);
|
|
}
|
|
|
|
TEST_F(DeclarableOpsTests3, NOOPTests_1) {
|
|
Graph graph;
|
|
|
|
auto x = NDArrayFactory::create(2.0f);
|
|
auto scalar = NDArrayFactory::create(1.0f);
|
|
|
|
nd4j::ops::noop op;
|
|
auto res = op.execute({&x, &scalar}, {}, {});
|
|
ASSERT_TRUE(res->status() == nd4j::Status::OK());
|
|
delete res;
|
|
}
|