cavis/libnd4j/tests_cpu/layers_tests/SwitchTests.cpp

252 lines
8.4 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by raver119 on 13.10.2017.
//
#include "testlayers.h"
#include <ops/declarable/CustomOperations.h>
using namespace sd;
using namespace sd::ops;
using namespace sd::graph;
2019-06-06 14:21:15 +02:00
class SwitchTests : public testing::Test {
public:
};
TEST_F(SwitchTests, SwitchTest1) {
Graph graph;
FlowPath flowPath;
auto variableSpace = graph.getVariableSpace();
variableSpace->setFlowPath(&flowPath);
auto input = NDArrayFactory::create_<float>('c',{32, 100});
input->assign(-119.0f);
auto condtionX = NDArrayFactory::create_<float>('c', {1, 1});
condtionX->p(0, 0.0f);
auto condtionY = NDArrayFactory::create_<float>('c', {1, 1});
condtionY->p(0, 0.0f);
variableSpace->putVariable(-1, input);
variableSpace->putVariable(-2, condtionX);
variableSpace->putVariable(-3, condtionY);
// this is just 2 ops, that are executed sequentially. We don't really care bout them
auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2});
auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3});
// this is our condition op, we'll be using Equals condition, on variables conditionX and conditionY (ids -2 and -3 respectively)
// we're creating this op manually in tests, as always.
sd::ops::eq_scalar eqOp;
2019-06-06 14:21:15 +02:00
auto nodeCondition = new Node(&eqOp, 119, {-2, -3});
//nodeCondition->setOpType(OpType_BOOLEAN);
// now, this is Switch operation. It takes BooleanOperation operation in,
// and based on evaluation result (true/false) - it'll pass data via :0 or :1 output
// other idx will be considered disabled, and that graph branch won't be executed
sd::ops::Switch switchOp;
2019-06-06 14:21:15 +02:00
auto nodeSwitch = new Node(&switchOp, 3, {2, 119}, {4, 5});
// these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE
auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 4, {}, {});
nodeZ0->pickInput(3, 0);
auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 5, {}, {});
nodeZ1->pickInput(3, 1);
graph.addNode(nodeA);
graph.addNode(nodeB);
graph.addNode(nodeCondition);
graph.addNode(nodeSwitch);
graph.addNode(nodeZ0);
graph.addNode(nodeZ1);
graph.buildGraph();
// we're making sure nodes connected to the Switch have no other inputs in this Graph
ASSERT_EQ(1, nodeZ0->input()->size());
ASSERT_EQ(1, nodeZ1->input()->size());
// just validating topo sort
ASSERT_EQ(0, nodeA->getLayer());
ASSERT_EQ(0, nodeCondition->getLayer());
ASSERT_EQ(1, nodeB->getLayer());
ASSERT_EQ(2, nodeSwitch->getLayer());
ASSERT_EQ(3, nodeZ0->getLayer());
ASSERT_EQ(3, nodeZ1->getLayer());
// executing graph
Nd4jStatus status = GraphExecutioner::execute(&graph);
ASSERT_EQ(ND4J_STATUS_OK, status);
// nd4j_printf("Z0: [%i]; Z1: [%i]\n", flowPath.isNodeActive(nodeZ0->id()), flowPath.isNodeActive(nodeZ1->id()));
// we know that Switch got TRUE evaluation, so :0 should be inactive
ASSERT_FALSE(flowPath.isNodeActive(nodeZ0->id()));
// and :1 should be active
ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id()));
std::pair<int,int> unexpected(4,0);
std::pair<int,int> expectedResultIndex(5,0);
ASSERT_TRUE(variableSpace->hasVariable(expectedResultIndex));
// getting output of nodeZ1
auto output = variableSpace->getVariable(expectedResultIndex)->getNDArray();
// and veryfing it against known expected value
ASSERT_NEAR(-118.0f, output->e<float>(0), 1e-5f);
}
TEST_F(SwitchTests, SwitchTest2) {
Graph graph;
FlowPath flowPath;
auto variableSpace = graph.getVariableSpace();
variableSpace->setFlowPath(&flowPath);
auto input = NDArrayFactory::create_<float>('c',{32, 100});
input->assign(-119.0f);
auto condtionX = NDArrayFactory::create_<float>('c', {1, 1});
condtionX->p(0, 1.0f);
auto condtionY = NDArrayFactory::create_<float>('c', {1, 1});
condtionY->p(0, 1.0f);
variableSpace->putVariable(-1, input);
variableSpace->putVariable(-2, condtionX);
variableSpace->putVariable(-3, condtionY);
auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2});
auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3});
auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3);
scopeCondition->setName("scopeCondition");
auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3});
nodeCondition->setScopeInfo(3, "scopeCondition");
sd::ops::eq_scalar eqOp;
2019-06-06 14:21:15 +02:00
nodeCondition->setCustomOp(&eqOp);
auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2});
sd::ops::Switch switchOp;
2019-06-06 14:21:15 +02:00
nodeSwitch->setCustomOp(&switchOp);
// these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE
auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 6, {}, {});
nodeZ0->pickInput(5, 0);
auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {});
nodeZ1->pickInput(5, 1);
graph.addNode(nodeA);
graph.addNode(nodeB);
graph.addNode(scopeCondition);
graph.addNode(nodeCondition);
graph.addNode(nodeSwitch);
graph.addNode(nodeZ0);
graph.addNode(nodeZ1);
Nd4jStatus status = GraphExecutioner::execute(&graph);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(!flowPath.isNodeActive(nodeZ0->id()));
ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id()));
auto z = graph.getVariableSpace()->getVariable(7)->getNDArray();
// abs(-119) = 119; 1 - 119 = -118
ASSERT_NEAR(-118.f, z->e<float>(0), 1e-5);
}
TEST_F(SwitchTests, SwitchTest3) {
Graph graph;
FlowPath flowPath;
auto variableSpace = graph.getVariableSpace();
variableSpace->setFlowPath(&flowPath);
auto input = NDArrayFactory::create_<float>('c',{32, 100});
input->assign(-119.0f);
auto condtionX = NDArrayFactory::create_<float>('c', {1, 1});
condtionX->p(0, 2.0f);
auto condtionY = NDArrayFactory::create_<float>('c', {1, 1});
condtionY->p(0, 1.0f);
variableSpace->putVariable(-1, input);
variableSpace->putVariable(-2, condtionX);
variableSpace->putVariable(-3, condtionY);
auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2});
auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3});
auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3);
scopeCondition->setName("scopeCondition");
auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3});
nodeCondition->setScopeInfo(3, "scopeCondition");
sd::ops::eq_scalar eqOp;
2019-06-06 14:21:15 +02:00
nodeCondition->setCustomOp(&eqOp);
auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2});
sd::ops::Switch switchOp;
2019-06-06 14:21:15 +02:00
nodeSwitch->setCustomOp(&switchOp);
// these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE
auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Neg, 6, {}, {});
nodeZ0->pickInput(5, 0);
auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {});
nodeZ1->pickInput(5, 1);
graph.addNode(nodeA);
graph.addNode(nodeB);
graph.addNode(scopeCondition);
graph.addNode(nodeCondition);
graph.addNode(nodeSwitch);
graph.addNode(nodeZ0);
graph.addNode(nodeZ1);
Nd4jStatus status = GraphExecutioner::execute(&graph);
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(flowPath.isNodeActive(nodeZ0->id()));
ASSERT_TRUE(!flowPath.isNodeActive(nodeZ1->id()));
auto z = graph.getVariableSpace()->getVariable(6)->getNDArray();
// abs(-119) = 119; Neg(119) = 119
ASSERT_NEAR(-119.f, z->e<float>(0), 1e-5);
}