2021-02-01 13:31:45 +01:00
|
|
|
/* ******************************************************************************
|
|
|
|
*
|
2019-06-06 14:21:15 +02:00
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
2021-02-01 13:31:45 +01:00
|
|
|
* See the NOTICE file distributed with this work for additional
|
|
|
|
* information regarding copyright ownership.
|
2019-06-06 14:21:15 +02:00
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <graph/VariableSpace.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
#include <legacy/NativeOps.h>
|
2019-06-06 14:21:15 +02:00
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
namespace sd {
|
2019-06-06 14:21:15 +02:00
|
|
|
namespace graph {
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<sd::graph::Variable *> * sd::graph::VariableSpace::getExternalVariables() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return &_external;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Stash* sd::graph::VariableSpace::getStash() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return &_stash;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::VariableSpace* sd::graph::VariableSpace::clone() {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto result = new VariableSpace();
|
|
|
|
|
|
|
|
for (auto const& x : _paired) {
|
|
|
|
std::pair<int, int> pair(x.first.first, x.first.second);
|
|
|
|
|
|
|
|
Variable* clonedVar = x.second->clone();
|
|
|
|
|
|
|
|
result->injectVariable(pair, clonedVar);
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void VariableSpace::setWorkspace(sd::memory::Workspace *workspace) {
|
2019-06-06 14:21:15 +02:00
|
|
|
//_workspace = *workspace;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::VariableSpace* sd::graph::VariableSpace::asT() {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto result = new VariableSpace();
|
|
|
|
|
|
|
|
for (auto const& x : _paired) {
|
|
|
|
std::pair<int, int> pair(x.first.first, x.first.second);
|
|
|
|
|
|
|
|
//Variable* clonedVar = x.second->template asT<N>();
|
|
|
|
|
|
|
|
//result->injectVariable(pair, clonedVar);
|
|
|
|
}
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::injectVariable(std::pair<int, int> &pair, Variable* variable) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (pair.second == 0) {
|
|
|
|
if (pair.first < 0)
|
|
|
|
this->_variables[pair.first] = variable;
|
|
|
|
else
|
|
|
|
this->_temporary[pair.first] = variable;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (variable->getName() != nullptr && variable->getName()->length() > 0)
|
|
|
|
this->_symbolic[*(variable->getName())] = variable;
|
|
|
|
|
|
|
|
this->_paired[pair] = variable;
|
|
|
|
|
|
|
|
this->_handles->push_back(variable);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<sd::graph::Variable*> * sd::graph::VariableSpace::getPlaceholders() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return &_placeholders;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int sd::graph::VariableSpace ::numberOfPlaceholders() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _placeholders.size();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::VariableSpace::hasVariable(std::string *symbol) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _symbolic.count(*symbol) == 1;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Variable * sd::graph::VariableSpace::getVariable(std::string *symbol) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _symbolic.at(*symbol);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::VariableSpace::hasVariable(int id, int index) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(id, index);
|
|
|
|
return hasVariable(pair);
|
|
|
|
}
|
|
|
|
|
|
|
|
bool VariableSpace::hasExternalVariable(int id) {
|
|
|
|
if (!hasVariable(id))
|
|
|
|
return false;
|
|
|
|
|
|
|
|
auto var = getVariable(id);
|
|
|
|
return var->isExternal();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool VariableSpace::hasExternalVariable(std::pair<int,int>& pair) {
|
|
|
|
if (!hasVariable(pair))
|
|
|
|
return false;
|
|
|
|
|
|
|
|
auto var = getVariable(pair);
|
|
|
|
return var->isExternal();
|
|
|
|
}
|
|
|
|
|
|
|
|
bool VariableSpace::hasExternalVariable(std::string *symbol) {
|
|
|
|
if (!hasVariable(symbol))
|
|
|
|
return false;
|
|
|
|
|
|
|
|
auto var = getVariable(symbol);
|
|
|
|
return var->isExternal();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Variable * sd::graph::VariableSpace::getVariable(int id, int index) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(id, index);
|
|
|
|
return getVariable(pair);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Variable * sd::graph::VariableSpace::getVariable(std::pair<int, int>& pair) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (pair.first < 0)
|
|
|
|
return getVariable(pair.first);
|
2020-02-25 09:13:23 +01:00
|
|
|
else
|
2019-06-06 14:21:15 +02:00
|
|
|
return _paired.at(pair);
|
|
|
|
|
|
|
|
nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second);
|
2020-02-25 09:13:23 +01:00
|
|
|
throw std::runtime_error("Unknown variable requested");
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::VariableSpace::hasVariable(int id) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _variables.count(id) == 1 || _temporary.count(id) == 1;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool sd::graph::VariableSpace::hasVariable(std::pair<int,int>& id) {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _paired.count(id) > 0;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::putOutputVariable(Variable *variable) {
|
2019-06-06 14:21:15 +02:00
|
|
|
//putVariable(_auto_counter--, variable);
|
|
|
|
putVariable(variable->id(), variable);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int sd::graph::VariableSpace::externalEntries() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _external.size();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int sd::graph::VariableSpace::internalEntries() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _internal.size();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
int sd::graph::VariableSpace::totalEntries() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return externalEntries() + internalEntries();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong sd::graph::VariableSpace::externalMemory() {
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong size = 0;
|
|
|
|
for (auto n: _external) {
|
|
|
|
size += n->getNDArray()->memoryFootprint();
|
|
|
|
}
|
|
|
|
|
|
|
|
return size;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<Variable*> VariableSpace::getVariables() {
|
|
|
|
std::vector<Variable*> result;
|
|
|
|
|
|
|
|
for (auto v: _internal)
|
|
|
|
result.emplace_back(v);
|
|
|
|
|
|
|
|
for (auto v: _external)
|
|
|
|
result.emplace_back(v);
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong sd::graph::VariableSpace::internalMemory() {
|
2019-06-06 14:21:15 +02:00
|
|
|
Nd4jLong size = 0;
|
|
|
|
for (auto n: _internal) {
|
|
|
|
size += n->getNDArray()->memoryFootprint();
|
|
|
|
}
|
|
|
|
|
|
|
|
return size;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Nd4jLong sd::graph::VariableSpace::totalMemory() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return externalMemory() + internalMemory();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Variable* sd::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
2019-06-06 14:21:15 +02:00
|
|
|
auto variable = new Variable(array, nullptr, pair.first, pair.second);
|
|
|
|
this->putVariable(pair, variable);
|
2020-02-28 10:06:30 +01:00
|
|
|
return variable;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
Variable* sd::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(node, idx);
|
2020-02-28 10:06:30 +01:00
|
|
|
return this->putVariable(pair, array);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) {
|
2019-06-06 14:21:15 +02:00
|
|
|
std::pair<int, int> pair(node, idx);
|
|
|
|
this->putVariable(pair, variable);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::silentPutVariable(std::pair<int,int>& pair, Variable *variable) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_varmap.lock();
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
//std::pair<std::pair<int, int>, sd::graph::Variable *> p(pair, variable);
|
2019-06-06 14:21:15 +02:00
|
|
|
_paired[pair] = variable;
|
|
|
|
|
|
|
|
_varmap.unlock();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::putVariable(std::pair<int,int>& pair, Variable *variable) {
|
2019-06-06 14:21:15 +02:00
|
|
|
silentPutVariable(pair, variable);
|
|
|
|
|
|
|
|
if (variable->isPlaceholder())
|
|
|
|
_placeholders.push_back(variable);
|
|
|
|
|
|
|
|
// copying duplicate for compatibility
|
|
|
|
if (pair.second == 0 && !this->hasVariable(pair.first)) {
|
|
|
|
this->putVariable(pair.first, variable);
|
|
|
|
} else {
|
|
|
|
if (variable->getName() != nullptr && variable->getName()->length() != 0) {
|
|
|
|
_symbolic[*(variable->getName())] = variable;
|
|
|
|
}
|
|
|
|
|
|
|
|
_varmap.lock();
|
|
|
|
|
|
|
|
_handles->push_back(variable);
|
|
|
|
|
|
|
|
_varmap.unlock();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void VariableSpace::trackList(sd::NDArrayList* list) {
|
2019-06-06 14:21:15 +02:00
|
|
|
_lists.emplace_back(list);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::putVariable(int id, Variable *variable) {
|
2019-06-06 14:21:15 +02:00
|
|
|
// we don't want to add variables more then once
|
|
|
|
if (_variables.count(id) > 0 || _temporary.count(id) > 0) {
|
|
|
|
auto local = id < 0 ? _variables.at(id) : _temporary.at(id);
|
|
|
|
|
|
|
|
if (!local->hasNDArray() && variable->hasNDArray()) {
|
|
|
|
local->setNDArray(variable->getNDArray());
|
2020-02-13 18:59:35 +01:00
|
|
|
|
|
|
|
// we're inheriting this from Variable
|
|
|
|
local->markReadOnly(variable->isReadOnly());
|
|
|
|
local->markRemovable(variable->isRemovable());
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
2020-02-13 18:59:35 +01:00
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
_varmap.lock();
|
|
|
|
|
|
|
|
_handles->emplace_back(variable);
|
|
|
|
|
|
|
|
if (_auto_counter >= id)
|
|
|
|
_auto_counter = id - 1;
|
|
|
|
|
|
|
|
variable->setId(id);
|
|
|
|
|
|
|
|
if (variable->getName() != nullptr && variable->getName()->length() != 0) {
|
2020-03-02 10:49:41 +01:00
|
|
|
//std::pair<std::string, sd::graph::Variable *> pair(*(variable->getName()), variable);
|
2019-06-06 14:21:15 +02:00
|
|
|
_symbolic[*(variable->getName())] = variable;
|
|
|
|
}
|
|
|
|
|
|
|
|
// we have special list for external variables to ensure graph completeness
|
|
|
|
|
|
|
|
if (id < 0) {
|
|
|
|
//if (variable->isExternal())
|
|
|
|
_external.push_back(variable);
|
|
|
|
|
|
|
|
_variables[id] = variable;
|
|
|
|
} else {
|
|
|
|
_internal.push_back(variable);
|
|
|
|
|
|
|
|
_temporary[id] = variable;
|
|
|
|
}
|
|
|
|
|
|
|
|
_varmap.unlock();
|
|
|
|
|
|
|
|
std::pair<int,int> pair(id, 0);
|
|
|
|
if (!hasVariable(pair)) {
|
|
|
|
this->silentPutVariable(pair, variable);
|
|
|
|
|
|
|
|
if (variable->isPlaceholder())
|
|
|
|
_placeholders.push_back(variable);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::putVariable(int id, int idx, NDArray &array) {
|
|
|
|
auto *var = new sd::graph::Variable(&array, "", id, idx);
|
2020-02-13 18:59:35 +01:00
|
|
|
var->markRemovable(false);
|
|
|
|
var->markReadOnly(true);
|
|
|
|
|
|
|
|
// let's see if this op needs
|
|
|
|
bool d = this->hasVariable(id, idx);
|
|
|
|
|
|
|
|
this->putVariable(id, var);
|
|
|
|
|
|
|
|
// if var for this nodeid already exists - we'll just delete variable
|
|
|
|
if (d)
|
|
|
|
delete var;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
void sd::graph::VariableSpace::putVariable(int id, NDArray *array) {
|
|
|
|
auto *var = new sd::graph::Variable(array);
|
2019-06-06 14:21:15 +02:00
|
|
|
this->putVariable(id, var);
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::Variable * sd::graph::VariableSpace::getVariable(int id) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if (id < 0) {
|
2020-02-25 09:13:23 +01:00
|
|
|
return _variables.at(id);
|
2019-06-06 14:21:15 +02:00
|
|
|
} else {
|
2020-02-25 09:13:23 +01:00
|
|
|
return _temporary.at(id);
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
LaunchContext* sd::graph::VariableSpace::launchContext() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return LaunchContext::defaultContext();
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
std::vector<Variable*>* sd::graph::VariableSpace::handles() {
|
2019-06-06 14:21:15 +02:00
|
|
|
return _handles;
|
|
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
|
|
* FIXME: this thing have nice chances to become backend-specific!
|
|
|
|
*/
|
2020-03-02 10:49:41 +01:00
|
|
|
sd::graph::VariableSpace::~VariableSpace() {
|
2019-06-06 14:21:15 +02:00
|
|
|
// loop through variables and release them
|
|
|
|
for (auto p: *_handles) {
|
|
|
|
delete p;
|
|
|
|
}
|
|
|
|
|
|
|
|
delete _handles;
|
|
|
|
|
|
|
|
//_internal.clear();
|
|
|
|
//_external.clear();
|
|
|
|
//_temporary.clear();
|
|
|
|
|
|
|
|
//nd4j_printf("Number of NDArrayLists in this space: [%i]\n", _lists.size())
|
|
|
|
for (auto p: _lists)
|
|
|
|
delete p;
|
|
|
|
|
|
|
|
_lists.clear();
|
|
|
|
}
|
|
|
|
|
|
|
|
VariableSpace& VariableSpace::operator=(const VariableSpace& other) {
|
|
|
|
if (this == &other) return *this;
|
|
|
|
|
|
|
|
for (auto const& x : other._paired) {
|
|
|
|
std::pair<int, int> pair(x.first.first, x.first.second);
|
|
|
|
|
|
|
|
Variable* clonedVar = x.second->clone();
|
|
|
|
|
|
|
|
if (pair.second == 0) {
|
|
|
|
if (pair.first < 0)
|
|
|
|
this->_variables[pair.first] = clonedVar;
|
|
|
|
else
|
|
|
|
this->_temporary[pair.first] = clonedVar;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (clonedVar->getName() != nullptr && clonedVar->getName()->length() > 0)
|
|
|
|
this->_symbolic[*(clonedVar->getName())] = clonedVar;
|
|
|
|
|
|
|
|
this->_paired[pair] = clonedVar;
|
|
|
|
|
|
|
|
this->_handles->push_back(clonedVar);
|
|
|
|
}
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
void VariableSpace::replaceVariable(Variable *variable) {
|
|
|
|
bool replaced = false;
|
|
|
|
// trying name first
|
|
|
|
if (variable->getName() != nullptr && !variable->getName()->empty()) {
|
|
|
|
nd4j_printf("Trying to replace variable by name: [%s]\n", variable->getName()->c_str());
|
|
|
|
if (hasVariable(variable->getName())) {
|
|
|
|
nd4j_printf("Replacing by name: [%s]\n", variable->getName()->c_str());
|
|
|
|
auto vs = getVariable(variable->getName());
|
|
|
|
dropVariable(vs->id(), vs->index());
|
|
|
|
putVariable(vs->id(), vs->index(), variable);
|
|
|
|
//delete vs;
|
|
|
|
replaced = true;
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
nd4j_printf("Trying to replace variable by id: [%i:%i]\n", variable->id(), variable->index());
|
|
|
|
if (hasVariable(variable->id(), variable->index())) {
|
|
|
|
nd4j_printf("Replacing by id: [%i:%i]\n", variable->id(), variable->index());
|
|
|
|
auto vs = getVariable(variable->id(), variable->index());
|
|
|
|
dropVariable(variable->id(), variable->index());
|
|
|
|
putVariable(vs->id(), vs->index(), variable);
|
|
|
|
//delete vs;
|
|
|
|
replaced = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!replaced) {
|
|
|
|
nd4j_printf("wasn't able to replace variable, putting\n", "");
|
|
|
|
putVariable(variable->id(), variable->index(), variable);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void VariableSpace::dropVariable(std::pair<int,int> &pair) {
|
|
|
|
dropVariable(pair.first, pair.second);
|
|
|
|
}
|
|
|
|
|
|
|
|
void VariableSpace::dropVariable(int id, int idx) {
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void VariableSpace::setFlowPath(FlowPath* flow) {
|
|
|
|
_flow = flow;
|
|
|
|
}
|
|
|
|
|
|
|
|
FlowPath* VariableSpace::flowPath() {
|
|
|
|
return _flow;
|
|
|
|
}
|
|
|
|
|
|
|
|
VariableSpace::VariableSpace() {
|
|
|
|
_handles = new std::vector<Variable *>;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|