transparent conversion to FastPath execution within Graph (#278)

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-27 16:10:38 +03:00 committed by GitHub
parent 353f901c7c
commit 31e3a2f7a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 16 deletions

View File

@ -69,6 +69,9 @@ namespace nd4j {
// in some cases we might be able to skip shape function for validation purposes
bool _shapeFunctionOverride = false;
// special flag used during conversion from Graph exec to FastPath exec
bool _forbidFastPath = false;
public:
Context(ContextPrototype* prototype, VariableSpace* variableSpace);
@ -176,11 +179,17 @@ namespace nd4j {
// methods used in java interop
/**
* This method checks, if Context uses fastpath variable access
* This method checks if Context uses fastpath variable access
* @return
*/
bool isFastPath();
/**
* Method allows to forbid FastPath execution
* @param reallyForbid
*/
void forbidFastPath(bool reallyForbid);
#ifndef __JAVACPP_HACK__
std::vector<NDArray*>& fastpath_in();
std::vector<NDArray*>& fastpath_out();

View File

@ -136,7 +136,19 @@ namespace nd4j {
}
bool Context::isFastPath() {
return !(_fastpath_in.empty() && _fastpath_out.empty());
auto ie = _fastpath_in.empty();
auto io = _fastpath_out.empty();
// two options here.
// either both IN/OUT are filled
auto b1 = (!ie && !io);
// or at least something is filled, and FastPath is NOT forbidden
auto b2 = (!ie || !io) && !_forbidFastPath;
return b1 || b2;
}
void Context::forbidFastPath(bool reallyForbid) {
_forbidFastPath = reallyForbid;
}
VariableSpace *Context::getVariableSpace() {

View File

@ -142,6 +142,8 @@ namespace nd4j {
NodeProfile *node = nullptr;
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
auto fp = ctx.isFastPath();
if (Environment::getInstance()->isProfiling()) {
if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) {
prof = ctx.getVariableSpace()->flowPath()->profile();
@ -170,20 +172,22 @@ namespace nd4j {
return static_cast<int>(ctx.width());
} else {
// if op is not inplace - we should pre-allocate arrays
ShapeList inSha;
int results = 0;
bool canUseFastPath = true;
if (Environment::getInstance()->isProfiling() && node != nullptr)
inputStart = std::chrono::system_clock::now();
int cntIn = 0;
// we build list of input shapes
if (ctx.isFastPath()) {
if (fp) {
for (const auto p:ctx.fastpath_in()) {
inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo());
}
} else {
int arrCnt = 0;
for (auto p: *ctx.inputs()) {
auto var = ctx.variable(p);
if (var->variableType() == VariableType::NDARRAY) {
@ -192,13 +196,19 @@ namespace nd4j {
throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p);
inSha.push_back(array->getShapeInfo());
// we're also filling ctx with arrays
if (canUseFastPath)
ctx.setInputArray(arrCnt++, array);
} else {
canUseFastPath = false;
}
cntIn++;
}
}
// if we override shape function, we'll return size of fastPath
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
if (fp && ctx.shapeFunctionOverride()) {
return (int) ctx.fastpath_out().size();
}
@ -232,8 +242,9 @@ namespace nd4j {
}
int cnt = 0;
for (auto out: *outSha->asVector()) {
if (!ctx.isFastPath()) {
if (!fp) {
// we need to check, if Z is really needed
std::pair<int, int> pair(ctx.nodeId(), cnt++);
@ -244,11 +255,17 @@ namespace nd4j {
auto outArr = new NDArray(out, true, ctx.launchContext());
ctx.pushNDArrayToVariableSpace(pair, outArr);
if (canUseFastPath)
ctx.setOutputArray(pair.second, outArr);
} else {
// validate/compare shapes here. existent vs provided in outSha
auto var = ctx.variable(pair);
auto shape = var->getNDArray()->shapeInfo();
if (canUseFastPath)
ctx.setOutputArray(pair.second, var->getNDArray());
if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
auto eShape = ShapeUtils::shapeAsString(out);
auto aShape = ShapeUtils::shapeAsString(shape);
@ -289,20 +306,13 @@ namespace nd4j {
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
throw std::runtime_error("Expected vs provided shape mismatch");
}
/*
* FIXME: we want to uncomment this eventually, and check data types equality
//checking out data type equality
if (ArrayOptions::dataType(out) != array->dataType()) {
std::string msg = "Provided array [" + StringUtils::valueToString<int>(idx) + "] has unexpected data type";
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType());
}
*/
}
}
}
//outSha->destroy();
if (!canUseFastPath)
ctx.forbidFastPath(true);
delete outSha;
// saving arrayTime