parent
241ed05c64
commit
f6442b6724
|
@ -261,12 +261,12 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
Variable* Context::variable(std::pair<int,int>& p) {
|
||||
if (!_variableSpace->hasVariable(p)) {
|
||||
try {
|
||||
return _variableSpace->getVariable(p);
|
||||
} catch (std::exception &e) {
|
||||
nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second);
|
||||
throw std::runtime_error("Bad variable");
|
||||
}
|
||||
|
||||
return _variableSpace->getVariable(p);
|
||||
}
|
||||
|
||||
void Context::pushNDArrayToVariableSpace(int nodeId, int index, NDArray *array, bool removable) {
|
||||
|
|
|
@ -132,23 +132,13 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(std::pair<int, int>& pair) {
|
||||
// if (pair.first == 0)
|
||||
// throw "0 requested";
|
||||
|
||||
//nd4j_debug("Requested variable: [%i:%i]\n", pair.first, pair.second);
|
||||
|
||||
if (pair.first < 0)
|
||||
return getVariable(pair.first);
|
||||
else if (_paired.count(pair) > 0)
|
||||
else
|
||||
return _paired.at(pair);
|
||||
else {
|
||||
if (hasVariable(pair.first) && pair.second == 0)
|
||||
return getVariable(pair.first);
|
||||
}
|
||||
|
||||
nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second);
|
||||
|
||||
return nullptr;
|
||||
throw std::runtime_error("Unknown variable requested");
|
||||
}
|
||||
|
||||
bool nd4j::graph::VariableSpace::hasVariable(int id) {
|
||||
|
@ -335,18 +325,10 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(int id) {
|
||||
// _varmap.lock();
|
||||
|
||||
if (id < 0) {
|
||||
auto v = _variables.at(id);
|
||||
// _varmap.unlock();
|
||||
|
||||
return v;
|
||||
return _variables.at(id);
|
||||
} else {
|
||||
auto v = _temporary.at(id);
|
||||
// _varmap.unlock();
|
||||
|
||||
return v;
|
||||
return _temporary.at(id);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -26,11 +26,13 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
OP_IMPL(identity, 1, 1, true) {
|
||||
auto first = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
if (!block.isInplace()) {
|
||||
auto first = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
if (!block.isInplace())
|
||||
first->applyTransform(nd4j::transform::Identity, *z);
|
||||
// we hope for memcpy here
|
||||
z->assign(first);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -26,10 +26,13 @@
|
|||
namespace nd4j {
|
||||
namespace ops {
|
||||
OP_IMPL(stop_gradient, 1, 1, true) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto out = OUTPUT_VARIABLE(0);
|
||||
// just for lulz
|
||||
x->applyTransform(transform::Identity, *out);
|
||||
if (!block.isInplace()) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto out = OUTPUT_VARIABLE(0);
|
||||
|
||||
// we hope for memcpy here
|
||||
out->assign(x);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue