few minor tweaks (#272)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-25 11:13:23 +03:00 committed by GitHub
parent 241ed05c64
commit f6442b6724
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 33 deletions

View File

@ -261,12 +261,12 @@ namespace nd4j {
} }
Variable* Context::variable(std::pair<int,int>& p) { Variable* Context::variable(std::pair<int,int>& p) {
if (!_variableSpace->hasVariable(p)) { try {
return _variableSpace->getVariable(p);
} catch (std::exception &e) {
nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second); nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second);
throw std::runtime_error("Bad variable"); throw std::runtime_error("Bad variable");
} }
return _variableSpace->getVariable(p);
} }
void Context::pushNDArrayToVariableSpace(int nodeId, int index, NDArray *array, bool removable) { void Context::pushNDArrayToVariableSpace(int nodeId, int index, NDArray *array, bool removable) {

View File

@ -132,23 +132,13 @@ namespace nd4j {
} }
nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(std::pair<int, int>& pair) { nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(std::pair<int, int>& pair) {
// if (pair.first == 0)
// throw "0 requested";
//nd4j_debug("Requested variable: [%i:%i]\n", pair.first, pair.second);
if (pair.first < 0) if (pair.first < 0)
return getVariable(pair.first); return getVariable(pair.first);
else if (_paired.count(pair) > 0) else
return _paired.at(pair); return _paired.at(pair);
else {
if (hasVariable(pair.first) && pair.second == 0)
return getVariable(pair.first);
}
nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second); nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second);
throw std::runtime_error("Unknown variable requested");
return nullptr;
} }
bool nd4j::graph::VariableSpace::hasVariable(int id) { bool nd4j::graph::VariableSpace::hasVariable(int id) {
@ -335,18 +325,10 @@ namespace nd4j {
} }
nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(int id) { nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(int id) {
// _varmap.lock();
if (id < 0) { if (id < 0) {
auto v = _variables.at(id); return _variables.at(id);
// _varmap.unlock();
return v;
} else { } else {
auto v = _temporary.at(id); return _temporary.at(id);
// _varmap.unlock();
return v;
} }
} }

View File

@ -26,11 +26,13 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
OP_IMPL(identity, 1, 1, true) { OP_IMPL(identity, 1, 1, true) {
if (!block.isInplace()) {
auto first = INPUT_VARIABLE(0); auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
if (!block.isInplace()) // we hope for memcpy here
first->applyTransform(nd4j::transform::Identity, *z); z->assign(first);
}
return Status::OK(); return Status::OK();
} }

View File

@ -26,10 +26,13 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
OP_IMPL(stop_gradient, 1, 1, true) { OP_IMPL(stop_gradient, 1, 1, true) {
if (!block.isInplace()) {
auto x = INPUT_VARIABLE(0); auto x = INPUT_VARIABLE(0);
auto out = OUTPUT_VARIABLE(0); auto out = OUTPUT_VARIABLE(0);
// just for lulz
x->applyTransform(transform::Identity, *out); // we hope for memcpy here
out->assign(x);
}
return Status::OK(); return Status::OK();
} }