parent
330a69d4e2
commit
5332ace32b
|
@ -53,11 +53,11 @@ namespace nd4j {
|
|||
|
||||
virtual std::vector<Variable*> getVariables();
|
||||
|
||||
virtual void putVariable(std::pair<int,int>& pair, NDArray *array);
|
||||
virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
|
||||
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
|
||||
virtual void putVariable(int id, Variable *variable);
|
||||
virtual void putVariable(int id, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray *array);
|
||||
virtual Variable* putVariable(int id, int idx, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray &array);
|
||||
virtual void putVariable(int id, int idx, Variable *array);
|
||||
|
||||
|
|
|
@ -95,11 +95,11 @@ namespace nd4j {
|
|||
|
||||
virtual std::vector<Variable*> getVariables();
|
||||
|
||||
virtual void putVariable(std::pair<int,int>& pair, NDArray *array);
|
||||
virtual Variable* putVariable(std::pair<int,int>& pair, NDArray *array);
|
||||
virtual void putVariable(std::pair<int,int>& pair, Variable *variable);
|
||||
virtual void putVariable(int id, Variable *variable);
|
||||
virtual void putVariable(int id, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray *array);
|
||||
virtual Variable* putVariable(int id, int idx, NDArray *array);
|
||||
virtual void putVariable(int id, int idx, NDArray &array);
|
||||
virtual void putVariable(int id, int idx, Variable *array);
|
||||
|
||||
|
|
|
@ -140,7 +140,7 @@ namespace nd4j {
|
|||
auto io = _fastpath_out.empty();
|
||||
// two options here.
|
||||
// either both IN/OUT are filled
|
||||
auto b1 = (!ie && !io);
|
||||
auto b1 = (!ie && !io) || (!ie && _isInplace);
|
||||
|
||||
// or at least something is filled, and FastPath is NOT forbidden
|
||||
auto b2 = (!ie || !io) && !_forbidFastPath;
|
||||
|
|
|
@ -172,8 +172,8 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
|
||||
void VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
||||
_current->putVariable(pair, array);
|
||||
Variable* VariableProxy::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
||||
return _current->putVariable(pair, array);
|
||||
}
|
||||
|
||||
|
||||
|
@ -195,8 +195,8 @@ namespace nd4j {
|
|||
_current->putVariable(id, idx, array);
|
||||
}
|
||||
|
||||
void VariableProxy::putVariable(int id, int idx, NDArray *array) {
|
||||
_current->putVariable(id, idx, array);
|
||||
Variable* VariableProxy::putVariable(int id, int idx, NDArray *array) {
|
||||
return _current->putVariable(id, idx, array);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -200,14 +200,15 @@ namespace nd4j {
|
|||
return externalMemory() + internalMemory();
|
||||
}
|
||||
|
||||
void nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
||||
Variable* nd4j::graph::VariableSpace::putVariable(std::pair<int,int>& pair, NDArray *array) {
|
||||
auto variable = new Variable(array, nullptr, pair.first, pair.second);
|
||||
this->putVariable(pair, variable);
|
||||
return variable;
|
||||
}
|
||||
|
||||
void nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
|
||||
Variable* nd4j::graph::VariableSpace::putVariable(int node, int idx, NDArray *array) {
|
||||
std::pair<int, int> pair(node, idx);
|
||||
this->putVariable(pair, array);
|
||||
return this->putVariable(pair, array);
|
||||
}
|
||||
|
||||
void nd4j::graph::VariableSpace::putVariable(int node, int idx, Variable *variable) {
|
||||
|
|
|
@ -68,7 +68,7 @@ namespace nd4j {
|
|||
private:
|
||||
std::mutex _registrator;
|
||||
bool _registered = false;
|
||||
|
||||
std::string _name;
|
||||
protected:
|
||||
OpDescriptor *_descriptor;
|
||||
NDArray *_scalar = nullptr;
|
||||
|
|
|
@ -46,6 +46,7 @@ namespace ops {
|
|||
axis = *block.getIArguments();
|
||||
|
||||
if(axis.empty()) { // do not perform reversion
|
||||
if (!block.isInplace())
|
||||
output->assign(input);
|
||||
}
|
||||
else {
|
||||
|
|
|
@ -53,22 +53,27 @@ namespace nd4j {
|
|||
|
||||
DeclarableOp::DeclarableOp(const char *name, bool isLogical) {
|
||||
_descriptor = new OpDescriptor(name, isLogical);
|
||||
_name = name;
|
||||
}
|
||||
|
||||
DeclarableOp::DeclarableOp(const char *name, int numInputs, bool scalar) {
|
||||
_descriptor = new OpDescriptor(numInputs, name, scalar);
|
||||
_name = name;
|
||||
}
|
||||
|
||||
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace) {
|
||||
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace);
|
||||
_name = opName;
|
||||
}
|
||||
|
||||
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, bool divergent) {
|
||||
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, divergent);
|
||||
_name = opName;
|
||||
}
|
||||
|
||||
DeclarableOp::DeclarableOp(int numInputs, int numOutputs, const char *opName, bool allowsInplace, int tArgs, int iArgs) {
|
||||
_descriptor = new OpDescriptor(numInputs, numOutputs, opName, allowsInplace, tArgs, iArgs);
|
||||
_name = opName;
|
||||
}
|
||||
|
||||
DeclarableOp::~DeclarableOp() {
|
||||
|
@ -141,6 +146,7 @@ namespace nd4j {
|
|||
GraphProfile *prof = nullptr;
|
||||
NodeProfile *node = nullptr;
|
||||
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
|
||||
bool canUseFastPath = true;
|
||||
|
||||
auto fp = ctx.isFastPath();
|
||||
|
||||
|
@ -153,7 +159,7 @@ namespace nd4j {
|
|||
|
||||
if (ctx.isInplace()) {
|
||||
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
||||
if (ctx.isFastPath()) {
|
||||
if (fp) {
|
||||
//
|
||||
} else {
|
||||
for (auto p: *ctx.inputs()) {
|
||||
|
@ -168,6 +174,44 @@ namespace nd4j {
|
|||
}
|
||||
}
|
||||
|
||||
// if that's not fp, we can still propagate inputs and outputs
|
||||
if (!fp) {
|
||||
int cnt = 0;
|
||||
auto id = ctx.nodeId();
|
||||
auto vs = ctx.getVariableSpace();
|
||||
for (auto p: *ctx.inputs()) {
|
||||
auto var = ctx.variable(p);
|
||||
if (var->variableType() == VariableType::NDARRAY) {
|
||||
NDArray *array = var->getNDArray();
|
||||
ctx.setInputArray(cnt, array);
|
||||
ctx.setOutputArray(cnt, array);
|
||||
|
||||
|
||||
// in case of this override we might need to update outputs in the Graph VariableSpace as well
|
||||
if (vs != nullptr) {
|
||||
if (vs->hasVariable(id, cnt)) {
|
||||
auto v2 = vs->getVariable(id, cnt);
|
||||
if (!v2->hasNDArray()) {
|
||||
v2->setNDArray(array);
|
||||
v2->markRemovable(false);
|
||||
|
||||
}
|
||||
} else {
|
||||
auto v2 = vs->putVariable(id, cnt, array);
|
||||
v2->markRemovable(false);
|
||||
}
|
||||
}
|
||||
|
||||
cnt++;
|
||||
} else {
|
||||
canUseFastPath = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!canUseFastPath)
|
||||
ctx.forbidFastPath(true);
|
||||
|
||||
// do nothing, getZ result will do the trick
|
||||
return static_cast<int>(ctx.width());
|
||||
} else {
|
||||
|
@ -175,8 +219,6 @@ namespace nd4j {
|
|||
ShapeList inSha;
|
||||
int results = 0;
|
||||
|
||||
bool canUseFastPath = true;
|
||||
|
||||
if (Environment::getInstance()->isProfiling() && node != nullptr)
|
||||
inputStart = std::chrono::system_clock::now();
|
||||
|
||||
|
@ -1007,15 +1049,15 @@ namespace nd4j {
|
|||
if (status != ND4J_STATUS_OK)
|
||||
return arrayList;
|
||||
|
||||
|
||||
if (!isInplace) {
|
||||
for (int e = 0; e < DataTypeUtils::max<int>(); e++) {
|
||||
std::pair<int,int> pair(1, e);
|
||||
std::pair<int, int> pair(1, e);
|
||||
if (variableSpace.hasVariable(pair)) {
|
||||
auto var = variableSpace.getVariable(pair);
|
||||
auto arr = var->getNDArray();
|
||||
if (!arr->isAttached()) {
|
||||
var->markRemovable(false);
|
||||
arr->setContext(nd4j::LaunchContext ::defaultContext());
|
||||
arr->setContext(nd4j::LaunchContext::defaultContext());
|
||||
arrayList->push_back(arr);
|
||||
} else {
|
||||
arrayList->push_back(arr->detach());
|
||||
|
@ -1023,6 +1065,11 @@ namespace nd4j {
|
|||
} else
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
for (auto v:inputs) {
|
||||
arrayList->push_back(v);
|
||||
}
|
||||
}
|
||||
|
||||
return arrayList;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue