better inplace exec with FastPath (#280)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-28 12:06:30 +03:00 committed by GitHub
parent 330a69d4e2
commit 5332ace32b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 81 additions and 32 deletions

View File

@ -53,11 +53,11 @@ namespace nd4j {
virtual std::vector<Variable*> getVariables();
virtual void putVariable(std::pair<int,int>& pair, NDArray *array);
virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array);
virtual Variable* putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array);

View File

@ -95,11 +95,11 @@ namespace nd4j {
virtual std::vector<Variable*> getVariables();
virtual void putVariable(std::pair<int,int>& pair, NDArray *array);
virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
virtual void putVariable(int id, Variable *variable);
virtual void putVariable(int id, NDArray *array);
virtual void putVariable(int id, int idx, NDArray *array);
virtual Variable* putVariable(int id, int idx, NDArray *array);
virtual void putVariable(int id, int idx, NDArray &array);
virtual void putVariable(int id, int idx, Variable *array);

View File

@ -140,7 +140,7 @@ namespace nd4j {
auto io = _fastpath_out.empty();
// two options here.
// either both IN/OUT are filled
auto b1 = (!ie && !io);
auto b1 = (!ie && !io) || (!ie && _isInplace);
// or at least something is filled, and FastPath is NOT forbidden
auto b2 = (!ie || !io) && !_forbidFastPath;

View File

@ -172,8 +172,8 @@ namespace nd4j {
}
void VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) {
_current->putVariable(pair, array);
Variable* VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) {
return _current->putVariable(pair, array);
}
@ -195,8 +195,8 @@ namespace nd4j {
_current->putVariable(id, idx, array);
}
void VariableProxy::putVariable(int id, int idx, NDArray *array) {
_current->putVariable(id, idx, array);
Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) {
return _current->putVariable(id, idx, array);
}

View File

@ -200,14 +200,15 @@ namespace nd4j {
return externalMemory() + internalMemory();
}
void nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
Variable* nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
auto variable = new Variable(array, nullptr, pair.first, pair.second);
this->putVariable(pair, variable);
return variable;
}
void nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
Variable* nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
std::pair<int, int> pair(node, idx);
this->putVariable(pair, array);
return this->putVariable(pair, array);
}
void nd4j::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) {

View File

@ -68,7 +68,7 @@ namespace nd4j {
private:
std::mutex _registrator;
bool _registered = false;
std::string _name;
protected:
OpDescriptor *_descriptor;
NDArray *_scalar = nullptr;

View File

@ -46,6 +46,7 @@ namespace ops {
axis = *block.getIArguments();
if(axis.empty()) { // do not perform reversion
if (!block.isInplace())
output->assign(input);
}
else {

View File

@ -53,22 +53,27 @@ namespace nd4j {
DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
_descriptor = new OpDescriptor(name, isLogical);
_name = name;
}
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
_descriptor = new OpDescriptor(numInputs, name, scalar);
_name = name;
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
_name = opName;
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
_name = opName;
}
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
_name = opName;
}
DeclarableOp::~DeclarableOp() {
@ -141,6 +146,7 @@ namespace nd4j {
GraphProfile *prof = nullptr;
NodeProfile *node = nullptr;
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
bool canUseFastPath = true;
auto fp = ctx.isFastPath();
@ -153,7 +159,7 @@ namespace nd4j {
if (ctx.isInplace()) {
if (Environment::getInstance()->isProfiling() && node != nullptr) {
if (ctx.isFastPath()) {
if (fp) {
//
} else {
for (auto p: *ctx.inputs()) {
@ -168,6 +174,44 @@ namespace nd4j {
}
}
// if that's not fp, we can still propagate inputs and outputs
if (!fp) {
int cnt = 0;
auto id = ctx.nodeId();
auto vs = ctx.getVariableSpace();
for (auto p: *ctx.inputs()) {
auto var = ctx.variable(p);
if (var->variableType() == VariableType::NDARRAY) {
NDArray *array = var->getNDArray();
ctx.setInputArray(cnt, array);
ctx.setOutputArray(cnt, array);
// in case of this override we might need to update outputs in the Graph VariableSpace as well
if (vs != nullptr) {
if (vs->hasVariable(id, cnt)) {
auto v2 = vs->getVariable(id, cnt);
if (!v2->hasNDArray()) {
v2->setNDArray(array);
v2->markRemovable(false);
}
} else {
auto v2 = vs->putVariable(id, cnt, array);
v2->markRemovable(false);
}
}
cnt++;
} else {
canUseFastPath = false;
}
}
}
if (!canUseFastPath)
ctx.forbidFastPath(true);
// do nothing, getZ result will do the trick
return static_cast<int>(ctx.width());
} else {
@ -175,8 +219,6 @@ namespace nd4j {
ShapeList inSha;
int results = 0;
bool canUseFastPath = true;
if (Environment::getInstance()->isProfiling() && node != nullptr)
inputStart = std::chrono::system_clock::now();
@ -1007,7 +1049,7 @@ namespace nd4j {
if (status != ND4J_STATUS_OK)
return arrayList;
if (!isInplace) {
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
std::pair<int, int> pair(1, e);
if (variableSpace.hasVariable(pair)) {
@ -1023,6 +1065,11 @@ namespace nd4j {
} else
break;
}
} else {
for (auto v:inputs) {
arrayList->push_back(v);
}
}
return arrayList;
}