/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #ifndef LIBND4J_GRAPH_H #define LIBND4J_GRAPH_H #include #include #include #include //#include #include #include #include #include #include #include #include #include #include #include namespace sd { namespace graph { class ND4J_EXPORT Graph { protected: ExecutorConfiguration *_configuration; VariableSpace *_variableSpace; Stash* _stash; // this list holds references to Node ptrs, which should be free'd in Graph destructor std::vector _handles; // vector holds ID's of top nodes only std::vector *_nodes; MAP_IMPL *_mapped; MAP_IMPL *> *_onion; MAP_IMPL _unmapped; std::vector _unmappedMap; // macOS? std::mutex _mutexPreprocessing; std::atomic _built; std::vector _output; std::vector _autos; MAP_IMPL _mappedScopes; std::vector _scopes; //////////////////////////////////////// Nd4jStatus validateNode(sd::graph::Node *node); void expandOnion(int newLayer); void injectNode(sd::graph::Node *node); void pushToOutputOnce(int id); void printOutNode(Node* node); void prepareOutputs(); public: Graph(const FlatGraph *flatGraph = nullptr, VariableSpace *variableSpace = nullptr); ~Graph(); // this method applies toposort to nodes void toposortNodes(); // method that'll print out graph Nd4jStatus validate(); // this method will build structured representation of graph Nd4jStatus buildGraph(); // this method will return estimated memory size (in bytes) required for 1 full graph execution round Nd4jLong estimateRequiredMemory(); // this method returns number of root nodes in this graph int rootNodes(); // this method returns total number of nodes in this graph int totalNodes(); int numberOfPlaceholders(); std::vector* getPlaceholders(); /** * This method returns pointer to thread_local VariableSpace * @return */ sd::graph::VariableSpace *getVariableSpace(); /** * This method adds given node to the graph * * @param node */ void addNode(sd::graph::Node *node); /** * This method returns layered representation of the graph * * @return */ MAP_IMPL *> *getOnion(); /** * This method returns map of all nodes of the graph * @return */ MAP_IMPL* getMapped(); /** * This method returns outputs of this graph * @return */ std::vector *fetchOutputs(); /** * This method returns pointer to ExecutorConfiguration * * @return */ ExecutorConfiguration *getExecutorConfiguration(); /** * This method adds specified node (by ID) to de * @param id */ void addOutput(int id); /** * This method returns all nodes at once (order is NOT guaranteed) * @return */ std::vector *getAllNodes(); /** * This method prints out Graph op-by-op, and respective inputs */ void printOut(); /** * This method collect all ops from the graph into ops vector */ std::vector getOperations(); /** * This method returns Scope ptr specified with id * * @param id * @return */ Scope* scopeById(int id); /** * This method returns TRUE if specified ID refers to Scope, and false otherwise * @param id * @return */ bool hasScope(int id); /** * This method returns clone of the graph */ Graph* clone(); /** * This method returns clone of the graph, backed by VariableProxy instead of VariableSpace */ Graph* cloneWithProxy(); /** * This method removes reference to VariableSpace from this Graph */ void forgetVariableSpace(); /** * This method returns Node with given Id */ Node* nodeById(int nodeId); /** * This method returns True if node with given ID exists, False otherwise * @param nodeId * @return */ bool hasNode(int nodeId); /** * This method returns hash of given Graph instance */ Nd4jLong hashCode(); /** * PLEASE NOTE: This method will be moved to private section */ void tagInplaceNodes(); void replaceState(VariableSpace *state, ExecutorConfiguration *configuration); FORCEINLINE std::vector* nodes() { return _nodes; } FORCEINLINE std::vector* autos() { return &_autos; } FORCEINLINE std::vector* output() { return &_output; } FORCEINLINE MAP_IMPL* scopes() { return &_mappedScopes; } FORCEINLINE bool built() { return _built.load(); } FORCEINLINE void pullState(Graph *other) { for (int e = 0; e < other->nodes()->size(); e++) this->_nodes->emplace_back(other->nodes()->at(e)); for (int e = 0; e < other->output()->size(); e++) this->_output.emplace_back(other->output()->at(e)); for (int e = 0; e < other->autos()->size(); e++) this->_autos.emplace_back(other->autos()->at(e)); for (auto &v: *other->scopes()) { auto scp = v.second->clone(); this->_mappedScopes[v.first] = scp; this->_scopes.emplace_back(scp); } for (auto &v: *other->getOnion()) { auto vec = this->_onion->count(v.first) > 0 ? this->_onion->at(v.first) : new std::vector(); auto ovec = (*other->getOnion())[v.first]; for (auto x: *(ovec)) { auto n = x->clone(); vec->emplace_back(n); _handles.emplace_back(n); (*this->_mapped)[n->id()] = n; } if (this->_onion->count(v.first) < 1) (*this->_onion)[v.first] = vec; } this->_built.store(other->built()); } }; } } #endif //LIBND4J_GRAPH_H