/* ****************************************************************************** * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include namespace sd { namespace graph { VariableProxy::VariableProxy(VariableSpace* ref) { if (ref == nullptr) _backed = new VariableSpace(); _backed = ref; _current = new VariableSpace(); } VariableProxy::~VariableProxy() { delete _current; } int VariableProxy::numberOfPlaceholders() { return _backed->numberOfPlaceholders(); } std::vector* VariableProxy::getPlaceholders() { return _backed->getPlaceholders(); } bool VariableProxy::hasExternalVariable(int it) { return _backed->hasExternalVariable(it); } bool VariableProxy::hasExternalVariable(std::pair& pair) { return _backed->hasExternalVariable(pair); } bool VariableProxy::hasExternalVariable(std::string *symbol) { return _backed->hasExternalVariable(symbol); } bool VariableProxy::hasVariable(int id) { return _current->hasVariable(id) || _backed->hasVariable(id); } bool VariableProxy::hasVariable(int id, int idx) { return _current->hasVariable(id, idx) || _backed->hasVariable(id, idx); } bool VariableProxy::hasVariable(std::pair& pair) { return _current->hasVariable(pair) || _backed->hasVariable(pair); } void VariableProxy::dropVariable(std::pair &pair) { dropVariable(pair.first, pair.second); } void VariableProxy::dropVariable(int id, int idx) { assert(_current->hasVariable(id, idx)); _current->dropVariable(id, idx); } std::vector VariableProxy::getVariables() { std::vector result; auto b = _backed->getVariables(); auto c = _current->getVariables(); for (auto v: b) result.emplace_back(v); for (auto v: c) result.emplace_back(v); return result; } bool VariableProxy::hasVariable(std::string *symbol) { return _current->hasVariable(symbol) || _backed->hasVariable(symbol); } sd::graph::Variable *VariableProxy::getVariable(int id) { if (_current->hasVariable(id)) return _current->getVariable(id); if (_backed->hasVariable(id)) return _backed->getVariable(id); nd4j_printf("Unable to get Variable to proxy: [%i]\n", id); throw std::runtime_error("Bad arguments"); } sd::graph::Variable *VariableProxy::getVariable(int id, int idx) { if (_current->hasVariable(id, idx)) return _current->getVariable(id, idx); if (_backed->hasVariable(id, idx)) return _backed->getVariable(id, idx); nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", id, idx); throw std::runtime_error("Bad arguments"); } sd::graph::Variable *VariableProxy::getVariable(std::pair& pair) { if (_current->hasVariable(pair)) return _current->getVariable(pair); if (_backed->hasVariable(pair)) return _backed->getVariable(pair); nd4j_printf("Unable to get Variable to proxy: [%i:%i]\n", pair.first, pair.second); throw std::runtime_error("Bad arguments"); } sd::graph::Variable *VariableProxy::getVariable(std::string *symbol) { if (_current->hasVariable(symbol)) return _current->getVariable(symbol); if (_backed->hasVariable(symbol)) return _backed->getVariable(symbol); nd4j_printf("Unable to get Variable to proxy: [%s]\n", symbol->c_str()); throw std::runtime_error("Bad arguments"); } void VariableProxy::replaceVariable(Variable *variable) { if (variable->getName() != nullptr && !variable->getName()->empty()) { // if variable has name defined - we should resolve it via backing var space if (_backed->hasVariable(variable->getName())) { auto origVar = _backed->getVariable(variable->getName()); variable->setId(origVar->id(), origVar->index()); _current->replaceVariable(variable); } else _current->replaceVariable(variable); } else // if proxy has variable - that's one story _current->replaceVariable(variable); } Variable* VariableProxy::putVariable(std::pair& pair, NDArray *array) { return _current->putVariable(pair, array); } void VariableProxy::putVariable(std::pair& pair, Variable *variable) { _current->putVariable(pair, variable); } void VariableProxy::putVariable(int id, Variable *variable) { _current->putVariable(id, variable); } void VariableProxy::putVariable(int id, NDArray *array) { _current->putVariable(id, array); } void sd::graph::VariableProxy::putVariable(int id, int idx, NDArray &array) { _current->putVariable(id, idx, array); } Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) { return _current->putVariable(id, idx, array); } void VariableProxy::putVariable(int id, int idx, Variable *array) { _current->putVariable(id, idx, array); } void VariableProxy::trackList(sd::NDArrayList* list) { _current->trackList(list); } sd::graph::Stash* VariableProxy::getStash() { return _current->getStash(); } void VariableProxy::setFlowPath(FlowPath* timers) { _current->setFlowPath(timers); } FlowPath* VariableProxy::flowPath() { return _current->flowPath(); } void VariableProxy::putOutputVariable(Variable *variable) { _current->putOutputVariable(variable); } Nd4jLong VariableProxy::externalMemory() { return _backed->externalMemory() + _current->externalMemory(); } Nd4jLong VariableProxy::internalMemory() { return _backed->internalMemory() + _current->internalMemory(); } Nd4jLong VariableProxy::totalMemory() { return _backed->totalMemory() + _current->totalMemory(); } int VariableProxy::externalEntries() { return _backed->externalEntries() + _current->externalEntries(); } int VariableProxy::internalEntries() { return _backed->internalEntries() + _current->internalEntries(); } int VariableProxy::totalEntries() { return _backed->totalEntries() + _current->totalEntries(); } sd::graph::VariableSpace* VariableProxy::clone() { auto clone = new VariableProxy(_backed); delete clone->_current; clone->_current = _current->clone(); return clone; } VariableSpace& VariableProxy::operator=(const VariableSpace& other) { if (this == &other) return *this; nd4j_printf("VariableProxy = not implemented\n",""); return *this; } sd::memory::Workspace * sd::graph::VariableProxy::workspace() { return _workspace; } } }