diff --git a/libnd4j/include/graph/VariableProxy.h b/libnd4j/include/graph/VariableProxy.h index c2a6e9c62..63027f237 100644 --- a/libnd4j/include/graph/VariableProxy.h +++ b/libnd4j/include/graph/VariableProxy.h @@ -53,11 +53,11 @@ namespace nd4j { virtual std::vector getVariables(); - virtual void putVariable(std::pair& pair, NDArray *array); + virtual Variable* putVariable(std::pair& pair, NDArray *array); virtual void putVariable(std::pair& pair, Variable *variable); virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, NDArray *array); - virtual void putVariable(int id, int idx, NDArray *array); + virtual Variable* putVariable(int id, int idx, NDArray *array); virtual void putVariable(int id, int idx, NDArray &array); virtual void putVariable(int id, int idx, Variable *array); diff --git a/libnd4j/include/graph/VariableSpace.h b/libnd4j/include/graph/VariableSpace.h index 94cdf6bb0..6ae0339ab 100644 --- a/libnd4j/include/graph/VariableSpace.h +++ b/libnd4j/include/graph/VariableSpace.h @@ -95,11 +95,11 @@ namespace nd4j { virtual std::vector getVariables(); - virtual void putVariable(std::pair& pair, NDArray *array); + virtual Variable* putVariable(std::pair& pair, NDArray *array); virtual void putVariable(std::pair& pair, Variable *variable); virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, NDArray *array); - virtual void putVariable(int id, int idx, NDArray *array); + virtual Variable* putVariable(int id, int idx, NDArray *array); virtual void putVariable(int id, int idx, NDArray &array); virtual void putVariable(int id, int idx, Variable *array); diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 26f8875a8..02955a9ca 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -140,7 +140,7 @@ namespace nd4j { auto io = _fastpath_out.empty(); // two options here. // either both IN/OUT are filled - auto b1 = (!ie && !io); + auto b1 = (!ie && !io) || (!ie && _isInplace); // or at least something is filled, and FastPath is NOT forbidden auto b2 = (!ie || !io) && !_forbidFastPath; diff --git a/libnd4j/include/graph/impl/VariableProxy.cpp b/libnd4j/include/graph/impl/VariableProxy.cpp index e8abf1310..5dee4b261 100644 --- a/libnd4j/include/graph/impl/VariableProxy.cpp +++ b/libnd4j/include/graph/impl/VariableProxy.cpp @@ -172,8 +172,8 @@ namespace nd4j { } - void VariableProxy::putVariable(std::pair& pair, NDArray *array) { - _current->putVariable(pair, array); + Variable* VariableProxy::putVariable(std::pair& pair, NDArray *array) { + return _current->putVariable(pair, array); } @@ -195,8 +195,8 @@ namespace nd4j { _current->putVariable(id, idx, array); } - void VariableProxy::putVariable(int id, int idx, NDArray *array) { - _current->putVariable(id, idx, array); + Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) { + return _current->putVariable(id, idx, array); } diff --git a/libnd4j/include/graph/impl/VariableSpace.cpp b/libnd4j/include/graph/impl/VariableSpace.cpp index 287935eda..de6000630 100644 --- a/libnd4j/include/graph/impl/VariableSpace.cpp +++ b/libnd4j/include/graph/impl/VariableSpace.cpp @@ -200,14 +200,15 @@ namespace nd4j { return externalMemory() + internalMemory(); } - void nd4j::graph::VariableSpace::putVariable(std::pair& pair, NDArray *array) { + Variable* nd4j::graph::VariableSpace::putVariable(std::pair& pair, NDArray *array) { auto variable = new Variable(array, nullptr, pair.first, pair.second); this->putVariable(pair, variable); + return variable; } - void nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) { + Variable* nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) { std::pair pair(node, idx); - this->putVariable(pair, array); + return this->putVariable(pair, array); } void nd4j::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) { diff --git a/libnd4j/include/ops/declarable/DeclarableOp.h b/libnd4j/include/ops/declarable/DeclarableOp.h index 78f5fcaa4..39a0b7041 100644 --- a/libnd4j/include/ops/declarable/DeclarableOp.h +++ b/libnd4j/include/ops/declarable/DeclarableOp.h @@ -68,7 +68,7 @@ namespace nd4j { private: std::mutex _registrator; bool _registered = false; - + std::string _name; protected: OpDescriptor *_descriptor; NDArray *_scalar = nullptr; diff --git a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp index 8047da41a..9f93d37a8 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/reverse.cpp @@ -46,7 +46,8 @@ namespace ops { axis = *block.getIArguments(); if(axis.empty()) { // do not perform reversion - output->assign(input); + if (!block.isInplace()) + output->assign(input); } else { // check the consistency of input dimensions to reverse along diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 77cd5c937..25c25fc2d 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -53,22 +53,27 @@ namespace nd4j { DeclarableOp::DeclarableOp(const char *name, bool isLogical) { _descriptor = new OpDescriptor(name, isLogical); + _name = name; } DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) { _descriptor = new OpDescriptor(numInputs, name, scalar); + _name = name; } DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) { _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace); + _name = opName; } DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) { _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent); + _name = opName; } DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) { _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); + _name = opName; } DeclarableOp::~DeclarableOp() { @@ -141,6 +146,7 @@ namespace nd4j { GraphProfile *prof = nullptr; NodeProfile *node = nullptr; std::chrono::time_point inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd; + bool canUseFastPath = true; auto fp = ctx.isFastPath(); @@ -153,7 +159,7 @@ namespace nd4j { if (ctx.isInplace()) { if (Environment::getInstance()->isProfiling() && node != nullptr) { - if (ctx.isFastPath()) { + if (fp) { // } else { for (auto p: *ctx.inputs()) { @@ -168,6 +174,44 @@ namespace nd4j { } } + // if that's not fp, we can still propagate inputs and outputs + if (!fp) { + int cnt = 0; + auto id = ctx.nodeId(); + auto vs = ctx.getVariableSpace(); + for (auto p: *ctx.inputs()) { + auto var = ctx.variable(p); + if (var->variableType() == VariableType::NDARRAY) { + NDArray *array = var->getNDArray(); + ctx.setInputArray(cnt, array); + ctx.setOutputArray(cnt, array); + + + // in case of this override we might need to update outputs in the Graph VariableSpace as well + if (vs != nullptr) { + if (vs->hasVariable(id, cnt)) { + auto v2 = vs->getVariable(id, cnt); + if (!v2->hasNDArray()) { + v2->setNDArray(array); + v2->markRemovable(false); + + } + } else { + auto v2 = vs->putVariable(id, cnt, array); + v2->markRemovable(false); + } + } + + cnt++; + } else { + canUseFastPath = false; + } + } + } + + if (!canUseFastPath) + ctx.forbidFastPath(true); + // do nothing, getZ result will do the trick return static_cast(ctx.width()); } else { @@ -175,8 +219,6 @@ namespace nd4j { ShapeList inSha; int results = 0; - bool canUseFastPath = true; - if (Environment::getInstance()->isProfiling() && node != nullptr) inputStart = std::chrono::system_clock::now(); @@ -1007,21 +1049,26 @@ namespace nd4j { if (status != ND4J_STATUS_OK) return arrayList; - - for (int e = 0; e < DataTypeUtils::max(); e++) { - std::pair pair(1, e); - if (variableSpace.hasVariable(pair)) { - auto var = variableSpace.getVariable(pair); - auto arr = var->getNDArray(); - if (!arr->isAttached()) { - var->markRemovable(false); - arr->setContext(nd4j::LaunchContext ::defaultContext()); - arrayList->push_back(arr); - } else { - arrayList->push_back(arr->detach()); - } - } else - break; + if (!isInplace) { + for (int e = 0; e < DataTypeUtils::max(); e++) { + std::pair pair(1, e); + if (variableSpace.hasVariable(pair)) { + auto var = variableSpace.getVariable(pair); + auto arr = var->getNDArray(); + if (!arr->isAttached()) { + var->markRemovable(false); + arr->setContext(nd4j::LaunchContext::defaultContext()); + arrayList->push_back(arr); + } else { + arrayList->push_back(arr->detach()); + } + } else + break; + } + } else { + for (auto v:inputs) { + arrayList->push_back(v); + } } return arrayList;