/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #ifndef LIBND4J_GRAPH_H #define LIBND4J_GRAPH_H #include <list> #include <algorithm> #include <unordered_map> #include <map> //#include <NDArray.h> #include <graph/Node.h> #include <graph/Stash.h> #include <graph/Scope.h> #include <graph/Variable.h> #include <graph/VariableSpace.h> #include <graph/generated/node_generated.h> #include <graph/generated/graph_generated.h> #include <graph/generated/config_generated.h> #include <graph/ExecutorConfiguration.h> #include <ops/declarable/OpDescriptor.h> namespace sd { namespace graph { class ND4J_EXPORT Graph { protected: ExecutorConfiguration *_configuration; VariableSpace *_variableSpace; Stash* _stash; // this list holds references to Node ptrs, which should be free'd in Graph destructor std::vector<Node*> _handles; // vector holds ID's of top nodes only std::vector<int > *_nodes; MAP_IMPL<int, sd::graph::Node*> *_mapped; MAP_IMPL<int, std::vector<sd::graph::Node*> *> *_onion; MAP_IMPL<int, sd::graph::Node*> _unmapped; std::vector<int> _unmappedMap; // macOS? std::mutex _mutexPreprocessing; std::atomic<bool> _built; std::vector<int> _output; std::vector<int> _autos; MAP_IMPL<int, Scope*> _mappedScopes; std::vector<Scope*> _scopes; //////////////////////////////////////// Nd4jStatus validateNode(sd::graph::Node *node); void expandOnion(int newLayer); void injectNode(sd::graph::Node *node); void pushToOutputOnce(int id); void printOutNode(Node* node); void prepareOutputs(); public: Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr); ~Graph(); // this method applies toposort to nodes void toposortNodes(); // method that'll print out graph Nd4jStatus validate(); // this method will build structured representation of graph Nd4jStatus buildGraph(); // this method will return estimated memory size (in bytes) required for 1 full graph execution round Nd4jLong estimateRequiredMemory(); // this method returns number of root nodes in this graph int rootNodes(); // this method returns total number of nodes in this graph int totalNodes(); int numberOfPlaceholders(); std::vector<sd::graph::Variable*>* getPlaceholders(); /** * This method returns pointer to thread_local VariableSpace * @return */ sd::graph::VariableSpace *getVariableSpace(); /** * This method adds given node to the graph * * @param node */ void addNode(sd::graph::Node *node); /** * This method returns layered representation of the graph * * @return */ MAP_IMPL<int, std::vector<sd::graph::Node*> *> *getOnion(); /** * This method returns map of all nodes of the graph * @return */ MAP_IMPL<int, sd::graph::Node*>* getMapped(); /** * This method returns outputs of this graph * @return */ std::vector<sd::graph::Variable*> *fetchOutputs(); /** * This method returns pointer to ExecutorConfiguration * * @return */ ExecutorConfiguration *getExecutorConfiguration(); /** * This method adds specified node (by ID) to de * @param id */ void addOutput(int id); /** * This method returns all nodes at once (order is NOT guaranteed) * @return */ std::vector<sd::graph::Node*> *getAllNodes(); /** * This method prints out Graph op-by-op, and respective inputs */ void printOut(); /** * This method collect all ops from the graph into ops vector */ std::vector<sd::ops::OpDescriptor> getOperations(); /** * This method returns Scope ptr specified with id * * @param id * @return */ Scope* scopeById(int id); /** * This method returns TRUE if specified ID refers to Scope, and false otherwise * @param id * @return */ bool hasScope(int id); /** * This method returns clone of the graph */ Graph* clone(); /** * This method returns clone of the graph, backed by VariableProxy instead of VariableSpace */ Graph* cloneWithProxy(); /** * This method removes reference to VariableSpace from this Graph */ void forgetVariableSpace(); /** * This method returns Node with given Id */ Node* nodeById(int nodeId); /** * This method returns True if node with given ID exists, False otherwise * @param nodeId * @return */ bool hasNode(int nodeId); /** * This method returns hash of given Graph instance */ Nd4jLong hashCode(); /** * PLEASE NOTE: This method will be moved to private section */ void tagInplaceNodes(); void replaceState(VariableSpace *state, ExecutorConfiguration *configuration); FORCEINLINE std::vector<int>* nodes() { return _nodes; } FORCEINLINE std::vector<int>* autos() { return &_autos; } FORCEINLINE std::vector<int>* output() { return &_output; } FORCEINLINE MAP_IMPL<int, Scope*>* scopes() { return &_mappedScopes; } FORCEINLINE bool built() { return _built.load(); } FORCEINLINE void pullState(Graph *other) { for (int e = 0; e < other->nodes()->size(); e++) this->_nodes->emplace_back(other->nodes()->at(e)); for (int e = 0; e < other->output()->size(); e++) this->_output.emplace_back(other->output()->at(e)); for (int e = 0; e < other->autos()->size(); e++) this->_autos.emplace_back(other->autos()->at(e)); for (auto &v: *other->scopes()) { auto scp = v.second->clone(); this->_mappedScopes[v.first] = scp; this->_scopes.emplace_back(scp); } for (auto &v: *other->getOnion()) { auto vec = this->_onion->count(v.first) > 0 ? this->_onion->at(v.first) : new std::vector<Node*>(); auto ovec = (*other->getOnion())[v.first]; for (auto x: *(ovec)) { auto n = x->clone(); vec->emplace_back(n); _handles.emplace_back(n); (*this->_mapped)[n->id()] = n; } if (this->_onion->count(v.first) < 1) (*this->_onion)[v.first] = vec; } this->_built.store(other->built()); } }; } } #endif //LIBND4J_GRAPH_H