few minor tweaks (#272)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-25 11:13:23 +03:00 committed by GitHub
parent 241ed05c64
commit f6442b6724
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 33 deletions

View File

@ -261,12 +261,12 @@ namespace nd4j {
}
Variable* Context::variable(std::pair<int,int>& p) {
if (!_variableSpace->hasVariable(p)) {
try {
return _variableSpace->getVariable(p);
} catch (std::exception &e) {
nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second);
throw std::runtime_error("Bad variable");
}
return _variableSpace->getVariable(p);
}
void Context::pushNDArrayToVariableSpace(int nodeId, int index, NDArray *array, bool removable) {

View File

@ -132,23 +132,13 @@ namespace nd4j {
}
nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(std::pair<int, int>& pair) {
// if (pair.first == 0)
// throw "0 requested";
//nd4j_debug("Requested variable: [%i:%i]\n", pair.first, pair.second);
if (pair.first < 0)
return getVariable(pair.first);
else if (_paired.count(pair) > 0)
else
return _paired.at(pair);
else {
if (hasVariable(pair.first) && pair.second == 0)
return getVariable(pair.first);
}
nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second);
return nullptr;
throw std::runtime_error("Unknown variable requested");
}
bool nd4j::graph::VariableSpace::hasVariable(int id) {
@ -335,18 +325,10 @@ namespace nd4j {
}
nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(int id) {
// _varmap.lock();
if (id < 0) {
auto v = _variables.at(id);
// _varmap.unlock();
return v;
return _variables.at(id);
} else {
auto v = _temporary.at(id);
// _varmap.unlock();
return v;
return _temporary.at(id);
}
}

View File

@ -26,11 +26,13 @@
namespace nd4j {
namespace ops {
OP_IMPL(identity, 1, 1, true) {
auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
if (!block.isInplace()) {
auto first = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
if (!block.isInplace())
first->applyTransform(nd4j::transform::Identity, *z);
// we hope for memcpy here
z->assign(first);
}
return Status::OK();
}

View File

@ -26,10 +26,13 @@
namespace nd4j {
namespace ops {
OP_IMPL(stop_gradient, 1, 1, true) {
auto x = INPUT_VARIABLE(0);
auto out = OUTPUT_VARIABLE(0);
// just for lulz
x->applyTransform(transform::Identity, *out);
if (!block.isInplace()) {
auto x = INPUT_VARIABLE(0);
auto out = OUTPUT_VARIABLE(0);
// we hope for memcpy here
out->assign(x);
}
return Status::OK();
}