transparent conversion to FastPath execution within Graph (#278)
Signed-off-by: raver119 <raver119@gmail.com>master
parent
353f901c7c
commit
31e3a2f7a5
|
@ -69,6 +69,9 @@ namespace nd4j {
|
||||||
|
|
||||||
// in some cases we might be able to skip shape function for validation purposes
|
// in some cases we might be able to skip shape function for validation purposes
|
||||||
bool _shapeFunctionOverride = false;
|
bool _shapeFunctionOverride = false;
|
||||||
|
|
||||||
|
// special flag used during conversion from Graph exec to FastPath exec
|
||||||
|
bool _forbidFastPath = false;
|
||||||
public:
|
public:
|
||||||
Context(ContextPrototype* prototype, VariableSpace* variableSpace);
|
Context(ContextPrototype* prototype, VariableSpace* variableSpace);
|
||||||
|
|
||||||
|
@ -176,11 +179,17 @@ namespace nd4j {
|
||||||
|
|
||||||
// methods used in java interop
|
// methods used in java interop
|
||||||
/**
|
/**
|
||||||
* This method checks, if Context uses fastpath variable access
|
* This method checks if Context uses fastpath variable access
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
bool isFastPath();
|
bool isFastPath();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Method allows to forbid FastPath execution
|
||||||
|
* @param reallyForbid
|
||||||
|
*/
|
||||||
|
void forbidFastPath(bool reallyForbid);
|
||||||
|
|
||||||
#ifndef __JAVACPP_HACK__
|
#ifndef __JAVACPP_HACK__
|
||||||
std::vector<NDArray*>& fastpath_in();
|
std::vector<NDArray*>& fastpath_in();
|
||||||
std::vector<NDArray*>& fastpath_out();
|
std::vector<NDArray*>& fastpath_out();
|
||||||
|
|
|
@ -136,7 +136,19 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Context::isFastPath() {
|
bool Context::isFastPath() {
|
||||||
return !(_fastpath_in.empty() && _fastpath_out.empty());
|
auto ie = _fastpath_in.empty();
|
||||||
|
auto io = _fastpath_out.empty();
|
||||||
|
// two options here.
|
||||||
|
// either both IN/OUT are filled
|
||||||
|
auto b1 = (!ie && !io);
|
||||||
|
|
||||||
|
// or at least something is filled, and FastPath is NOT forbidden
|
||||||
|
auto b2 = (!ie || !io) && !_forbidFastPath;
|
||||||
|
return b1 || b2;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Context::forbidFastPath(bool reallyForbid) {
|
||||||
|
_forbidFastPath = reallyForbid;
|
||||||
}
|
}
|
||||||
|
|
||||||
VariableSpace *Context::getVariableSpace() {
|
VariableSpace *Context::getVariableSpace() {
|
||||||
|
|
|
@ -142,6 +142,8 @@ namespace nd4j {
|
||||||
NodeProfile *node = nullptr;
|
NodeProfile *node = nullptr;
|
||||||
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
|
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
|
||||||
|
|
||||||
|
auto fp = ctx.isFastPath();
|
||||||
|
|
||||||
if (Environment::getInstance()->isProfiling()) {
|
if (Environment::getInstance()->isProfiling()) {
|
||||||
if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) {
|
if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) {
|
||||||
prof = ctx.getVariableSpace()->flowPath()->profile();
|
prof = ctx.getVariableSpace()->flowPath()->profile();
|
||||||
|
@ -170,20 +172,22 @@ namespace nd4j {
|
||||||
return static_cast<int>(ctx.width());
|
return static_cast<int>(ctx.width());
|
||||||
} else {
|
} else {
|
||||||
// if op is not inplace - we should pre-allocate arrays
|
// if op is not inplace - we should pre-allocate arrays
|
||||||
|
|
||||||
ShapeList inSha;
|
ShapeList inSha;
|
||||||
int results = 0;
|
int results = 0;
|
||||||
|
|
||||||
|
bool canUseFastPath = true;
|
||||||
|
|
||||||
if (Environment::getInstance()->isProfiling() && node != nullptr)
|
if (Environment::getInstance()->isProfiling() && node != nullptr)
|
||||||
inputStart = std::chrono::system_clock::now();
|
inputStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
int cntIn = 0;
|
int cntIn = 0;
|
||||||
// we build list of input shapes
|
// we build list of input shapes
|
||||||
if (ctx.isFastPath()) {
|
if (fp) {
|
||||||
for (const auto p:ctx.fastpath_in()) {
|
for (const auto p:ctx.fastpath_in()) {
|
||||||
inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo());
|
inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
int arrCnt = 0;
|
||||||
for (auto p: *ctx.inputs()) {
|
for (auto p: *ctx.inputs()) {
|
||||||
auto var = ctx.variable(p);
|
auto var = ctx.variable(p);
|
||||||
if (var->variableType() == VariableType::NDARRAY) {
|
if (var->variableType() == VariableType::NDARRAY) {
|
||||||
|
@ -192,13 +196,19 @@ namespace nd4j {
|
||||||
throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p);
|
throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p);
|
||||||
|
|
||||||
inSha.push_back(array->getShapeInfo());
|
inSha.push_back(array->getShapeInfo());
|
||||||
|
|
||||||
|
// we're also filling ctx with arrays
|
||||||
|
if (canUseFastPath)
|
||||||
|
ctx.setInputArray(arrCnt++, array);
|
||||||
|
} else {
|
||||||
|
canUseFastPath = false;
|
||||||
}
|
}
|
||||||
cntIn++;
|
cntIn++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we override shape function, we'll return size of fastPath
|
// if we override shape function, we'll return size of fastPath
|
||||||
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
|
if (fp && ctx.shapeFunctionOverride()) {
|
||||||
return (int) ctx.fastpath_out().size();
|
return (int) ctx.fastpath_out().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -232,8 +242,9 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
int cnt = 0;
|
int cnt = 0;
|
||||||
|
|
||||||
for (auto out: *outSha->asVector()) {
|
for (auto out: *outSha->asVector()) {
|
||||||
if (!ctx.isFastPath()) {
|
if (!fp) {
|
||||||
// we need to check, if Z is really needed
|
// we need to check, if Z is really needed
|
||||||
std::pair<int, int> pair(ctx.nodeId(), cnt++);
|
std::pair<int, int> pair(ctx.nodeId(), cnt++);
|
||||||
|
|
||||||
|
@ -244,11 +255,17 @@ namespace nd4j {
|
||||||
auto outArr = new NDArray(out, true, ctx.launchContext());
|
auto outArr = new NDArray(out, true, ctx.launchContext());
|
||||||
|
|
||||||
ctx.pushNDArrayToVariableSpace(pair, outArr);
|
ctx.pushNDArrayToVariableSpace(pair, outArr);
|
||||||
|
|
||||||
|
if (canUseFastPath)
|
||||||
|
ctx.setOutputArray(pair.second, outArr);
|
||||||
} else {
|
} else {
|
||||||
// validate/compare shapes here. existent vs provided in outSha
|
// validate/compare shapes here. existent vs provided in outSha
|
||||||
auto var = ctx.variable(pair);
|
auto var = ctx.variable(pair);
|
||||||
auto shape = var->getNDArray()->shapeInfo();
|
auto shape = var->getNDArray()->shapeInfo();
|
||||||
|
|
||||||
|
if (canUseFastPath)
|
||||||
|
ctx.setOutputArray(pair.second, var->getNDArray());
|
||||||
|
|
||||||
if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
|
if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
|
||||||
auto eShape = ShapeUtils::shapeAsString(out);
|
auto eShape = ShapeUtils::shapeAsString(out);
|
||||||
auto aShape = ShapeUtils::shapeAsString(shape);
|
auto aShape = ShapeUtils::shapeAsString(shape);
|
||||||
|
@ -289,20 +306,13 @@ namespace nd4j {
|
||||||
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
|
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
|
||||||
throw std::runtime_error("Expected vs provided shape mismatch");
|
throw std::runtime_error("Expected vs provided shape mismatch");
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
* FIXME: we want to uncomment this eventually, and check data types equality
|
|
||||||
//checking out data type equality
|
|
||||||
if (ArrayOptions::dataType(out) != array->dataType()) {
|
|
||||||
std::string msg = "Provided array [" + StringUtils::valueToString<int>(idx) + "] has unexpected data type";
|
|
||||||
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType());
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//outSha->destroy();
|
if (!canUseFastPath)
|
||||||
|
ctx.forbidFastPath(true);
|
||||||
|
|
||||||
delete outSha;
|
delete outSha;
|
||||||
|
|
||||||
// saving arrayTime
|
// saving arrayTime
|
||||||
|
|
Loading…
Reference in New Issue