2019-06-06 14:21:15 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
|
|
|
#include "testlayers.h"
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <array/NDArray.h>
|
|
|
|
#include <graph/GraphExecutioner.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
#include <ops/declarable/CustomOperations.h>
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
using namespace sd;
|
|
|
|
using namespace sd::ops;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
class ListOperationsTests : public testing::Test {
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Write_1) {
|
|
|
|
NDArrayList list(5);
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {128});
|
|
|
|
x.linspace(1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::write_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto result = op.execute(&list, {&x}, {}, {1});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(1, list.elements());
|
|
|
|
|
|
|
|
auto result2 = op.execute(&list, {&x}, {}, {2});
|
|
|
|
|
|
|
|
ASSERT_EQ(2, list.elements());
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Stack_1) {
|
|
|
|
NDArrayList list(10);
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {10, 100});
|
|
|
|
auto tads = exp.allTensorsAlongDimension({1});
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
|
|
|
row->assign((double) e);
|
|
|
|
list.write(e, row);
|
2019-12-20 20:35:39 +01:00
|
|
|
tads.at(e)->assign(row);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::stack_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto result = op.execute(&list, {}, {}, {1});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
auto z = result.at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
// z->printShapeInfo();
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2019-08-15 12:54:47 +02:00
|
|
|
TEST_F(ListOperationsTests, BasicTest_UnStackList_1) {
|
|
|
|
NDArrayList list(0, true);
|
|
|
|
auto x = NDArrayFactory::create<double>('c', {10, 100});
|
|
|
|
auto tads = x.allTensorsAlongDimension({1});
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
|
|
|
row->assign((double) e);
|
|
|
|
//list.write(e, row);
|
2019-12-20 20:35:39 +01:00
|
|
|
tads.at(e)->assign(row);
|
2019-08-15 12:54:47 +02:00
|
|
|
delete row;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::unstack_list op;
|
2019-08-15 12:54:47 +02:00
|
|
|
|
|
|
|
auto result = op.execute(&list, {&x}, {}, {0});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-08-15 12:54:47 +02:00
|
|
|
ASSERT_EQ(list.elements(), 10);
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
// auto z = result.at(0);
|
2019-08-15 12:54:47 +02:00
|
|
|
// z->printShapeInfo("The first of");
|
|
|
|
// ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
// ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = list.read(e);
|
2019-12-20 20:35:39 +01:00
|
|
|
ASSERT_TRUE(row->equalsTo(tads.at(e)));
|
2019-08-15 12:54:47 +02:00
|
|
|
//list.write(e, row);
|
2019-08-15 14:28:19 +02:00
|
|
|
delete row;
|
2019-08-15 12:54:47 +02:00
|
|
|
}
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-08-15 12:54:47 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) {
|
|
|
|
//// NDArrayList list(0, true);
|
|
|
|
// auto x = NDArrayFactory::create<double>('c', {10, 100});
|
|
|
|
// auto tads = x.allTensorsAlongDimension({1});
|
|
|
|
// for (int e = 0; e < 10; e++) {
|
|
|
|
// auto row = NDArrayFactory::create_<double>('c', {100});
|
|
|
|
// row->assign((double) e);
|
|
|
|
// //list.write(e, row);
|
|
|
|
// tads->at(e)->assign(row);
|
|
|
|
// delete row;
|
|
|
|
// }
|
|
|
|
//
|
2020-03-02 10:49:41 +01:00
|
|
|
// sd::ops::unstack_list op;
|
2019-08-15 12:54:47 +02:00
|
|
|
//
|
|
|
|
// auto result = op.execute(nullptr, {&x}, {}, {0});
|
|
|
|
//
|
2020-03-10 05:42:50 +01:00
|
|
|
// ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-08-15 12:54:47 +02:00
|
|
|
// ASSERT_EQ(result->size(), 10);
|
|
|
|
//
|
2020-03-10 05:42:50 +01:00
|
|
|
// // auto z = result.at(0);
|
2019-08-15 12:54:47 +02:00
|
|
|
//// z->printShapeInfo("The first of");
|
|
|
|
//// ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
//// ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
// for (int e = 0; e < 10; e++) {
|
2020-03-10 05:42:50 +01:00
|
|
|
// auto row = result.at(e);
|
2019-08-15 12:54:47 +02:00
|
|
|
// ASSERT_TRUE(row->equalsTo(tads->at(e)));
|
|
|
|
// //list.write(e, row);
|
|
|
|
// }
|
|
|
|
//
|
2020-03-10 05:42:50 +01:00
|
|
|
//
|
2019-08-15 12:54:47 +02:00
|
|
|
// delete tads;
|
|
|
|
//}
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Read_1) {
|
|
|
|
NDArrayList list(10);
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {1, 100});
|
|
|
|
exp.assign(4.0f);
|
|
|
|
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {1, 100});
|
|
|
|
row->assign((double) e);
|
2019-12-20 20:35:39 +01:00
|
|
|
list.write(e, new NDArray(row->dup()));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete row;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::read_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto result = op.execute(&list, {}, {}, {4});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
auto z = result.at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Pick_1) {
|
|
|
|
NDArrayList list(10);
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {4, 100});
|
|
|
|
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
|
|
|
row->assign((double) e);
|
2019-12-20 20:35:39 +01:00
|
|
|
list.write(e, new NDArray(row->dup()));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete row;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto tads = exp.allTensorsAlongDimension({1});
|
2019-12-20 20:35:39 +01:00
|
|
|
tads.at(0)->assign(1.0f);
|
|
|
|
tads.at(1)->assign(1.0f);
|
|
|
|
tads.at(2)->assign(3.0f);
|
|
|
|
tads.at(3)->assign(3.0f);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::pick_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto result = op.execute(&list, {}, {}, {1, 1, 3, 3});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
auto z = result.at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Size_1) {
|
|
|
|
NDArrayList list(10);
|
|
|
|
auto exp = NDArrayFactory::create<int>(10);
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {100});
|
|
|
|
row->assign((double) e);
|
2019-12-20 20:35:39 +01:00
|
|
|
list.write(e, new NDArray(row->dup()));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete row;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::size_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto result = op.execute(&list, {}, {}, {1});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
auto z = result.at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Create_1) {
|
|
|
|
auto matrix = NDArrayFactory::create<double>('c', {3, 2});
|
|
|
|
matrix.linspace(1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::create_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto result = op.execute(nullptr, {&matrix}, {}, {1, 1});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// we return flow as well
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(1, result.size());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Split_1) {
|
|
|
|
NDArrayList list(0, true);
|
|
|
|
|
|
|
|
auto exp0 = NDArrayFactory::create<double>('c', {2, 5});
|
|
|
|
auto exp1 = NDArrayFactory::create<double>('c', {3, 5});
|
|
|
|
auto exp2 = NDArrayFactory::create<double>('c', {5, 5});
|
|
|
|
|
|
|
|
auto matrix = NDArrayFactory::create<double>('c', {10, 5});
|
|
|
|
|
|
|
|
auto lengths = NDArrayFactory::create<int>('c', {3});
|
|
|
|
lengths.p(0, 2);
|
|
|
|
lengths.p(1, 3);
|
|
|
|
lengths.p(2, 5);
|
|
|
|
|
|
|
|
auto tads = matrix.allTensorsAlongDimension({1});
|
|
|
|
|
|
|
|
auto tads0 = exp0.allTensorsAlongDimension({1});
|
|
|
|
auto tads1 = exp1.allTensorsAlongDimension({1});
|
|
|
|
auto tads2 = exp2.allTensorsAlongDimension({1});
|
|
|
|
|
|
|
|
int cnt0 = 0;
|
|
|
|
int cnt1 = 0;
|
|
|
|
int cnt2 = 0;
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {5});
|
|
|
|
row->assign((double) e);
|
2019-12-20 20:35:39 +01:00
|
|
|
tads.at(e)->assign(row);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
if (e < 2)
|
2019-12-20 20:35:39 +01:00
|
|
|
tads0.at(cnt0++)->assign(row);
|
2019-06-06 14:21:15 +02:00
|
|
|
else if (e < 5)
|
2019-12-20 20:35:39 +01:00
|
|
|
tads1.at(cnt1++)->assign(row);
|
2019-06-06 14:21:15 +02:00
|
|
|
else
|
2019-12-20 20:35:39 +01:00
|
|
|
tads2.at(cnt2++)->assign(row);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete row;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::split_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto result = op.execute(&list, {&matrix, &lengths}, {}, {});
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(Status::OK(), result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_EQ(3, list.height());
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp0.isSameShape(list.readRaw(0)));
|
|
|
|
ASSERT_TRUE(exp0.equalsTo(list.readRaw(0)));
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp1.isSameShape(list.readRaw(1)));
|
|
|
|
ASSERT_TRUE(exp1.equalsTo(list.readRaw(1)));
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp2.isSameShape(list.readRaw(2)));
|
|
|
|
ASSERT_TRUE(exp2.equalsTo(list.readRaw(2)));
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Scatter_1) {
|
|
|
|
NDArrayList list(0, true);
|
|
|
|
auto s = NDArrayFactory::create<double>(0.0);
|
|
|
|
|
|
|
|
auto matrix = NDArrayFactory::create<double>('c', {10, 5});
|
|
|
|
auto tads = matrix.allTensorsAlongDimension({1});
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {1, 5});
|
|
|
|
row->assign((double) e);
|
2019-12-20 20:35:39 +01:00
|
|
|
tads.at(e)->assign(row);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete row;
|
|
|
|
}
|
|
|
|
auto indices = NDArrayFactory::create<double>('c', {1, 10});
|
|
|
|
for (int e = 0; e < matrix.rows(); e++)
|
|
|
|
indices.p(e, 9 - e);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::scatter_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
for (int e = 0; e < 10; e++) {
|
2019-12-20 20:35:39 +01:00
|
|
|
auto row = tads.at(9 - e);
|
2019-06-06 14:21:15 +02:00
|
|
|
auto chunk = list.readRaw(e);
|
|
|
|
|
|
|
|
ASSERT_TRUE(chunk->isSameShape(row));
|
|
|
|
|
|
|
|
ASSERT_TRUE(chunk->equalsTo(row));
|
|
|
|
}
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Clone_1) {
|
|
|
|
auto list = new NDArrayList(0, true);
|
|
|
|
|
|
|
|
VariableSpace variableSpace;
|
|
|
|
auto var = new Variable(nullptr, nullptr, -1, 0);
|
|
|
|
var->setNDArrayList(list);
|
|
|
|
|
|
|
|
variableSpace.putVariable(-1, var);
|
|
|
|
variableSpace.trackList(list);
|
|
|
|
|
|
|
|
Context block(1, &variableSpace);
|
|
|
|
block.pickInput(-1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::clone_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(list == block.variable(0)->getNDArrayList());
|
|
|
|
|
|
|
|
auto result = op.execute(&block);
|
|
|
|
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
|
|
|
|
|
|
|
auto resVar = variableSpace.getVariable(1);
|
|
|
|
|
|
|
|
auto resList = resVar->getNDArrayList();
|
|
|
|
|
|
|
|
ASSERT_TRUE( resList != nullptr);
|
|
|
|
|
|
|
|
ASSERT_TRUE(list->equals(*resList));
|
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, BasicTest_Gather_1) {
|
|
|
|
NDArrayList list(0, true);
|
|
|
|
for (int e = 0; e < 10; e++) {
|
|
|
|
auto row = NDArrayFactory::create_<double>('c', {3});
|
|
|
|
row->assign((double) e);
|
2019-12-20 20:35:39 +01:00
|
|
|
list.write(e, new NDArray(row->dup()));
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
delete row;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {10, 3});
|
|
|
|
auto tads = exp.allTensorsAlongDimension({1});
|
|
|
|
for (int e = 0; e < 10; e++) {
|
2019-12-20 20:35:39 +01:00
|
|
|
auto tad = tads.at(9 - e);
|
2019-06-06 14:21:15 +02:00
|
|
|
tad->assign(e);
|
|
|
|
}
|
|
|
|
|
|
|
|
auto indices = NDArrayFactory::create<double>('c', {1, 10});
|
|
|
|
indices.linspace(9, -1);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::gather_list op;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto result = op.execute(&list, {&indices}, {}, {});
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
|
|
|
ASSERT_EQ(1, result.size());
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
auto z = result.at(0);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(z));
|
|
|
|
|
|
|
|
//exp.printIndexedBuffer("e");
|
|
|
|
//z->printIndexedBuffer("z");
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.equalsTo(z));
|
|
|
|
|
2020-03-10 05:42:50 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, GraphTests_Sequential_1) {
|
|
|
|
Graph graph;
|
|
|
|
|
|
|
|
auto matrix = NDArrayFactory::create_<float>('c', {3, 3});
|
|
|
|
auto tads = matrix->allTensorsAlongDimension({1});
|
2019-12-20 20:35:39 +01:00
|
|
|
for (int e = 0; e < tads.size(); e++) {
|
|
|
|
tads.at(e)->assign((float) (e+1));
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<float>('c', {3, 3});
|
|
|
|
auto tadsExp = exp.allTensorsAlongDimension({1});
|
2019-12-20 20:35:39 +01:00
|
|
|
tadsExp.at(0)->assign(0.f);
|
|
|
|
tadsExp.at(1)->assign(-1.f);
|
|
|
|
tadsExp.at(2)->assign(-2.f);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
auto indices = NDArrayFactory::valueOf<int>({3}, 1, 'c');
|
|
|
|
//indices->linspace(0);
|
|
|
|
|
|
|
|
|
|
|
|
auto variableSpace = graph.getVariableSpace();
|
|
|
|
variableSpace->putVariable(-1, matrix);
|
|
|
|
variableSpace->putVariable(-2, indices);
|
|
|
|
|
|
|
|
|
|
|
|
auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1});
|
|
|
|
|
|
|
|
// creating list
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::create_list opB;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1});
|
|
|
|
//nodeB->setCustomOp(&opB);
|
|
|
|
|
|
|
|
// filling list with matrix
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::split_list opC;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeC = new Node(&opC, 3, {2, 1, -2});
|
|
|
|
//nodeC->setCustomOp(&opC);
|
|
|
|
|
|
|
|
// reading chunks from List. We're adding op number 3 in inputs, to ensure graph will execute this node after split
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::read_list opD;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0});
|
|
|
|
auto nodeD1 = new Node(&opD, 6, {2, 3}, {},{}, 0.0f, {}, {1});
|
|
|
|
auto nodeD2 = new Node(&opD, 7, {2, 3}, {},{}, 0.0f, {}, {2});
|
|
|
|
//nodeD0->setCustomOp(&opD);
|
|
|
|
//nodeD1->setCustomOp(&opD);
|
|
|
|
//nodeD2->setCustomOp(&opD);
|
|
|
|
|
|
|
|
// using OneMinus on each chunk separately
|
2020-03-02 10:49:41 +01:00
|
|
|
auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5});
|
|
|
|
auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6});
|
|
|
|
auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// writing chunks back to the List
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::write_list opF;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0});
|
|
|
|
auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1});
|
|
|
|
auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2});
|
|
|
|
|
|
|
|
// nodeF0->setCustomOp(&opF);
|
|
|
|
// nodeF1->setCustomOp(&opF);
|
|
|
|
// nodeF2->setCustomOp(&opF);
|
|
|
|
|
2019-12-20 20:35:39 +01:00
|
|
|
// now we're stacking chunks back to matrix state
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::stack_list opG;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeG = new Node(&opG, 20, {2, 15, 16, 17});
|
|
|
|
//auto nodeG = new Node<float>(OpType_CUSTOM, 0, 20, {2});
|
|
|
|
|
|
|
|
// nodeG->setCustomOp(&opG);
|
|
|
|
|
|
|
|
|
|
|
|
graph.addNode(nodeA);
|
|
|
|
graph.addNode(nodeB);
|
|
|
|
graph.addNode(nodeC);
|
|
|
|
graph.addNode(nodeD0);
|
|
|
|
graph.addNode(nodeD1);
|
|
|
|
graph.addNode(nodeD2);
|
|
|
|
graph.addNode(nodeE0);
|
|
|
|
graph.addNode(nodeE1);
|
|
|
|
graph.addNode(nodeE2);
|
|
|
|
|
|
|
|
graph.addNode(nodeF0);
|
|
|
|
graph.addNode(nodeF1);
|
|
|
|
graph.addNode(nodeF2);
|
|
|
|
|
|
|
|
graph.addNode(nodeG);
|
|
|
|
|
|
|
|
// let's also validate structural integrity
|
|
|
|
graph.buildGraph();
|
|
|
|
|
|
|
|
ASSERT_EQ(0, nodeA->getLayer());
|
|
|
|
ASSERT_EQ(1, nodeB->getLayer());
|
|
|
|
ASSERT_EQ(2, nodeC->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(3, nodeD0->getLayer());
|
|
|
|
ASSERT_EQ(3, nodeD1->getLayer());
|
|
|
|
ASSERT_EQ(3, nodeD2->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(4, nodeE0->getLayer());
|
|
|
|
ASSERT_EQ(4, nodeE1->getLayer());
|
|
|
|
ASSERT_EQ(4, nodeE2->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(5, nodeF0->getLayer());
|
|
|
|
ASSERT_EQ(5, nodeF1->getLayer());
|
|
|
|
ASSERT_EQ(5, nodeF2->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(6, nodeG->getLayer());
|
|
|
|
|
|
|
|
auto result = GraphExecutioner::execute(&graph);
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
|
|
|
|
|
|
|
ASSERT_TRUE(variableSpace->hasVariable(2));
|
|
|
|
auto list = variableSpace->getVariable(2)->getNDArrayList();
|
|
|
|
|
|
|
|
ASSERT_TRUE(list != nullptr);
|
|
|
|
|
|
|
|
ASSERT_EQ(3, list->height());
|
|
|
|
ASSERT_EQ(3, list->elements());
|
|
|
|
|
|
|
|
|
|
|
|
ASSERT_TRUE(variableSpace->hasVariable(20));
|
|
|
|
|
|
|
|
auto stack = variableSpace->getVariable(20)->getNDArray();
|
|
|
|
|
|
|
|
ASSERT_TRUE(stack != nullptr);
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(stack));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(stack));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
TEST_F(ListOperationsTests, GraphTests_Sequential_2) {
|
|
|
|
Graph graph;
|
|
|
|
|
|
|
|
auto scalar = NDArrayFactory::create_<double>(0.0f);
|
|
|
|
auto matrix = NDArrayFactory::create_<double>('c', {3, 3});
|
|
|
|
auto tads = matrix->allTensorsAlongDimension({1});
|
2019-12-20 20:35:39 +01:00
|
|
|
for (int e = 0; e < tads.size(); e++) {
|
|
|
|
tads.at(e)->assign((float) (e+1));
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
auto exp = NDArrayFactory::create<double>('c', {3, 3});
|
|
|
|
auto tadsExp = exp.allTensorsAlongDimension({1});
|
2019-12-20 20:35:39 +01:00
|
|
|
tadsExp.at(0)->assign(0.f);
|
|
|
|
tadsExp.at(1)->assign(-1.f);
|
|
|
|
tadsExp.at(2)->assign(-2.f);
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
//auto indices = NDArray<float>::valueOf({1, 3}, 1.0f, 'c');
|
|
|
|
auto indices = NDArrayFactory::create_<double>('c', {1, 3});
|
|
|
|
indices->linspace(0);
|
|
|
|
|
|
|
|
|
|
|
|
auto variableSpace = graph.getVariableSpace();
|
|
|
|
variableSpace->putVariable(-1, matrix);
|
|
|
|
variableSpace->putVariable(-2, indices);
|
|
|
|
variableSpace->putVariable(-3, scalar);
|
|
|
|
|
|
|
|
|
|
|
|
auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1});
|
|
|
|
|
|
|
|
// creating list
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::create_list opB;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1});
|
|
|
|
// nodeB->setCustomOp(&opB);
|
|
|
|
|
|
|
|
// filling list with matrix
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::scatter_list opC;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeC = new Node(&opC, 3, {2, -2, 1, -3});
|
2019-12-20 20:35:39 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
//nodeC->setCustomOp(&opC);
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::read_list opD;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0});
|
|
|
|
auto nodeD1 = new Node(&opD, 6, {2, 3, 15}, {},{}, 0.0f, {}, {1});
|
|
|
|
auto nodeD2 = new Node(&opD, 7, {2, 3, 16}, {},{}, 0.0f, {}, {2});
|
|
|
|
|
|
|
|
// nodeD0->setCustomOp(&opD);
|
|
|
|
// nodeD1->setCustomOp(&opD);
|
|
|
|
// nodeD2->setCustomOp(&opD);
|
|
|
|
|
|
|
|
|
|
|
|
// using OneMinus on each chunk separately
|
2020-03-02 10:49:41 +01:00
|
|
|
auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5});
|
|
|
|
auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6});
|
|
|
|
auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7});
|
2019-06-06 14:21:15 +02:00
|
|
|
|
|
|
|
// writing chunks back to the List
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::write_list opF;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0});
|
|
|
|
auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1});
|
|
|
|
auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2});
|
|
|
|
|
|
|
|
// nodeF0->setCustomOp(&opF);
|
|
|
|
// nodeF1->setCustomOp(&opF);
|
|
|
|
// nodeF2->setCustomOp(&opF);
|
|
|
|
|
2019-12-20 20:35:39 +01:00
|
|
|
// now we're gathering chunks back to matrix state
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::ops::pick_list opG;
|
2019-06-06 14:21:15 +02:00
|
|
|
auto nodeG = new Node(&opG, 20, {2, -2, 15, 16, 17});
|
|
|
|
//auto nodeG = new Node<float>(OpType_CUSTOM, 0, 20, {2});
|
|
|
|
|
|
|
|
//nodeG->setCustomOp(&opG);
|
|
|
|
|
|
|
|
graph.addNode(nodeA);
|
|
|
|
graph.addNode(nodeB);
|
|
|
|
graph.addNode(nodeC);
|
|
|
|
graph.addNode(nodeD0);
|
|
|
|
graph.addNode(nodeD1);
|
|
|
|
graph.addNode(nodeD2);
|
|
|
|
graph.addNode(nodeE0);
|
|
|
|
graph.addNode(nodeE1);
|
|
|
|
graph.addNode(nodeE2);
|
|
|
|
|
|
|
|
graph.addNode(nodeF0);
|
|
|
|
graph.addNode(nodeF1);
|
|
|
|
graph.addNode(nodeF2);
|
|
|
|
|
|
|
|
graph.addNode(nodeG);
|
|
|
|
|
|
|
|
// let's also validate structural integrity
|
|
|
|
graph.buildGraph();
|
|
|
|
|
|
|
|
ASSERT_EQ(0, nodeA->getLayer());
|
|
|
|
ASSERT_EQ(1, nodeB->getLayer());
|
|
|
|
ASSERT_EQ(2, nodeC->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(3, nodeD0->getLayer());
|
|
|
|
ASSERT_EQ(4, nodeE0->getLayer());
|
|
|
|
ASSERT_EQ(5, nodeF0->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(6, nodeD1->getLayer());
|
|
|
|
ASSERT_EQ(7, nodeE1->getLayer());
|
|
|
|
ASSERT_EQ(8, nodeF1->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(9, nodeD2->getLayer());
|
|
|
|
ASSERT_EQ(10, nodeE2->getLayer());
|
|
|
|
ASSERT_EQ(11, nodeF2->getLayer());
|
|
|
|
|
|
|
|
ASSERT_EQ(12, nodeG->getLayer());
|
|
|
|
|
|
|
|
|
|
|
|
auto result = GraphExecutioner::execute(&graph);
|
|
|
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
|
|
|
|
|
|
|
ASSERT_TRUE(variableSpace->hasVariable(2));
|
|
|
|
auto list = variableSpace->getVariable(2)->getNDArrayList();
|
|
|
|
|
|
|
|
ASSERT_TRUE(list != nullptr);
|
|
|
|
|
|
|
|
ASSERT_EQ(3, list->height());
|
|
|
|
ASSERT_EQ(3, list->elements());
|
|
|
|
|
|
|
|
ASSERT_TRUE(variableSpace->hasVariable(20));
|
2019-12-20 20:35:39 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
auto stack = variableSpace->getVariable(20)->getNDArray();
|
2019-12-20 20:35:39 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
ASSERT_TRUE(stack != nullptr);
|
|
|
|
|
|
|
|
ASSERT_TRUE(exp.isSameShape(stack));
|
|
|
|
ASSERT_TRUE(exp.equalsTo(stack));
|
|
|
|
}
|