190 lines
7.7 KiB
C++
190 lines
7.7 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
//
|
|
|
|
#include "GraphServer.h"
|
|
#include <graph/GraphHolder.h>
|
|
#include <graph/GraphExecutioner.h>
|
|
#include <graph/generated/result_generated.h>
|
|
#include <helpers/StringUtils.h>
|
|
#include <algorithm>
|
|
#include <stdexcept>
|
|
|
|
#include <graph/exceptions/unknown_graph_exception.h>
|
|
#include <graph/exceptions/graph_exists_exception.h>
|
|
#include <graph/exceptions/no_results_exception.h>
|
|
#include <graph/exceptions/graph_execution_exception.h>
|
|
|
|
|
|
|
|
namespace sd {
|
|
namespace graph {
|
|
grpc::Status GraphInferenceServerImpl::RegisterGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatGraph> *request_msg, flatbuffers::grpc::Message<FlatResponse> *response_msg) {
|
|
auto flat_graph = request_msg->GetRoot();
|
|
|
|
try {
|
|
// building our graph
|
|
auto graph = new Graph<float>(flat_graph);
|
|
|
|
// single data type for now
|
|
GraphHolder::getInstance().registerGraph<float>(flat_graph->id(), graph);
|
|
|
|
// sending out OK response
|
|
auto response_offset = CreateFlatResponse(mb_, 0);
|
|
mb_.Finish(response_offset);
|
|
*response_msg = mb_.ReleaseMessage<FlatResponse>();
|
|
assert(response_msg->Verify());
|
|
|
|
return grpc::Status::OK;
|
|
} catch (std::runtime_error &e) {
|
|
grpc::string gmsg("Caught runtime_error exception");
|
|
return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg);
|
|
}
|
|
}
|
|
|
|
grpc::Status GraphInferenceServerImpl::ReplaceGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatGraph> *request_msg, flatbuffers::grpc::Message<FlatResponse> *response_msg) {
|
|
auto flat_graph = request_msg->GetRoot();
|
|
|
|
try {
|
|
// building our graph
|
|
auto graph = new Graph<float>(flat_graph);
|
|
|
|
// single data type for now
|
|
GraphHolder::getInstance().replaceGraph(flat_graph->id(), graph);
|
|
|
|
// sending out OK response
|
|
auto response_offset = CreateFlatResponse(mb_, 0);
|
|
mb_.Finish(response_offset);
|
|
*response_msg = mb_.ReleaseMessage<FlatResponse>();
|
|
assert(response_msg->Verify());
|
|
|
|
return grpc::Status::OK;
|
|
} catch (sd::graph::unknown_graph_exception &e) {
|
|
grpc::string gmsg(e.message());
|
|
return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg);
|
|
} catch (std::runtime_error &e) {
|
|
grpc::string gmsg("Caught runtime_error exception");
|
|
return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg);
|
|
}
|
|
}
|
|
|
|
grpc::Status GraphInferenceServerImpl::ForgetGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatDropRequest> *request_msg, flatbuffers::grpc::Message<FlatResponse> *response_msg) {
|
|
try {
|
|
|
|
// getting drop request
|
|
auto request = request_msg->GetRoot();
|
|
|
|
// dropping out graph (any datatype)
|
|
GraphHolder::getInstance().dropGraphAny(request->id());
|
|
|
|
// sending out OK response
|
|
auto response_offset = CreateFlatResponse(mb_, 0);
|
|
mb_.Finish(response_offset);
|
|
*response_msg = mb_.ReleaseMessage<FlatResponse>();
|
|
assert(response_msg->Verify());
|
|
|
|
return grpc::Status::OK;
|
|
} catch (sd::graph::unknown_graph_exception &e) {
|
|
grpc::string gmsg(e.message());
|
|
return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg);
|
|
}
|
|
}
|
|
|
|
grpc::Status GraphInferenceServerImpl::InferenceRequest( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatInferenceRequest> *request_msg, flatbuffers::grpc::Message<FlatResult> *response_msg) {
|
|
auto request = request_msg->GetRoot();
|
|
|
|
try {
|
|
// GraphHolder
|
|
auto response_offset = GraphHolder::getInstance().execute(request->id(), mb_, request);
|
|
|
|
mb_.Finish(response_offset);
|
|
*response_msg = mb_.ReleaseMessage<FlatResult>();
|
|
assert(response_msg->Verify());
|
|
|
|
return grpc::Status::OK;
|
|
} catch (sd::graph::no_results_exception &e) {
|
|
grpc::string gmsg(e.message());
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, gmsg);
|
|
} catch (sd::graph::unknown_graph_exception &e) {
|
|
grpc::string gmsg(e.message());
|
|
return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg);
|
|
} catch (sd::graph::graph_execution_exception &e) {
|
|
grpc::string gmsg(e.message());
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, gmsg);
|
|
} catch (std::runtime_error &e) {
|
|
grpc::string gmsg("Caught runtime_error exception");
|
|
return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void RunServer(int port) {
|
|
assert(port > 0 && port < 65535);
|
|
|
|
std::string server_address("0.0.0.0:");
|
|
server_address += sd::StringUtils::valueToString<int>(port);
|
|
|
|
sd::graph::GraphInferenceServerImpl service;
|
|
auto registrator = sd::ops::OpRegistrator::getInstance();
|
|
|
|
grpc::ServerBuilder builder;
|
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
|
builder.RegisterService(&service);
|
|
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
|
std::cerr << "Server listening on: [" << server_address << "]; Number of operations: [" << registrator->numberOfOperations() << "]"<< std::endl;
|
|
|
|
server->Wait();
|
|
}
|
|
|
|
char* getCmdOption(char **begin, char **end, const std::string & option) {
|
|
auto itr = std::find(begin, end, option);
|
|
if (itr != end && ++itr != end)
|
|
return *itr;
|
|
|
|
return 0;
|
|
}
|
|
|
|
bool cmdOptionExists(char** begin, char** end, const std::string& option) {
|
|
return std::find(begin, end, option) != end;
|
|
}
|
|
|
|
int main(int argc, char *argv[]) {
|
|
/**
|
|
* basically we only care about few things here:
|
|
* 1) port number
|
|
* 2) if we should use gprc, json, or both
|
|
* 3) if there's any graph(s) provided at startup
|
|
*/
|
|
int port = 40123;
|
|
if(cmdOptionExists(argv, argv+argc, "-p")) {
|
|
auto sPort = getCmdOption(argv, argv + argc, "-p");
|
|
port = atoi(sPort);
|
|
}
|
|
|
|
if(cmdOptionExists(argv, argv+argc, "-f")) {
|
|
auto file = getCmdOption(argv, argv + argc, "-f");
|
|
auto graph = GraphExecutioner<float>::importFromFlatBuffers(file);
|
|
sd::graph::GraphHolder::getInstance().registerGraph<float>(0L, graph);
|
|
}
|
|
|
|
RunServer(port);
|
|
|
|
return 0;
|
|
} |