/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #ifndef LIBND4J_VARIABLESPACE_H #define LIBND4J_VARIABLESPACE_H #include #include #include #include #include #include #include #include #include #include #include #include #include namespace nd4j { namespace graph { class ND4J_EXPORT VariableSpace { protected: nd4j::memory::Workspace *_workspace; // stash is NOT cloned nd4j::graph::Stash _stash; MAP_IMPL, Variable*> _paired; MAP_IMPL _symbolic; MAP_IMPL _variables; std::vector _external; std::vector _internal; std::vector _lists; std::vector _placeholders; void silentPutVariable(std::pair& pair, Variable *variable); int _auto_counter = -1; std::mutex _varmap; MAP_IMPL _temporary; std::vector *_handles; FlowPath* _flow = nullptr; public: VariableSpace(); virtual ~VariableSpace(); virtual VariableSpace& operator=(const VariableSpace& other); virtual int numberOfPlaceholders(); virtual std::vector* getPlaceholders(); virtual void setWorkspace(nd4j::memory::Workspace *workspace); virtual LaunchContext* launchContext(); virtual bool hasExternalVariable(int it); virtual bool hasExternalVariable(std::pair& pair); virtual bool hasExternalVariable(std::string *symbol); virtual bool hasVariable(int id); virtual bool hasVariable(int id, int idx); virtual bool hasVariable(std::pair& pair); virtual bool hasVariable(std::string *symbol); virtual nd4j::graph::Variable* getVariable(int id); virtual nd4j::graph::Variable* getVariable(int id, int idx); virtual nd4j::graph::Variable* getVariable(std::pair& pair); virtual nd4j::graph::Variable* getVariable(std::string *symbol); virtual std::vector getVariables(); virtual void putVariable(std::pair& pair, NDArray *array); virtual void putVariable(std::pair& pair, Variable *variable); virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, NDArray *array); virtual void putVariable(int id, int idx, NDArray *array); virtual void putVariable(int id, int idx, NDArray &array); virtual void putVariable(int id, int idx, Variable *array); virtual void dropVariable(std::pair &pair); virtual void dropVariable(int id, int idx); virtual void trackList(nd4j::NDArrayList *list); virtual void putOutputVariable(Variable *variable); virtual void replaceVariable(Variable *variable); // memory-related statistics virtual Nd4jLong externalMemory(); virtual Nd4jLong internalMemory(); virtual Nd4jLong totalMemory(); virtual int externalEntries(); virtual int internalEntries(); virtual int totalEntries(); virtual nd4j::graph::VariableSpace* clone(); std::vector *handles(); nd4j::graph::VariableSpace* asT(); void injectVariable(std::pair &pair, Variable* variable); virtual nd4j::graph::Stash* getStash(); virtual std::vector * getExternalVariables(); virtual void setFlowPath(FlowPath* timers); virtual FlowPath* flowPath(); }; } } #endif //LIBND4J_VARIABLESPACE_H