816 lines
51 KiB
C++
816 lines
51 KiB
C++
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
//
|
||
|
// @author raver119@gmail.com
|
||
|
//
|
||
|
|
||
|
#include "testlayers.h"
|
||
|
#include <flatbuffers/flatbuffers.h>
|
||
|
#include <graph/generated/node_generated.h>
|
||
|
#include <graph/generated/graph_generated.h>
|
||
|
#include <graph/generated/result_generated.h>
|
||
|
#include <graph/Node.h>
|
||
|
#include <graph/Graph.h>
|
||
|
#include <GraphExecutioner.h>
|
||
|
#include <ops/declarable/CustomOperations.h>
|
||
|
|
||
|
using namespace nd4j;
|
||
|
using namespace nd4j::graph;
|
||
|
|
||
|
class FlatBuffersTest : public testing::Test {
|
||
|
public:
|
||
|
int alpha = 0;
|
||
|
|
||
|
Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99};
|
||
|
Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102};
|
||
|
|
||
|
FlatBuffersTest() {
|
||
|
Environment::getInstance()->setDebug(false);
|
||
|
Environment::getInstance()->setVerbose(false);
|
||
|
Environment::getInstance()->setProfiling(false);
|
||
|
}
|
||
|
|
||
|
~FlatBuffersTest() {
|
||
|
Environment::getInstance()->setDebug(false);
|
||
|
Environment::getInstance()->setVerbose(false);
|
||
|
Environment::getInstance()->setProfiling(false);
|
||
|
|
||
|
delete[] cShape;
|
||
|
delete[] fShape;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/**
|
||
|
* Simple test that creates Node & reads it
|
||
|
*/
|
||
|
TEST_F(FlatBuffersTest, BasicTest1) {
|
||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||
|
|
||
|
auto name = builder.CreateString("wow");
|
||
|
|
||
|
auto node = CreateFlatNode(builder, -1, name, OpType_TRANSFORM_SAME, transform::Ones, {0});
|
||
|
|
||
|
builder.Finish(node);
|
||
|
|
||
|
// now we have our buffer with data
|
||
|
uint8_t *buf = builder.GetBufferPointer();
|
||
|
int size = builder.GetSize();
|
||
|
ASSERT_TRUE(size > 0);
|
||
|
|
||
|
|
||
|
|
||
|
auto restored = GetFlatNode(buf);
|
||
|
|
||
|
auto gA = new Node(restored);
|
||
|
auto gB = new Node(restored);
|
||
|
|
||
|
ASSERT_TRUE(gA->equals(gB));
|
||
|
|
||
|
delete gA;
|
||
|
delete gB;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(FlatBuffersTest, FlatGraphTest1) {
|
||
|
flatbuffers::FlatBufferBuilder builder(4096);
|
||
|
|
||
|
auto array = NDArrayFactory::create_<float>('c', {5, 5});
|
||
|
array->assign(-2.0f);
|
||
|
|
||
|
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||
|
auto fBuffer = builder.CreateVector(array->asByteVector());
|
||
|
|
||
|
auto fArray = CreateFlatArray(builder, fShape, fBuffer, nd4j::graph::DataType::DataType_FLOAT);
|
||
|
auto fVid = CreateIntPair(builder, -1);
|
||
|
|
||
|
auto fVar = CreateFlatVariable(builder, fVid, 0, nd4j::graph::DataType::DataType_FLOAT, 0, fArray);
|
||
|
|
||
|
std::vector<int> outputs1, outputs2, inputs1, inputs2;
|
||
|
outputs1.push_back(2);
|
||
|
outputs2.push_back(0);
|
||
|
|
||
|
inputs1.push_back(-1);
|
||
|
inputs2.push_back(1);
|
||
|
|
||
|
|
||
|
auto vec1 = builder.CreateVector(outputs1);
|
||
|
auto vec2 = builder.CreateVector(outputs2);
|
||
|
|
||
|
auto in1 = builder.CreateVector(inputs1);
|
||
|
auto in2 = builder.CreateVector(inputs2);
|
||
|
|
||
|
auto name1 = builder.CreateString("wow1");
|
||
|
auto name2 = builder.CreateString("wow2");
|
||
|
|
||
|
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM_SAME, transform::Abs, 0, in1, 0, vec1);
|
||
|
auto node2 = CreateFlatNode(builder, 2, name2, OpType_TRANSFORM_STRICT, transform::Cosine, 0, in2, 0, vec2);
|
||
|
|
||
|
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
||
|
variables_vector.push_back(fVar);
|
||
|
|
||
|
std::vector<flatbuffers::Offset<FlatNode>> nodes_vector;
|
||
|
|
||
|
nodes_vector.push_back(node1);
|
||
|
nodes_vector.push_back(node2);
|
||
|
|
||
|
auto nodes = builder.CreateVector(nodes_vector);
|
||
|
|
||
|
auto variables = builder.CreateVector(variables_vector);
|
||
|
|
||
|
FlatGraphBuilder graphBuilder(builder);
|
||
|
|
||
|
graphBuilder.add_variables(variables);
|
||
|
graphBuilder.add_id(119);
|
||
|
graphBuilder.add_nodes(nodes);
|
||
|
|
||
|
auto flatGraph = graphBuilder.Finish();
|
||
|
|
||
|
builder.Finish(flatGraph);
|
||
|
|
||
|
uint8_t *buf = builder.GetBufferPointer();
|
||
|
int size = builder.GetSize();
|
||
|
ASSERT_TRUE(size > 0);
|
||
|
|
||
|
|
||
|
auto restoredGraph = GetFlatGraph(buf);
|
||
|
ASSERT_EQ(119, restoredGraph->id());
|
||
|
ASSERT_EQ(2, restoredGraph->nodes()->size());
|
||
|
|
||
|
// checking op nodes
|
||
|
ASSERT_EQ(transform::Abs, restoredGraph->nodes()->Get(0)->opNum());
|
||
|
ASSERT_EQ(transform::Cosine, restoredGraph->nodes()->Get(1)->opNum());
|
||
|
ASSERT_EQ(transform::Abs, restoredGraph->nodes()->Get(0)->opNum());
|
||
|
|
||
|
// checking variables
|
||
|
ASSERT_EQ(1, restoredGraph->variables()->size());
|
||
|
ASSERT_EQ(-1, restoredGraph->variables()->Get(0)->id()->first());
|
||
|
|
||
|
// nd4j_printf("-------------------------\n","");
|
||
|
|
||
|
Graph graph(restoredGraph);
|
||
|
|
||
|
// graph.printOut();
|
||
|
|
||
|
ASSERT_EQ(2, graph.totalNodes());
|
||
|
ASSERT_EQ(1, graph.rootNodes());
|
||
|
|
||
|
|
||
|
auto vs = graph.getVariableSpace();
|
||
|
|
||
|
ASSERT_EQ(OutputMode_IMPLICIT, graph.getExecutorConfiguration()->_outputMode);
|
||
|
|
||
|
ASSERT_EQ(3, vs->totalEntries());
|
||
|
ASSERT_EQ(1, vs->externalEntries());
|
||
|
ASSERT_EQ(2, vs->internalEntries());
|
||
|
|
||
|
auto var = vs->getVariable(-1)->getNDArray();
|
||
|
|
||
|
ASSERT_TRUE(var != nullptr);
|
||
|
ASSERT_EQ(-2.0, var->reduceNumber(reduce::Mean).e<float>(0));
|
||
|
|
||
|
nd4j::graph::GraphExecutioner::execute(&graph);
|
||
|
|
||
|
auto resultWrapper = nd4j::graph::GraphExecutioner::executeFlatBuffer((Nd4jPointer) buf);
|
||
|
|
||
|
auto flatResults = GetFlatResult(resultWrapper->pointer());
|
||
|
|
||
|
ASSERT_EQ(1, flatResults->variables()->size());
|
||
|
ASSERT_TRUE(flatResults->variables()->Get(0)->name() != nullptr);
|
||
|
ASSERT_TRUE(flatResults->variables()->Get(0)->name()->c_str() != nullptr);
|
||
|
//nd4j_printf("VARNAME: %s\n", flatResults->variables()->Get(0)->name()->c_str());
|
||
|
|
||
|
auto var0 = new Variable(flatResults->variables()->Get(0));
|
||
|
//auto var1 = new Variable<float>(flatResults->variables()->Get(1));
|
||
|
auto avg = var0->getNDArray()->reduceNumber(reduce::Mean);
|
||
|
avg.printIndexedBuffer("FBT_1");
|
||
|
ASSERT_NEAR(-0.4161468, avg.e<float>(0), 1e-5);
|
||
|
|
||
|
//ASSERT_TRUE(var->equalsTo(var0->getNDArray()));
|
||
|
|
||
|
delete array;
|
||
|
delete var0;
|
||
|
delete resultWrapper;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ExecutionTest1) {
|
||
|
auto gA = new Node(OpType_TRANSFORM_SAME);
|
||
|
|
||
|
auto c = new float[4] {-1, -2, -3, -4};
|
||
|
auto array = new NDArray(c, cShape);
|
||
|
|
||
|
auto e = new float[4] {1, 2, 3, 4};
|
||
|
auto exp = new NDArray(e, cShape);
|
||
|
|
||
|
//gA->execute(array, nullptr, array);
|
||
|
|
||
|
//ASSERT_TRUE(exp->equalsTo(array));
|
||
|
|
||
|
delete gA;
|
||
|
delete[] c;
|
||
|
delete array;
|
||
|
delete[] e;
|
||
|
delete exp;
|
||
|
}
|
||
|
|
||
|
/*
|
||
|
TEST_F(FlatBuffersTest, ExplicitOutputTest1) {
|
||
|
flatbuffers::FlatBufferBuilder builder(4096);
|
||
|
|
||
|
auto x = NDArrayFactory::create_<float>(5, 5, 'c');
|
||
|
x->assign(-2.0f);
|
||
|
|
||
|
auto fXShape = builder.CreateVector(x->getShapeInfoAsVector());
|
||
|
auto fXBuffer = builder.CreateVector(x->asByteVector());
|
||
|
auto fXArray = CreateFlatArray(builder, fXShape, fXBuffer);
|
||
|
auto fXid = CreateIntPair(builder, -1);
|
||
|
|
||
|
auto fXVar = CreateFlatVariable(builder, fXid, 0, 0, fXArray);
|
||
|
|
||
|
|
||
|
auto y = NDArrayFactory::create_<float>(5, 5, 'c');
|
||
|
y->assign(-1.0f);
|
||
|
|
||
|
auto fYShape = builder.CreateVector(y->getShapeInfoAsVector());
|
||
|
auto fYBuffer = builder.CreateVector(y->asByteVector());
|
||
|
auto fYArray = CreateFlatArray(builder, fYShape, fYBuffer);
|
||
|
auto fYid = CreateIntPair(builder, -2);
|
||
|
|
||
|
auto fYVar = CreateFlatVariable(builder, fYid, 0, 0, fYArray);
|
||
|
|
||
|
|
||
|
std::vector<flatbuffers::Offset<IntPair>> inputs1, outputs1, outputs;
|
||
|
inputs1.push_back(CreateIntPair(builder, -1));
|
||
|
inputs1.push_back(CreateIntPair(builder, -2));
|
||
|
|
||
|
outputs.push_back(CreateIntPair(builder, -1));
|
||
|
outputs.push_back(CreateIntPair(builder, -2));
|
||
|
|
||
|
auto out1 = builder.CreateVector(outputs1);
|
||
|
auto in1 = builder.CreateVector(inputs1);
|
||
|
auto o = builder.CreateVector(outputs);
|
||
|
|
||
|
auto name1 = builder.CreateString("wow1");
|
||
|
|
||
|
auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, nd4j::graph::DataType::FLOAT);
|
||
|
|
||
|
std::vector<flatbuffers::Offset<FlatVariable>> variables_vector;
|
||
|
variables_vector.push_back(fXVar);
|
||
|
variables_vector.push_back(fYVar);
|
||
|
|
||
|
std::vector<flatbuffers::Offset<FlatNode>> nodes_vector;
|
||
|
nodes_vector.push_back(node1);
|
||
|
|
||
|
|
||
|
|
||
|
auto nodes = builder.CreateVector(nodes_vector);
|
||
|
auto variables = builder.CreateVector(variables_vector);
|
||
|
|
||
|
FlatGraphBuilder graphBuilder(builder);
|
||
|
|
||
|
graphBuilder.add_variables(variables);
|
||
|
graphBuilder.add_id(119);
|
||
|
graphBuilder.add_nodes(nodes);
|
||
|
graphBuilder.add_outputs(o);
|
||
|
|
||
|
|
||
|
auto flatGraph = graphBuilder.Finish();
|
||
|
builder.Finish(flatGraph);
|
||
|
|
||
|
auto restoredGraph = new Graph<float>(GetFlatGraph(builder.GetBufferPointer()));
|
||
|
|
||
|
GraphExecutioner<float>::execute(restoredGraph);
|
||
|
|
||
|
auto results = restoredGraph->fetchOutputs();
|
||
|
|
||
|
// IMPLICIT is default
|
||
|
ASSERT_EQ(1, results->size());
|
||
|
|
||
|
//ASSERT_NEAR(-2.0, results->at(0)->getNDArray()->reduceNumber<simdOps::Mean<float>>(), 1e-5);
|
||
|
//ASSERT_NEAR(-1.0, results->at(1)->getNDArray()->reduceNumber<simdOps::Mean<float>>(), 1e-5);
|
||
|
ASSERT_NEAR(-3.0, results->at(0)->getNDArray()->reduceNumber<simdOps::Mean<float>>(), 1e-5);
|
||
|
|
||
|
//ASSERT_EQ(-1, results->at(0)->id());
|
||
|
//ASSERT_EQ(-2, results->at(1)->id());
|
||
|
|
||
|
delete restoredGraph;
|
||
|
delete results;
|
||
|
delete x;
|
||
|
delete y;
|
||
|
}
|
||
|
*/
|
||
|
|
||
|
/*
|
||
|
TEST_F(FlatBuffersTest, ReadFile1) {
|
||
|
|
||
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/adam_sum.fb");
|
||
|
|
||
|
auto fg = GetFlatGraph(data);
|
||
|
auto restoredGraph = new Graph<float>(fg);
|
||
|
|
||
|
ASSERT_EQ(1, restoredGraph->rootNodes());
|
||
|
ASSERT_EQ(2, restoredGraph->totalNodes());
|
||
|
|
||
|
auto ones = restoredGraph->getVariableSpace()->getVariable(-1)->getNDArray();
|
||
|
|
||
|
ASSERT_EQ(4, ones->lengthOf());
|
||
|
ASSERT_NEAR(4.0f, ones->template reduceNumber<simdOps::Sum<float>>(), 1e-5);
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(restoredGraph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
auto result = restoredGraph->getVariableSpace()->getVariable(2)->getNDArray();
|
||
|
ASSERT_EQ(1, result->lengthOf());
|
||
|
ASSERT_EQ(8, result->e(0));
|
||
|
|
||
|
delete[] data;
|
||
|
delete restoredGraph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReadFile2) {
|
||
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/adam_sum.fb");
|
||
|
Nd4jPointer result = GraphExecutioner<float>::executeFlatBuffer((Nd4jPointer) data);
|
||
|
|
||
|
ResultSet<float> arrays(GetFlatResult(result));
|
||
|
|
||
|
ASSERT_EQ(1, arrays.size());
|
||
|
ASSERT_EQ(1, arrays.at(0)->lengthOf());
|
||
|
ASSERT_EQ(8, arrays.at(0)->e(0));
|
||
|
|
||
|
delete[] data;
|
||
|
delete[] (char *) result;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReadFile3) {
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/adam_sum.fb");
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(2)->getNDArray();
|
||
|
|
||
|
ASSERT_EQ(1, z->lengthOf());
|
||
|
ASSERT_EQ(8, z->e(0));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReadInception1) {
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/inception.fb");
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(227));
|
||
|
|
||
|
auto lastNode = graph->getVariableSpace()->getVariable(227)->getNDArray();
|
||
|
|
||
|
lastNode->printShapeInfo("Result shape");
|
||
|
|
||
|
auto argMax = lastNode->argMax();
|
||
|
|
||
|
//nd4j_printf("Predicted class: %i\n", (int) argMax);
|
||
|
//nd4j_printf("Probability: %f\n", lastNode->e(argMax));
|
||
|
//nd4j_printf("Probability ipod: %f\n", lastNode->e(980));
|
||
|
//lastNode->printBuffer("Whole output");
|
||
|
|
||
|
ASSERT_EQ(561, (int) argMax);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) {
|
||
|
// TF graph:
|
||
|
// https://gist.github.com/raver119/b86ef727e9a094aab386e2b35e878966
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/three_args_while.fb");
|
||
|
|
||
|
ASSERT_TRUE(graph != nullptr);
|
||
|
|
||
|
//graph->printOut();
|
||
|
|
||
|
auto expPhi('c', {2, 2});
|
||
|
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-1));
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-2));
|
||
|
|
||
|
auto phi = graph->getVariableSpace()->getVariable(-2)->getNDArray();
|
||
|
auto constA = graph->getVariableSpace()->getVariable(-5)->getNDArray();
|
||
|
auto lessY = graph->getVariableSpace()->getVariable(-6)->getNDArray();
|
||
|
|
||
|
//constA->printBuffer("constA");
|
||
|
//lessY->printBuffer("lessY");
|
||
|
|
||
|
ASSERT_TRUE(expPhi.isSameShape(phi));
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
// now, we expect some values
|
||
|
|
||
|
auto x = graph->getVariableSpace()->getVariable(20);
|
||
|
auto y = graph->getVariableSpace()->getVariable(21);
|
||
|
|
||
|
ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5);
|
||
|
ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReadTensorArrayLoop_1) {
|
||
|
auto exp('c', {5, 2}, {3., 6., 9., 12., 15., 18., 21., 24., 27., 30.});
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/tensor_array_loop.fb");
|
||
|
|
||
|
ASSERT_TRUE(graph != nullptr);
|
||
|
|
||
|
//graph->printOut();
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
auto variableSpace = graph->getVariableSpace();
|
||
|
|
||
|
ASSERT_TRUE(variableSpace->hasVariable(23,0));
|
||
|
|
||
|
auto z = variableSpace->getVariable(23)->getNDArray();
|
||
|
|
||
|
//z->printShapeInfo("z shape");
|
||
|
//z->printIndexedBuffer("z buffer");
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
*/
|
||
|
|
||
|
/*
|
||
|
TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) {
|
||
|
// TF graph:
|
||
|
// https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0
|
||
|
nd4j::ops::assign<float> op1;
|
||
|
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/nested_while.fb");
|
||
|
|
||
|
ASSERT_TRUE(graph != nullptr);
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
auto x = graph->getVariableSpace()->getVariable(28);
|
||
|
auto y = graph->getVariableSpace()->getVariable(29);
|
||
|
auto z = graph->getVariableSpace()->getVariable(11, 2);
|
||
|
|
||
|
ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5);
|
||
|
ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5);
|
||
|
|
||
|
// we should have only 3 cycles in nested loop
|
||
|
ASSERT_NEAR(30.0f, z->getNDArray()->meanNumber(), 1e-5);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
*/
|
||
|
/*
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReadTensorArray_1) {
|
||
|
// TF graph: https://gist.github.com/raver119/3265923eed48feecc465d17ec842b6e2
|
||
|
|
||
|
auto exp('c', {3, 2}, {1.000000, 1.000000, 2.000000, 2.000000, 3.000000, 3.000000});
|
||
|
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/tensor_array.fb");
|
||
|
|
||
|
ASSERT_TRUE(graph != nullptr);
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(14));
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(14)->getNDArray();
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
*/
|
||
|
/*
|
||
|
TEST_F(FlatBuffersTest, ReadStridedSlice_1) {
|
||
|
// TF graph: https://gist.github.com/raver119/fc3bf2d31c91e465c635b24020fd798d
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/tensor_slice.fb");
|
||
|
|
||
|
ASSERT_TRUE(graph != nullptr);
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(7));
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(7)->getNDArray();
|
||
|
|
||
|
ASSERT_NEAR(73.5f, z->e(0), 1e-5);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReduceDim_1) {
|
||
|
auto exp = NDArrayFactory::create<float>('c', {3});
|
||
|
exp.assign(3.0);
|
||
|
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb");
|
||
|
|
||
|
graph->printOut();
|
||
|
|
||
|
auto variableSpace = graph->getVariableSpace();
|
||
|
|
||
|
|
||
|
ASSERT_TRUE(variableSpace->hasVariable(1));
|
||
|
ASSERT_TRUE(variableSpace->hasVariable(2));
|
||
|
|
||
|
auto x = variableSpace->getVariable(1)->getNDArray();
|
||
|
auto y = variableSpace->getVariable(2)->getNDArray();
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
ASSERT_TRUE(variableSpace->hasVariable(3));
|
||
|
|
||
|
auto result = variableSpace->getVariable(3)->getNDArray();
|
||
|
|
||
|
result->printShapeInfo("z");
|
||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, ReduceDim_2) {
|
||
|
auto exp = NDArrayFactory::create<float>('c', {3, 1});
|
||
|
exp.assign(3.0);
|
||
|
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_true.fb");
|
||
|
|
||
|
graph->printOut();
|
||
|
|
||
|
auto variableSpace = graph->getVariableSpace();
|
||
|
|
||
|
|
||
|
ASSERT_TRUE(variableSpace->hasVariable(1));
|
||
|
ASSERT_TRUE(variableSpace->hasVariable(2));
|
||
|
|
||
|
auto x = variableSpace->getVariable(1)->getNDArray();
|
||
|
auto y = variableSpace->getVariable(2)->getNDArray();
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
ASSERT_TRUE(variableSpace->hasVariable(3));
|
||
|
|
||
|
auto result = variableSpace->getVariable(3)->getNDArray();
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||
|
ASSERT_TRUE(exp.equalsTo(result));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
*/
|
||
|
|
||
|
#ifdef GRAPH_FILES_OK
|
||
|
TEST_F(FlatBuffersTest, Ae_00) {
|
||
|
nd4j::ops::rank op1;
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb");
|
||
|
|
||
|
auto exp = NDArrayFactory::create<float>('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f});
|
||
|
|
||
|
// graph->printOut();
|
||
|
|
||
|
ASSERT_EQ(OutputMode_VARIABLE_SPACE, graph->getExecutorConfiguration()->_outputMode);
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||
|
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(18));
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(18)->getNDArray();
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, expand_dims) {
|
||
|
nd4j::ops::rank op1;
|
||
|
|
||
|
auto exp = NDArrayFactory::create<float>('c', {3, 1, 4}, {-0.95938617f, -1.20301781f, 1.22260064f, 0.50172403f, 0.59972949f, 0.78568028f, 0.31609724f, 1.51674747f, 0.68013491f, -0.05227458f, 0.25903158f, 1.13243439f});
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/expand_dim.fb");
|
||
|
|
||
|
// graph->printOut();
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5));
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(5)->getNDArray();
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, transpose) {
|
||
|
nd4j::ops::rank op1;
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/transpose.fb");
|
||
|
|
||
|
//graph->printOut();
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, Test_Stitches) {
|
||
|
nd4j::ops::realdiv op0;
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/partition_stitch_misc.fb");
|
||
|
//graph->printOut();
|
||
|
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, Test_GruDynamicMnist) {
|
||
|
nd4j::Environment::getInstance()->setDebug(false);
|
||
|
nd4j::Environment::getInstance()->setVerbose(false);
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||
|
//graph->printOut();
|
||
|
|
||
|
auto timeStart = std::chrono::system_clock::now();
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||
|
|
||
|
auto timeEnd = std::chrono::system_clock::now();
|
||
|
|
||
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds> (timeEnd - timeStart).count();
|
||
|
|
||
|
// nd4j_printf("GRU time 1 time %lld us\n", outerTime);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
TEST_F(FlatBuffersTest, Test_Non2D_2) {
|
||
|
nd4j::Environment::getInstance()->setDebug(false);
|
||
|
nd4j::Environment::getInstance()->setVerbose(false);
|
||
|
nd4j::ops::realdiv op0;
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_2.fb");
|
||
|
//graph->printOut();
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(FlatBuffersTest, Test_TensorDotMisc) {
|
||
|
Environment::getInstance()->setVerbose(false);
|
||
|
Environment::getInstance()->setDebug(false);
|
||
|
|
||
|
auto e = NDArrayFactory::create<float>('c', {1, 3, 16, 20}, {4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 2.f, 3.f, 5.f, 5.f, 1.f, 4.f, 6.f, 3.f, 2.f, 1.f, 5.f, 4.f, 4.f, 4.f, 4.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 5.f, 3.f, 6.f, 5.f, 4.f, 4.f, 3.f, 6.f, 1.f, 2.f, 4.f, 2.f, 6.f, 4.f, 2.f, 3.f, 2.f, 3.f, 1.f, 2.f, 4.f, 3.f, 5.f, 3.f, 3.f, 5.f, 2.f, 6.f, 3.f, 4.f, 4.f, 4.f, 4.f, 6.f, 4.f, 5.f, 2.f, 5.f, 5.f, 5.f, 5.f, 2.f, 4.f, 4.f, 4.f, 5.f, 4.f, 3.f, 6.f, 3.f, 4.f, 5.f, 2.f, 5.f, 4.f, 4.f, 5.f, 4.f, 3.f, 4.f, 5.f, 5.f, 3.f, 5.f, 6.f, 6.f, 3.f, 4.f, 5.f, 7.f, 6.f, 5.f, 2.f, 4.f, 5.f, 5.f, 4.f, 5.f, 4.f, 4.f, 6.f, 3.f, 4.f, 5.f, 4.f, 6.f, 2.f, 3.f, 4.f, 3.f, 3.f, 2.f, 2.f, 3.f, 4.f, 7.f, 3.f, 5.f, 4.f, 5.f, 4.f, 4.f, 4.f, 4.f, 6.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 5.f, 2.f, 2.f, 1.f, 6.f, 2.f, 2.f, 3.f, 4.f, 5.f, 5.f, 3.f, 6.f, 6.f, 4.f, 3.f, 3.f, 3.f, 3.f, 3.f, 4.f, 5.f, 4.f, 4.f, 3.f, 5.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 4.f, 5.f, 3.f, 3.f, 7.f, 2.f, 3.f, 2.f, 6.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 2.f, 4.f, 3.f, 3.f, 4.f, 5.f, 3.f, 3.f, 6.f, 5.f, 3.f, 2.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 6.f, 7.f, 3.f, 4.f, 3.f, 5.f, 6.f, 7.f, 5.f, 6.f, 5.f, 7.f, 4.f, 6.f, 5.f, 5.f, 6.f, 4.f, 2.f, 5.f, 4.f, 3.f, 4.f, 1.f, 5.f, 5.f, 3.f, 2.f, 2.f, 6.f, 5.f, 5.f, 2.f, 5.f, 2.f, 4.f, 4.f, 5.f, 5.f, 4.f, 3.f, 7.f, 4.f, 5.f, 3.f, 3.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 4.f, 2.f, 4.f, 5.f, 3.f, 4.f, 5.f, 3.f, 7.f, 2.f, 1.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 3.f, 5.f, 3.f, 6.f, 6.f, 5.f, 3.f, 5.f, 3.f, 4.f, 3.f, 5.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, 3.f, 3.f, 4.f, 6.f, 4.f, 3.f, 7.f, 4.f, 4.f, 6.f, 7.f, 5.f, 5.f, 3.f, 1.f, 2.f, 5.f, 5.f, 2.f, 5.f, 7.f, 5.f, 3.f, 1.f, 4.f, 6.f, 5.f, 7.f, 5.f, 6.f, 5.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 3.f, 5.f, 2.f, 4.f, 5.f, 2.f, 5.f, 5.f, 4.f, 5.f, 4.f, 5.f, 2.f, 3.f, 5.f, 3.f, 6.f, 3.f, 4.f, 5.f, 3.f, 6.f, 5.f, 5.f, 6.f, 4.f, 6.f, 7.f, 4.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 2.f, 2.f, 5.f, 3.f, 5.f, 3.f, 4.f, 6.f, 3.f, 5.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 5.f, 2.f, 3.f, 5.f, 4.f, 2.f, 4.f, 5.f, 4.f, 2.f, 3.f, 4.f, 4.f, 5.f, 5.f, 1.f, 4.f, 4.f, 4.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 4.f, 4.f, 4.f, 3.f, 2.f, 3.f, 4.f, 8.f, 3.f, 5.f, 5.f, 5.f, 3.f, 3.f, 4.f, 5.f, 7.f, 3.f, 3.f, 3.f, 6.f, 6.f, 5.f, 5.f, 3.f, 4.f, 3.f, 8.f, 3.f, 4.f, 2.f, 3.f, 4.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 5.f, 4.f, 6.f, 4.f, 5.f, 5.f, 4.f, 7.f, 3.f, 5.f, 5.f, 3.f, 5.f, 5.f, 6.f, 4.f, 5.f, 4.f, 2.f, 7.f, 2.f, 3.f, 1.f, 4.f, 5.f, 5.f, 4.f, 4.f, 5.f, 7.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 6.f, 6.f, 3.f, 2.f, 4.f, 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 5.f, 1.f, 2.f, 3.f, 3.f, 4.f, 5.f, 4.f, 5.f, 4.f, 5.f, 6.f, 6.f, 6.f, 6.f, 7.f, 4.f, 3.f, 4.f, 5.f, 4.f, 4.f, 2.f, 5.f, 6.f, 4.f, 2.f, 2.f, 6.f, 5.f, 5.f, 1.f, 4.f, 2.f, 3.f, 4.f, 5.f, 5.f, 4.f, 5.f, 9.f, 4.f, 6.f, 4.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 5.f, 4.f, 3.f, 1.f, 3.f, 4.f, 3.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 3.f, 2.f, 3.f, 3.f, 4.f, 5.f, 6.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 4.f, 3.f, 4.f, 3.f, 6.f, 7.f, 6.f, 4.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 5.f, 5.f, 4.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 2.f, 4.f, 5.f, 3.f, 3.f, 7.f, 4.f, 2.f, 5.f, 6.f, 5.f, 5.f, 3.f, 1.f, 2.f, 4.f, 4.f, 1.f, 3.f, 6.f, 3.f, 3.f, 1.f, 4.f, 4.f, 4.f, 5.f, 3.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 3.f, 3.f, 4.f, 2.f, 5.f, 1.f, 3.f, 4.f, 2.f, 6.f, 4.f, 3.f, 4.f, 3.f, 3.f, 1.f, 2.f, 5.f, 2.f, 6.f, 4.f, 5.f, 6.f, 3.f, 6.f, 4.f, 4.f, 5.f, 3.f, 5.f, 6.f, 3.f, 4.f, 2.f, 4.f, 5.f, 5.f, 5.f, 2.f, 3.f, 4.f, 3.f, 5.f, 3.f, 3.f, 9.f, 6.f, 7.f, 7.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 4.f, 6.f, 5.f, 3.f, 5.f, 5.f, 5.f, 2.f, 4.f, 6.f, 7.f, 7.f, 5.f, 3.f, 4.f, 5.f, 4.f, 4.f, 5.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 2.f, 3.f, 6.f, 2.f, 5.f, 4.f, 4.f, 3.f, 3.f, 3.f, 5.f, 7.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 4.f, 2.f, 2.f, 1.f, 6.f, 1.f, 2.f, 2.f, 3.f, 5.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 2.f, 2.f, 2.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 5.f, 2.f, 4.f,
|
||
|
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_dot_misc.fb");
|
||
|
// graph->printOut();
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(Status::OK(), result);
|
||
|
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(77));
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(77,0)->getNDArray();
|
||
|
|
||
|
ASSERT_EQ(e, *z);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
|
||
|
TEST_F(FlatBuffersTest, Test_MNIST_00_1) {
|
||
|
auto e = NDArrayFactory::create<float>('c', {100, 10}, {0.00066107f, 0.00002358f, 0.00031518f, 0.00238039f, 0.00027216f, 0.00030300f, 0.00004659f, 0.98962247f, 0.00050380f, 0.00587174f, 0.05895791f, 0.00323104f, 0.52636790f, 0.12912551f, 0.00003951f, 0.03615341f, 0.22013727f, 0.00007333f, 0.02566659f, 0.00024759f, 0.00192367f, 0.90509874f, 0.01985082f, 0.02080356f, 0.00260053f, 0.00497826f, 0.01107823f, 0.00872595f, 0.01559795f, 0.00934229f, 0.98202229f, 0.00000150f, 0.00137381f, 0.00082931f, 0.00001806f, 0.00384426f, 0.00758274f, 0.00305049f, 0.00052152f, 0.00075617f, 0.01094264f, 0.00044708f, 0.03576852f, 0.00711267f, 0.65963465f, 0.00734364f, 0.02747800f, 0.06494589f, 0.02966754f, 0.15665947f, 0.00035806f, 0.95196360f, 0.00622721f, 0.01610696f, 0.00084180f, 0.00139947f, 0.00127350f, 0.00577912f, 0.00980321f, 0.00624705f, 0.00167418f, 0.00125611f, 0.00109477f, 0.04061511f, 0.57403159f, 0.08173440f, 0.00423709f, 0.10187119f, 0.07103974f, 0.12244581f, 0.00073566f, 0.00624759f, 0.00559816f, 0.01215601f, 0.08299568f, 0.06209232f, 0.01742392f, 0.01341172f, 0.02181461f, 0.77752429f, 0.08474547f, 0.00957346f, 0.29235491f, 0.00243696f, 0.06653537f, 0.03792902f, 0.43910959f, 0.00344940f, 0.02626713f, 0.03759870f, 0.00143713f, 0.00011047f, 0.00018820f, 0.00047970f, 0.02127167f, 0.00308758f, 0.00093357f, 0.17067374f, 0.00545499f, 0.79636300f, 0.95257199f, 0.00002157f, 0.00647615f, 0.01024892f, 0.00005942f, 0.01910058f, 0.00044579f, 0.00008416f, 0.01097712f, 0.00001441f, 0.16705236f, 0.01782482f, 0.17580827f, 0.06262068f, 0.03860324f, 0.01763505f, 0.32766294f, 0.00555595f, 0.17227779f, 0.01495883f, 0.00180449f, 0.00010494f, 0.00075124f, 0.00161161f, 0.08859238f, 0.00364861f, 0.00162414f, 0.06005199f, 0.00805061f, 0.83375996f, 0.97355360f, 0.00000305f, 0.00144336f, 0.00051544f, 0.00010043f, 0.00714774f, 0.00021183f, 0.00042562f, 0.01294680f, 0.00365222f, 0.00026871f, 0.95752406f, 0.00408361f, 0.02153200f, 0.00015639f, 0.00153930f, 0.00323335f, 0.00178700f, 0.00516464f, 0.00471107f, 0.07408376f, 0.00468759f, 0.02638813f, 0.33325842f, 0.01172767f, 0.36993489f, 0.01118315f, 0.01460529f, 0.14850292f, 0.00562817f, 0.00551083f, 0.00015134f, 0.01184739f, 0.00643833f, 0.11686873f, 0.00163741f, 0.00582776f, 0.11497385f, 0.0
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist_00.fb");
|
||
|
//graph->printOut();
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(Status::OK(), result);
|
||
|
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6));
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(6,0)->getNDArray();
|
||
|
|
||
|
ASSERT_EQ(e, *z);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
TEST_F(FlatBuffersTest, Test_MNIST_1) {
|
||
|
auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist.fb");
|
||
|
//graph->printOut();
|
||
|
|
||
|
auto result = GraphExecutioner::execute(graph);
|
||
|
ASSERT_EQ(Status::OK(), result);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
/*
|
||
|
// FIXME: uncomment this test once conv_0 fb reexported
|
||
|
TEST_F(FlatBuffersTest, nhwc_conv_0) {
|
||
|
nd4j::ops::rank<float> op1;
|
||
|
|
||
|
auto exp('c', {4, 2}, {2.958640f, 0.602521f, 7.571267f, 1.496686f, -2.292647f, -1.791460f, 13.055838f, 4.278642f});
|
||
|
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/conv_0.fb");
|
||
|
|
||
|
graph->printOut();
|
||
|
|
||
|
auto result = GraphExecutioner<float>::execute(graph);
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||
|
|
||
|
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11));
|
||
|
|
||
|
auto z = graph->getVariableSpace()->getVariable(11)->getNDArray();
|
||
|
|
||
|
//z->printShapeInfo("z buffr");
|
||
|
//z->printIndexedBuffer("z shape");
|
||
|
|
||
|
// [[2.96, 0.60],
|
||
|
// [7.57, 1.50],
|
||
|
// [-2.29, -1.79],
|
||
|
// [13.06, 4.28]]
|
||
|
|
||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
*/
|
||
|
|
||
|
|
||
|
/*
|
||
|
TEST_F(FlatBuffersTest, ReadLoops_SimpleWhile_1) {
|
||
|
// TF graph:
|
||
|
// https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0
|
||
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers("./resources/simple_while.fb");
|
||
|
|
||
|
ASSERT_TRUE(graph != nullptr);
|
||
|
|
||
|
Nd4jStatus status = GraphExecutioner<float>::execute(graph);
|
||
|
|
||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||
|
|
||
|
delete graph;
|
||
|
}
|
||
|
|
||
|
*/
|
||
|
#endif
|