/* ******************************************************************************
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author raver119@gmail.com
//

#include "GraphServer.h"
#include <graph/GraphHolder.h>
#include <graph/GraphExecutioner.h>
#include <graph/generated/result_generated.h>
#include <helpers/StringUtils.h>
#include <algorithm>
#include <stdexcept>

#include <graph/exceptions/unknown_graph_exception.h>
#include <graph/exceptions/graph_exists_exception.h>
#include <graph/exceptions/no_results_exception.h>
#include <graph/exceptions/graph_execution_exception.h>



namespace sd {
    namespace graph {
            grpc::Status GraphInferenceServerImpl::RegisterGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatGraph> *request_msg, flatbuffers::grpc::Message<FlatResponse> *response_msg) {
                auto flat_graph = request_msg->GetRoot();

                try {
                    // building our graph
                    auto graph = new Graph<float>(flat_graph);

                    // single data type for now
                    GraphHolder::getInstance().registerGraph<float>(flat_graph->id(), graph);

                    // sending out OK response
                    auto response_offset = CreateFlatResponse(mb_, 0);
                    mb_.Finish(response_offset);
                    *response_msg = mb_.ReleaseMessage<FlatResponse>();
                    assert(response_msg->Verify());

                    return grpc::Status::OK;
                } catch (std::runtime_error &e) {
                    grpc::string gmsg("Caught runtime_error exception");
                    return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg);
                }
            }

            grpc::Status GraphInferenceServerImpl::ReplaceGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatGraph> *request_msg, flatbuffers::grpc::Message<FlatResponse> *response_msg) {
                auto flat_graph = request_msg->GetRoot();

                try {
                    // building our graph
                    auto graph = new Graph<float>(flat_graph);

                    // single data type for now
                    GraphHolder::getInstance().replaceGraph(flat_graph->id(), graph);

                    // sending out OK response
                    auto response_offset = CreateFlatResponse(mb_, 0);
                    mb_.Finish(response_offset);
                    *response_msg = mb_.ReleaseMessage<FlatResponse>();
                    assert(response_msg->Verify());

                    return grpc::Status::OK;
                } catch (sd::graph::unknown_graph_exception &e) {
                    grpc::string gmsg(e.message());
                    return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg);
                } catch (std::runtime_error &e) {
                    grpc::string gmsg("Caught runtime_error exception");
                    return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg);
                }
            }

            grpc::Status GraphInferenceServerImpl::ForgetGraph( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatDropRequest> *request_msg, flatbuffers::grpc::Message<FlatResponse> *response_msg) {
                try {

                    // getting drop request
                    auto request = request_msg->GetRoot();

                    // dropping out graph (any datatype)
                    GraphHolder::getInstance().dropGraphAny(request->id());

                    // sending out OK response
                    auto response_offset = CreateFlatResponse(mb_, 0);
                    mb_.Finish(response_offset);
                    *response_msg = mb_.ReleaseMessage<FlatResponse>();
                    assert(response_msg->Verify());

                    return grpc::Status::OK;
                } catch (sd::graph::unknown_graph_exception &e) {
                    grpc::string gmsg(e.message());
                    return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg);
                }
            }

            grpc::Status GraphInferenceServerImpl::InferenceRequest( grpc::ServerContext *context, const flatbuffers::grpc::Message<FlatInferenceRequest> *request_msg, flatbuffers::grpc::Message<FlatResult> *response_msg) {
                auto request = request_msg->GetRoot();

                try {
                    // GraphHolder
                    auto response_offset = GraphHolder::getInstance().execute(request->id(), mb_, request);

                    mb_.Finish(response_offset);
                    *response_msg = mb_.ReleaseMessage<FlatResult>();
                    assert(response_msg->Verify());

                    return grpc::Status::OK;
                } catch (sd::graph::no_results_exception &e) {
                    grpc::string gmsg(e.message());
                    return grpc::Status(grpc::StatusCode::INTERNAL, gmsg);
                } catch (sd::graph::unknown_graph_exception &e) {
                    grpc::string gmsg(e.message());
                    return grpc::Status(grpc::StatusCode::NOT_FOUND, gmsg);
                } catch (sd::graph::graph_execution_exception &e) {
                    grpc::string gmsg(e.message());
                    return grpc::Status(grpc::StatusCode::INTERNAL, gmsg);
                } catch (std::runtime_error &e) {
                    grpc::string gmsg("Caught runtime_error exception");
                    return grpc::Status(grpc::StatusCode::UNKNOWN, gmsg);
                }
            }
    }
}

void RunServer(int port) {
  assert(port > 0 && port < 65535);

  std::string server_address("0.0.0.0:");
  server_address += sd::StringUtils::valueToString<int>(port);

  sd::graph::GraphInferenceServerImpl service;
  auto registrator = sd::ops::OpRegistrator::getInstance();

  grpc::ServerBuilder builder;
  builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
  builder.RegisterService(&service);
  std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
  std::cerr << "Server listening on: [" << server_address << "]; Number of operations: [" <<  registrator->numberOfOperations()  << "]"<< std::endl;

  server->Wait();
}

char* getCmdOption(char **begin, char **end, const std::string & option) {
    auto itr = std::find(begin, end, option);
    if (itr != end && ++itr != end)
        return *itr;

    return 0;
}

bool cmdOptionExists(char** begin, char** end, const std::string& option) {
    return std::find(begin, end, option) != end;
}

int main(int argc, char *argv[]) {
    /**
     * basically we only care about few things here:
     * 1) port number
     * 2) if we should use gprc, json, or both
     * 3) if there's any graph(s) provided at startup
     */
     int port = 40123;
     if(cmdOptionExists(argv, argv+argc, "-p")) {
        auto sPort = getCmdOption(argv, argv + argc, "-p");
        port = atoi(sPort);
     }

    if(cmdOptionExists(argv, argv+argc, "-f")) {
        auto file = getCmdOption(argv, argv + argc, "-f");
        auto graph = GraphExecutioner<float>::importFromFlatBuffers(file);
        sd::graph::GraphHolder::getInstance().registerGraph<float>(0L, graph);
    }

    RunServer(port);

    return 0;
}