/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author raver119@gmail.com
//

#ifndef LIBND4J_GRAPH_H
#define LIBND4J_GRAPH_H

#include <list>
#include <algorithm>
#include <unordered_map>
#include <map>
//#include <NDArray.h>
#include <graph/Node.h>
#include <graph/Stash.h>
#include <graph/Scope.h>
#include <graph/Variable.h>
#include <graph/VariableSpace.h>
#include <graph/generated/node_generated.h>
#include <graph/generated/graph_generated.h>
#include <graph/generated/config_generated.h>
#include <graph/ExecutorConfiguration.h>
#include <ops/declarable/OpDescriptor.h>

namespace sd {
    namespace graph {

        class ND4J_EXPORT Graph {
        protected:
            ExecutorConfiguration *_configuration;
            VariableSpace *_variableSpace;
            Stash* _stash;

            // this list holds references to Node ptrs, which should be free'd in Graph destructor
            std::vector<Node*> _handles;

            // vector holds ID's of top nodes only
            std::vector<int > *_nodes;
            MAP_IMPL<int, sd::graph::Node*> *_mapped;

            MAP_IMPL<int, std::vector<sd::graph::Node*> *> *_onion;
            MAP_IMPL<int, sd::graph::Node*> _unmapped;
            std::vector<int> _unmappedMap; // macOS?

            std::mutex _mutexPreprocessing;
            std::atomic<bool> _built;

            std::vector<int> _output;
            std::vector<int> _autos;


            MAP_IMPL<int, Scope*> _mappedScopes;
            std::vector<Scope*> _scopes;

////////////////////////////////////////
            Nd4jStatus validateNode(sd::graph::Node *node);

            void expandOnion(int newLayer);

            void injectNode(sd::graph::Node *node);

            void pushToOutputOnce(int id);

            void printOutNode(Node* node);

            void prepareOutputs();

        public:
            Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr);

            ~Graph();

            // this method applies toposort to nodes
            void toposortNodes();

            // method that'll print out graph
            Nd4jStatus validate();

            // this method will build structured representation of graph
            Nd4jStatus buildGraph();

            // this method will return estimated memory size (in bytes) required for 1 full graph execution round
            Nd4jLong estimateRequiredMemory();

            // this method returns number of root nodes in this graph
            int rootNodes();

            // this method returns total number of nodes in this graph
            int totalNodes();

            int numberOfPlaceholders();

            std::vector<sd::graph::Variable*>* getPlaceholders();

            /**
             * This method returns pointer to thread_local VariableSpace
             * @return
             */
            sd::graph::VariableSpace *getVariableSpace();

            /**
             * This method adds given node to the graph
             *
             * @param node
             */
            void addNode(sd::graph::Node *node);

            /**
             * This method returns layered representation of the graph
             *
             * @return
             */
            MAP_IMPL<int, std::vector<sd::graph::Node*> *> *getOnion();

            /**
             * This method returns map of all nodes of the graph
             * @return
             */
            MAP_IMPL<int, sd::graph::Node*>* getMapped();

            /**
             * This method returns outputs of this graph
             * @return
             */
            std::vector<sd::graph::Variable*> *fetchOutputs();

            /**
             * This method returns pointer to ExecutorConfiguration
             *
             * @return
             */
            ExecutorConfiguration *getExecutorConfiguration();

            /**
             * This method adds specified node (by ID) to de
             * @param id
             */
            void addOutput(int id);

            /**
             * This method returns all nodes at once (order is NOT guaranteed)
             * @return
             */
            std::vector<sd::graph::Node*> *getAllNodes();

            /**
             * This method prints out Graph op-by-op, and respective inputs
             */
            void printOut();

            /**
             * This method collect all ops from the graph into ops vector
             */
            std::vector<sd::ops::OpDescriptor> getOperations();

            /**
             * This method returns Scope ptr specified with id
             *
             * @param id
             * @return
             */
            Scope* scopeById(int id);

            /**
             * This method returns TRUE if specified ID refers to Scope, and false otherwise
             * @param id
             * @return
             */
            bool hasScope(int id);

            /**
             * This method returns clone of the graph
             */
            Graph* clone();

            /**
             * This method returns clone of the graph, backed by VariableProxy instead of VariableSpace
             */
            Graph* cloneWithProxy();

            /**
             * This method removes reference to VariableSpace from this Graph
             */
            void forgetVariableSpace();

            /**
             * This method returns Node with given Id
             */
            Node* nodeById(int nodeId);

            /**
             * This method returns True if node with given ID exists, False otherwise
             * @param nodeId
             * @return
             */
            bool hasNode(int nodeId);

            /**
             * This method returns hash of given Graph instance
             */
            Nd4jLong hashCode();

            /**
             * PLEASE NOTE: This method will be moved to private section
             */
            void tagInplaceNodes();

            void replaceState(VariableSpace *state, ExecutorConfiguration *configuration);

            FORCEINLINE std::vector<int>* nodes() {
                return _nodes;
            }

            FORCEINLINE std::vector<int>* autos() {
                return &_autos;
            }

            FORCEINLINE std::vector<int>* output() {
                return &_output;
            }

            FORCEINLINE MAP_IMPL<int, Scope*>* scopes() {
                return &_mappedScopes;
            }

            FORCEINLINE bool built() {
                return _built.load();
            }

            FORCEINLINE void pullState(Graph *other) {
                for (int e = 0; e < other->nodes()->size(); e++)
                    this->_nodes->emplace_back(other->nodes()->at(e));

                for (int e = 0; e < other->output()->size(); e++)
                    this->_output.emplace_back(other->output()->at(e));
                
                for (int e = 0; e < other->autos()->size(); e++)
                    this->_autos.emplace_back(other->autos()->at(e));

                for (auto &v: *other->scopes()) {
                    auto scp = v.second->clone();
                    this->_mappedScopes[v.first] = scp;
                    this->_scopes.emplace_back(scp);
                }
                
                for (auto &v: *other->getOnion()) {
                    auto vec = this->_onion->count(v.first) > 0 ? this->_onion->at(v.first) : new std::vector<Node*>();

                    auto ovec = (*other->getOnion())[v.first];
                    for (auto x: *(ovec)) {
                        auto n = x->clone();
                        vec->emplace_back(n);
                        _handles.emplace_back(n);
                        (*this->_mapped)[n->id()] = n;
                    }

                    if (this->_onion->count(v.first) < 1)
                        (*this->_onion)[v.first] = vec;
                }

                this->_built.store(other->built());
            }
        };
    }
}

#endif //LIBND4J_GRAPH_H