From da39a63c9bcbb49fed5f6a6f02c2613e406293d3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 18 Feb 2020 11:20:38 +0300 Subject: [PATCH] one more bert-like test Signed-off-by: raver119 --- .../layers_tests/PlaygroundTests.cpp | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 7cdf40c7f..93fb5d6b3 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -149,6 +149,56 @@ TEST_F(PlaygroundTests, test_bert_1) { delete graph; } +TEST_F(PlaygroundTests, test_bert_2) { + // this test will run ONLY if this model exists + if (nd4j::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0) + return; + + auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); + + //graph->printOut(); + + graph->tagInplaceNodes(); + + +/* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198)); + + auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); + ASSERT_EQ(z, *array); +*/ + + nd4j::Environment::getInstance()->setProfiling(true); + auto profile = GraphProfilingHelper::profile(graph, 1); + + profile->printOut(); + + nd4j::Environment::getInstance()->setProfiling(false); + delete profile; + +/* + std::vector values; + + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); + + GraphExecutioner::execute(graph); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +*/ + delete graph; +} + TEST_F(PlaygroundTests, test_one_off_ops_1) { auto x = NDArrayFactory::create('c', {4, 128, 768}); auto y = NDArrayFactory::create('c', {4, 128, 1});