From f6442b6724397ea30feae8b1f4f59720d307b540 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 25 Feb 2020 11:13:23 +0300 Subject: [PATCH] few minor tweaks (#272) Signed-off-by: raver119 --- libnd4j/include/graph/impl/Context.cpp | 6 ++--- libnd4j/include/graph/impl/VariableSpace.cpp | 26 +++---------------- .../generic/activations/identity.cpp | 10 ++++--- .../generic/parity_ops/stop_gradient.cpp | 11 +++++--- 4 files changed, 20 insertions(+), 33 deletions(-) diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 5add8280d..671d89c24 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -261,12 +261,12 @@ namespace nd4j { } Variable* Context::variable(std::pair& p) { - if (!_variableSpace->hasVariable(p)) { + try { + return _variableSpace->getVariable(p); + } catch (std::exception &e) { nd4j_printf("Node %i; Non-existent variable requested: [%i:%i]\n", this->_nodeId, p.first, p.second); throw std::runtime_error("Bad variable"); } - - return _variableSpace->getVariable(p); } void Context::pushNDArrayToVariableSpace(int nodeId, int index, NDArray *array, bool removable) { diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 735f0260a..287935eda 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -132,23 +132,13 @@ namespace nd4j { } nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(std::pair& pair) { -// if (pair.first == 0) -// throw "0 requested"; - - //nd4j_debug("Requested variable: [%i:%i]\n", pair.first, pair.second); - if (pair.first < 0) return getVariable(pair.first); - else if (_paired.count(pair) > 0) + else return _paired.at(pair); - else { - if (hasVariable(pair.first) && pair.second == 0) - return getVariable(pair.first); - } nd4j_printf("Unknown variable requested: [%i,%i]\n", pair.first, pair.second); - - return nullptr; + throw std::runtime_error("Unknown variable requested"); } bool nd4j::graph::VariableSpace::hasVariable(int id) { @@ -335,18 +325,10 @@ namespace nd4j { } nd4j::graph::Variable * nd4j::graph::VariableSpace::getVariable(int id) { -// _varmap.lock(); - if (id < 0) { - auto v = _variables.at(id); - // _varmap.unlock(); - - return v; + return _variables.at(id); } else { - auto v = _temporary.at(id); - // _varmap.unlock(); - - return v; + return _temporary.at(id); } } diff --git a/libnd4j/include/ops/declarable/generic/activations/identity.cpp b/libnd4j/include/ops/declarable/generic/activations/identity.cpp index e424772fc..c2b600374 100644 --- a/libnd4j/include/ops/declarable/generic/activations/identity.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/identity.cpp @@ -26,11 +26,13 @@ namespace nd4j { namespace ops { OP_IMPL(identity, 1, 1, true) { - auto first = INPUT_VARIABLE(0); - auto z = OUTPUT_VARIABLE(0); + if (!block.isInplace()) { + auto first = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); - if (!block.isInplace()) - first->applyTransform(nd4j::transform::Identity, *z); + // we hope for memcpy here + z->assign(first); + } return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp index 81f81c326..746e1f9cf 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/stop_gradient.cpp @@ -26,10 +26,13 @@ namespace nd4j { namespace ops { OP_IMPL(stop_gradient, 1, 1, true) { - auto x = INPUT_VARIABLE(0); - auto out = OUTPUT_VARIABLE(0); - // just for lulz - x->applyTransform(transform::Identity, *out); + if (!block.isInplace()) { + auto x = INPUT_VARIABLE(0); + auto out = OUTPUT_VARIABLE(0); + + // we hope for memcpy here + out->assign(x); + } return Status::OK(); }