one more bert-like test

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-18 11:20:38 +03:00
parent 22c7aa9acf
commit da39a63c9b
1 changed files with 50 additions and 0 deletions

View File

@ -149,6 +149,56 @@ TEST_F(PlaygroundTests, test_bert_1) {
delete graph;
}
TEST_F(PlaygroundTests, test_bert_2) {
// this test will run ONLY if this model exists
if (nd4j::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0)
return;
auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb");
//graph->printOut();
graph->tagInplaceNodes();
/*
// validating graph now
auto status = GraphExecutioner::execute(graph);
ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198));
auto array = graph->getVariableSpace()->getVariable(198)->getNDArray();
ASSERT_EQ(z, *array);
*/
nd4j::Environment::getInstance()->setProfiling(true);
auto profile = GraphProfilingHelper::profile(graph, 1);
profile->printOut();
nd4j::Environment::getInstance()->setProfiling(false);
delete profile;
/*
std::vector<Nd4jLong> values;
for (int e = 0; e < 1; e++) {
auto timeStart = std::chrono::system_clock::now();
GraphExecutioner::execute(graph);
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
values.emplace_back(outerTime);
}
std::sort(values.begin(), values.end());
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
*/
delete graph;
}
TEST_F(PlaygroundTests, test_one_off_ops_1) {
auto x = NDArrayFactory::create<float>('c', {4, 128, 768});
auto y = NDArrayFactory::create<float>('c', {4, 128, 1});