parent
330a69d4e2
commit
5332ace32b
|
@ -53,11 +53,11 @@ namespace nd4j {
|
||||||
|
|
||||||
virtual std::vector<Variable*> getVariables();
|
virtual std::vector<Variable*> getVariables();
|
||||||
|
|
||||||
virtual void putVariable(std::pair<int,int>& pair, NDArray *array);
|
virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
|
||||||
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
|
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
|
||||||
virtual void putVariable(int id, Variable *variable);
|
virtual void putVariable(int id, Variable *variable);
|
||||||
virtual void putVariable(int id, NDArray *array);
|
virtual void putVariable(int id, NDArray *array);
|
||||||
virtual void putVariable(int id, int idx, NDArray *array);
|
virtual Variable* putVariable(int id, int idx, NDArray *array);
|
||||||
virtual void putVariable(int id, int idx, NDArray &array);
|
virtual void putVariable(int id, int idx, NDArray &array);
|
||||||
virtual void putVariable(int id, int idx, Variable *array);
|
virtual void putVariable(int id, int idx, Variable *array);
|
||||||
|
|
||||||
|
|
|
@ -95,11 +95,11 @@ namespace nd4j {
|
||||||
|
|
||||||
virtual std::vector<Variable*> getVariables();
|
virtual std::vector<Variable*> getVariables();
|
||||||
|
|
||||||
virtual void putVariable(std::pair<int,int>& pair, NDArray *array);
|
virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
|
||||||
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
|
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
|
||||||
virtual void putVariable(int id, Variable *variable);
|
virtual void putVariable(int id, Variable *variable);
|
||||||
virtual void putVariable(int id, NDArray *array);
|
virtual void putVariable(int id, NDArray *array);
|
||||||
virtual void putVariable(int id, int idx, NDArray *array);
|
virtual Variable* putVariable(int id, int idx, NDArray *array);
|
||||||
virtual void putVariable(int id, int idx, NDArray &array);
|
virtual void putVariable(int id, int idx, NDArray &array);
|
||||||
virtual void putVariable(int id, int idx, Variable *array);
|
virtual void putVariable(int id, int idx, Variable *array);
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ namespace nd4j {
|
||||||
auto io = _fastpath_out.empty();
|
auto io = _fastpath_out.empty();
|
||||||
// two options here.
|
// two options here.
|
||||||
// either both IN/OUT are filled
|
// either both IN/OUT are filled
|
||||||
auto b1 = (!ie && !io);
|
auto b1 = (!ie && !io) || (!ie && _isInplace);
|
||||||
|
|
||||||
// or at least something is filled, and FastPath is NOT forbidden
|
// or at least something is filled, and FastPath is NOT forbidden
|
||||||
auto b2 = (!ie || !io) && !_forbidFastPath;
|
auto b2 = (!ie || !io) && !_forbidFastPath;
|
||||||
|
|
|
@ -172,8 +172,8 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
Variable* VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
||||||
_current->putVariable(pair, array);
|
return _current->putVariable(pair, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -195,8 +195,8 @@ namespace nd4j {
|
||||||
_current->putVariable(id, idx, array);
|
_current->putVariable(id, idx, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
void VariableProxy::putVariable(int id, int idx, NDArray *array) {
|
Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) {
|
||||||
_current->putVariable(id, idx, array);
|
return _current->putVariable(id, idx, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -200,14 +200,15 @@ namespace nd4j {
|
||||||
return externalMemory() + internalMemory();
|
return externalMemory() + internalMemory();
|
||||||
}
|
}
|
||||||
|
|
||||||
void nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
Variable* nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
||||||
auto variable = new Variable(array, nullptr, pair.first, pair.second);
|
auto variable = new Variable(array, nullptr, pair.first, pair.second);
|
||||||
this->putVariable(pair, variable);
|
this->putVariable(pair, variable);
|
||||||
|
return variable;
|
||||||
}
|
}
|
||||||
|
|
||||||
void nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
|
Variable* nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
|
||||||
std::pair<int, int> pair(node, idx);
|
std::pair<int, int> pair(node, idx);
|
||||||
this->putVariable(pair, array);
|
return this->putVariable(pair, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
void nd4j::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) {
|
void nd4j::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) {
|
||||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
||||||
private:
|
private:
|
||||||
std::mutex _registrator;
|
std::mutex _registrator;
|
||||||
bool _registered = false;
|
bool _registered = false;
|
||||||
|
std::string _name;
|
||||||
protected:
|
protected:
|
||||||
OpDescriptor *_descriptor;
|
OpDescriptor *_descriptor;
|
||||||
NDArray *_scalar = nullptr;
|
NDArray *_scalar = nullptr;
|
||||||
|
|
|
@ -46,6 +46,7 @@ namespace ops {
|
||||||
axis = *block.getIArguments();
|
axis = *block.getIArguments();
|
||||||
|
|
||||||
if(axis.empty()) { // do not perform reversion
|
if(axis.empty()) { // do not perform reversion
|
||||||
|
if (!block.isInplace())
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
|
@ -53,22 +53,27 @@ namespace nd4j {
|
||||||
|
|
||||||
DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
|
DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
|
||||||
_descriptor = new OpDescriptor(name, isLogical);
|
_descriptor = new OpDescriptor(name, isLogical);
|
||||||
|
_name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
|
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
|
||||||
_descriptor = new OpDescriptor(numInputs, name, scalar);
|
_descriptor = new OpDescriptor(numInputs, name, scalar);
|
||||||
|
_name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
|
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
|
||||||
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
|
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
|
||||||
|
_name = opName;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
|
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
|
||||||
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
|
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
|
||||||
|
_name = opName;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
|
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
|
||||||
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
|
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
|
||||||
|
_name = opName;
|
||||||
}
|
}
|
||||||
|
|
||||||
DeclarableOp::~DeclarableOp() {
|
DeclarableOp::~DeclarableOp() {
|
||||||
|
@ -141,6 +146,7 @@ namespace nd4j {
|
||||||
GraphProfile *prof = nullptr;
|
GraphProfile *prof = nullptr;
|
||||||
NodeProfile *node = nullptr;
|
NodeProfile *node = nullptr;
|
||||||
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
|
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
|
||||||
|
bool canUseFastPath = true;
|
||||||
|
|
||||||
auto fp = ctx.isFastPath();
|
auto fp = ctx.isFastPath();
|
||||||
|
|
||||||
|
@ -153,7 +159,7 @@ namespace nd4j {
|
||||||
|
|
||||||
if (ctx.isInplace()) {
|
if (ctx.isInplace()) {
|
||||||
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
||||||
if (ctx.isFastPath()) {
|
if (fp) {
|
||||||
//
|
//
|
||||||
} else {
|
} else {
|
||||||
for (auto p: *ctx.inputs()) {
|
for (auto p: *ctx.inputs()) {
|
||||||
|
@ -168,6 +174,44 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if that's not fp, we can still propagate inputs and outputs
|
||||||
|
if (!fp) {
|
||||||
|
int cnt = 0;
|
||||||
|
auto id = ctx.nodeId();
|
||||||
|
auto vs = ctx.getVariableSpace();
|
||||||
|
for (auto p: *ctx.inputs()) {
|
||||||
|
auto var = ctx.variable(p);
|
||||||
|
if (var->variableType() == VariableType::NDARRAY) {
|
||||||
|
NDArray *array = var->getNDArray();
|
||||||
|
ctx.setInputArray(cnt, array);
|
||||||
|
ctx.setOutputArray(cnt, array);
|
||||||
|
|
||||||
|
|
||||||
|
// in case of this override we might need to update outputs in the Graph VariableSpace as well
|
||||||
|
if (vs != nullptr) {
|
||||||
|
if (vs->hasVariable(id, cnt)) {
|
||||||
|
auto v2 = vs->getVariable(id, cnt);
|
||||||
|
if (!v2->hasNDArray()) {
|
||||||
|
v2->setNDArray(array);
|
||||||
|
v2->markRemovable(false);
|
||||||
|
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto v2 = vs->putVariable(id, cnt, array);
|
||||||
|
v2->markRemovable(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cnt++;
|
||||||
|
} else {
|
||||||
|
canUseFastPath = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!canUseFastPath)
|
||||||
|
ctx.forbidFastPath(true);
|
||||||
|
|
||||||
// do nothing, getZ result will do the trick
|
// do nothing, getZ result will do the trick
|
||||||
return static_cast<int>(ctx.width());
|
return static_cast<int>(ctx.width());
|
||||||
} else {
|
} else {
|
||||||
|
@ -175,8 +219,6 @@ namespace nd4j {
|
||||||
ShapeList inSha;
|
ShapeList inSha;
|
||||||
int results = 0;
|
int results = 0;
|
||||||
|
|
||||||
bool canUseFastPath = true;
|
|
||||||
|
|
||||||
if (Environment::getInstance()->isProfiling() && node != nullptr)
|
if (Environment::getInstance()->isProfiling() && node != nullptr)
|
||||||
inputStart = std::chrono::system_clock::now();
|
inputStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
|
@ -1007,15 +1049,15 @@ namespace nd4j {
|
||||||
if (status != ND4J_STATUS_OK)
|
if (status != ND4J_STATUS_OK)
|
||||||
return arrayList;
|
return arrayList;
|
||||||
|
|
||||||
|
if (!isInplace) {
|
||||||
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
||||||
std::pair<int,int> pair(1, e);
|
std::pair<int, int> pair(1, e);
|
||||||
if (variableSpace.hasVariable(pair)) {
|
if (variableSpace.hasVariable(pair)) {
|
||||||
auto var = variableSpace.getVariable(pair);
|
auto var = variableSpace.getVariable(pair);
|
||||||
auto arr = var->getNDArray();
|
auto arr = var->getNDArray();
|
||||||
if (!arr->isAttached()) {
|
if (!arr->isAttached()) {
|
||||||
var->markRemovable(false);
|
var->markRemovable(false);
|
||||||
arr->setContext(nd4j::LaunchContext ::defaultContext());
|
arr->setContext(nd4j::LaunchContext::defaultContext());
|
||||||
arrayList->push_back(arr);
|
arrayList->push_back(arr);
|
||||||
} else {
|
} else {
|
||||||
arrayList->push_back(arr->detach());
|
arrayList->push_back(arr->detach());
|
||||||
|
@ -1023,6 +1065,11 @@ namespace nd4j {
|
||||||
} else
|
} else
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (auto v:inputs) {
|
||||||
|
arrayList->push_back(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return arrayList;
|
return arrayList;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue