transparent conversion to FastPath execution within Graph (#278)
Signed-off-by: raver119 <raver119@gmail.com>master
parent
353f901c7c
commit
31e3a2f7a5
|
@ -69,6 +69,9 @@ namespace nd4j {
|
|||
|
||||
// in some cases we might be able to skip shape function for validation purposes
|
||||
bool _shapeFunctionOverride = false;
|
||||
|
||||
// special flag used during conversion from Graph exec to FastPath exec
|
||||
bool _forbidFastPath = false;
|
||||
public:
|
||||
Context(ContextPrototype* prototype, VariableSpace* variableSpace);
|
||||
|
||||
|
@ -176,11 +179,17 @@ namespace nd4j {
|
|||
|
||||
// methods used in java interop
|
||||
/**
|
||||
* This method checks, if Context uses fastpath variable access
|
||||
* This method checks if Context uses fastpath variable access
|
||||
* @return
|
||||
*/
|
||||
bool isFastPath();
|
||||
|
||||
/**
|
||||
* Method allows to forbid FastPath execution
|
||||
* @param reallyForbid
|
||||
*/
|
||||
void forbidFastPath(bool reallyForbid);
|
||||
|
||||
#ifndef __JAVACPP_HACK__
|
||||
std::vector<NDArray*>& fastpath_in();
|
||||
std::vector<NDArray*>& fastpath_out();
|
||||
|
|
|
@ -136,7 +136,19 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
bool Context::isFastPath() {
|
||||
return !(_fastpath_in.empty() && _fastpath_out.empty());
|
||||
auto ie = _fastpath_in.empty();
|
||||
auto io = _fastpath_out.empty();
|
||||
// two options here.
|
||||
// either both IN/OUT are filled
|
||||
auto b1 = (!ie && !io);
|
||||
|
||||
// or at least something is filled, and FastPath is NOT forbidden
|
||||
auto b2 = (!ie || !io) && !_forbidFastPath;
|
||||
return b1 || b2;
|
||||
}
|
||||
|
||||
void Context::forbidFastPath(bool reallyForbid) {
|
||||
_forbidFastPath = reallyForbid;
|
||||
}
|
||||
|
||||
VariableSpace *Context::getVariableSpace() {
|
||||
|
|
|
@ -142,6 +142,8 @@ namespace nd4j {
|
|||
NodeProfile *node = nullptr;
|
||||
std::chrono::time_point<std::chrono::system_clock> inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd;
|
||||
|
||||
auto fp = ctx.isFastPath();
|
||||
|
||||
if (Environment::getInstance()->isProfiling()) {
|
||||
if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) {
|
||||
prof = ctx.getVariableSpace()->flowPath()->profile();
|
||||
|
@ -170,20 +172,22 @@ namespace nd4j {
|
|||
return static_cast<int>(ctx.width());
|
||||
} else {
|
||||
// if op is not inplace - we should pre-allocate arrays
|
||||
|
||||
ShapeList inSha;
|
||||
int results = 0;
|
||||
|
||||
bool canUseFastPath = true;
|
||||
|
||||
if (Environment::getInstance()->isProfiling() && node != nullptr)
|
||||
inputStart = std::chrono::system_clock::now();
|
||||
|
||||
int cntIn = 0;
|
||||
// we build list of input shapes
|
||||
if (ctx.isFastPath()) {
|
||||
if (fp) {
|
||||
for (const auto p:ctx.fastpath_in()) {
|
||||
inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo());
|
||||
}
|
||||
} else {
|
||||
int arrCnt = 0;
|
||||
for (auto p: *ctx.inputs()) {
|
||||
auto var = ctx.variable(p);
|
||||
if (var->variableType() == VariableType::NDARRAY) {
|
||||
|
@ -192,13 +196,19 @@ namespace nd4j {
|
|||
throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p);
|
||||
|
||||
inSha.push_back(array->getShapeInfo());
|
||||
|
||||
// we're also filling ctx with arrays
|
||||
if (canUseFastPath)
|
||||
ctx.setInputArray(arrCnt++, array);
|
||||
} else {
|
||||
canUseFastPath = false;
|
||||
}
|
||||
cntIn++;
|
||||
}
|
||||
}
|
||||
|
||||
// if we override shape function, we'll return size of fastPath
|
||||
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
|
||||
if (fp && ctx.shapeFunctionOverride()) {
|
||||
return (int) ctx.fastpath_out().size();
|
||||
}
|
||||
|
||||
|
@ -232,8 +242,9 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
int cnt = 0;
|
||||
|
||||
for (auto out: *outSha->asVector()) {
|
||||
if (!ctx.isFastPath()) {
|
||||
if (!fp) {
|
||||
// we need to check, if Z is really needed
|
||||
std::pair<int, int> pair(ctx.nodeId(), cnt++);
|
||||
|
||||
|
@ -244,11 +255,17 @@ namespace nd4j {
|
|||
auto outArr = new NDArray(out, true, ctx.launchContext());
|
||||
|
||||
ctx.pushNDArrayToVariableSpace(pair, outArr);
|
||||
|
||||
if (canUseFastPath)
|
||||
ctx.setOutputArray(pair.second, outArr);
|
||||
} else {
|
||||
// validate/compare shapes here. existent vs provided in outSha
|
||||
auto var = ctx.variable(pair);
|
||||
auto shape = var->getNDArray()->shapeInfo();
|
||||
|
||||
if (canUseFastPath)
|
||||
ctx.setOutputArray(pair.second, var->getNDArray());
|
||||
|
||||
if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) {
|
||||
auto eShape = ShapeUtils::shapeAsString(out);
|
||||
auto aShape = ShapeUtils::shapeAsString(shape);
|
||||
|
@ -289,20 +306,13 @@ namespace nd4j {
|
|||
nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx);
|
||||
throw std::runtime_error("Expected vs provided shape mismatch");
|
||||
}
|
||||
|
||||
/*
|
||||
* FIXME: we want to uncomment this eventually, and check data types equality
|
||||
//checking out data type equality
|
||||
if (ArrayOptions::dataType(out) != array->dataType()) {
|
||||
std::string msg = "Provided array [" + StringUtils::valueToString<int>(idx) + "] has unexpected data type";
|
||||
throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType());
|
||||
}
|
||||
*/
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//outSha->destroy();
|
||||
if (!canUseFastPath)
|
||||
ctx.forbidFastPath(true);
|
||||
|
||||
delete outSha;
|
||||
|
||||
// saving arrayTime
|
||||
|
|
Loading…
Reference in New Issue