better inplace exec with FastPath (#280)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-28 12:06:30 +03:00 committed by GitHub
parent 330a69d4e2
commit 5332ace32b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 81 additions and 32 deletions

View File

@ -53,11 +53,11 @@ namespace nd4j {
virtual std::vector<Variable*> getVariables(); virtual std::vector<Variable*> getVariables();
virtual void putVariable(std::pair<int,int>& pair, NDArray *array); virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
virtual void putVariable(std::pair<int,int>& pair, Variable *variable); virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array); virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array); virtual Variable* putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array); virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array); virtual void putVariable(int id, int idx, Variable *array);

View File

@ -95,11 +95,11 @@ namespace nd4j {
virtual std::vector<Variable*> getVariables(); virtual std::vector<Variable*> getVariables();
virtual void putVariable(std::pair<int,int>& pair, NDArray *array); virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
virtual void putVariable(std::pair<int,int>& pair, Variable *variable); virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
virtual void putVariable(int id, Variable *variable); virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array); virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array); virtual Variable* putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array); virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array); virtual void putVariable(int id, int idx, Variable *array);

View File

@ -140,7 +140,7 @@ namespace nd4j {
auto io = _fastpath_out.empty(); auto io = _fastpath_out.empty();
// two options here. // two options here.
// either both IN/OUT are filled // either both IN/OUT are filled
auto b1 = (!ie && !io); auto b1 = (!ie && !io) || (!ie && _isInplace);
// or at least something is filled, and FastPath is NOT forbidden // or at least something is filled, and FastPath is NOT forbidden
auto b2 = (!ie || !io) && !_forbidFastPath; auto b2 = (!ie || !io) && !_forbidFastPath;

View File

@ -172,8 +172,8 @@ namespace nd4j {
} }
void VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) { Variable* VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) {
_current->putVariable(pair, array); return _current->putVariable(pair, array);
} }
@ -195,8 +195,8 @@ namespace nd4j {
_current->putVariable(id, idx, array); _current->putVariable(id, idx, array);
} }
void VariableProxy::putVariable(int id, int idx, NDArray *array) { Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) {
_current->putVariable(id, idx, array); return _current->putVariable(id, idx, array);
} }

View File

@ -200,14 +200,15 @@ namespace nd4j {
return externalMemory() + internalMemory(); return externalMemory() + internalMemory();
} }
void nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) { Variable* nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
auto variable = new Variable(array, nullptr, pair.first, pair.second); auto variable = new Variable(array, nullptr, pair.first, pair.second);
this->putVariable(pair, variable); this->putVariable(pair, variable);
return variable;
} }
void nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) { Variable* nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
std::pair<int, int> pair(node, idx); std::pair<int, int> pair(node, idx);
this->putVariable(pair, array); return this->putVariable(pair, array);
} }
void nd4j::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) { void nd4j::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) {

View File

@ -68,7 +68,7 @@ namespace nd4j {
private: private:
std::mutex _registrator; std::mutex _registrator;
bool _registered = false; bool _registered = false;
std::string _name;
protected: protected:
OpDescriptor *_descriptor; OpDescriptor *_descriptor;
NDArray *_scalar = nullptr; NDArray *_scalar = nullptr;

View File

@ -46,7 +46,8 @@ namespace ops {
axis = *block.getIArguments(); axis = *block.getIArguments();
if(axis.empty()) { // do not perform reversion if(axis.empty()) { // do not perform reversion
output->assign(input); if (!block.isInplace())
output->assign(input);
} }
else { else {
// check the consistency of input dimensions to reverse along // check the consistency of input dimensions to reverse along

View File

@ -53,22 +53,27 @@ namespace nd4j {
DeclarableOp::DeclarableOp(const char *name, bool isLogical) { DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
_descriptor = new OpDescriptor(name, isLogical); _descriptor = new OpDescriptor(name, isLogical);
_name = name;
} }
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) { DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
_descriptor = new OpDescriptor(numInputs, name, scalar); _descriptor = new OpDescriptor(numInputs, name, scalar);
_name = name;
} }
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) { DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace); _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
_name = opName;
} }
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) { DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent); _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
_name = opName;
} }
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) { DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs); _descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
_name = opName;
} }
DeclarableOp::~DeclarableOp() { DeclarableOp::~DeclarableOp() {
@ -141,6 +146,7 @@ namespace nd4j {
GraphProfile *prof = nullptr; GraphProfile *prof = nullptr;
NodeProfile *node = nullptr; NodeProfile *node = nullptr;
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd; std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
bool canUseFastPath = true;
auto fp = ctx.isFastPath(); auto fp = ctx.isFastPath();
@ -153,7 +159,7 @@ namespace nd4j {
if (ctx.isInplace()) { if (ctx.isInplace()) {
if (Environment::getInstance()->isProfiling() && node != nullptr) { if (Environment::getInstance()->isProfiling() && node != nullptr) {
if (ctx.isFastPath()) { if (fp) {
// //
} else { } else {
for (auto p: *ctx.inputs()) { for (auto p: *ctx.inputs()) {
@ -168,6 +174,44 @@ namespace nd4j {
} }
} }
// if that's not fp, we can still propagate inputs and outputs
if (!fp) {
int cnt = 0;
auto id = ctx.nodeId();
auto vs = ctx.getVariableSpace();
for (auto p: *ctx.inputs()) {
auto var = ctx.variable(p);
if (var->variableType() == VariableType::NDARRAY) {
NDArray *array = var->getNDArray();
ctx.setInputArray(cnt, array);
ctx.setOutputArray(cnt, array);
// in case of this override we might need to update outputs in the Graph VariableSpace as well
if (vs != nullptr) {
if (vs->hasVariable(id, cnt)) {
auto v2 = vs->getVariable(id, cnt);
if (!v2->hasNDArray()) {
v2->setNDArray(array);
v2->markRemovable(false);
}
} else {
auto v2 = vs->putVariable(id, cnt, array);
v2->markRemovable(false);
}
}
cnt++;
} else {
canUseFastPath = false;
}
}
}
if (!canUseFastPath)
ctx.forbidFastPath(true);
// do nothing, getZ result will do the trick // do nothing, getZ result will do the trick
return static_cast<int>(ctx.width()); return static_cast<int>(ctx.width());
} else { } else {
@ -175,8 +219,6 @@ namespace nd4j {
ShapeList inSha; ShapeList inSha;
int results = 0; int results = 0;
bool canUseFastPath = true;
if (Environment::getInstance()->isProfiling() && node != nullptr) if (Environment::getInstance()->isProfiling() && node != nullptr)
inputStart = std::chrono::system_clock::now(); inputStart = std::chrono::system_clock::now();
@ -1007,21 +1049,26 @@ namespace nd4j {
if (status != ND4J_STATUS_OK) if (status != ND4J_STATUS_OK)
return arrayList; return arrayList;
if (!isInplace) {
for (int e = 0; e < DataTypeUtils::max<int>(); e++) { for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int,int> pair(1, e); std::pair<int, int> pair(1, e);
if (variableSpace.hasVariable(pair)) { if (variableSpace.hasVariable(pair)) {
auto var = variableSpace.getVariable(pair); auto var = variableSpace.getVariable(pair);
auto arr = var->getNDArray(); auto arr = var->getNDArray();
if (!arr->isAttached()) { if (!arr->isAttached()) {
var->markRemovable(false); var->markRemovable(false);
arr->setContext(nd4j::LaunchContext ::defaultContext()); arr->setContext(nd4j::LaunchContext::defaultContext());
arrayList->push_back(arr); arrayList->push_back(arr);
} else { } else {
arrayList->push_back(arr->detach()); arrayList->push_back(arr->detach());
} }
} else } else
break; break;
}
} else {
for (auto v:inputs) {
arrayList->push_back(v);
}
} }
return arrayList; return arrayList;