From 31e3a2f7a51e20620afe52d314f29a8da75b07af Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 27 Feb 2020 16:10:38 +0300 Subject: [PATCH] transparent conversion to FastPath execution within Graph (#278) Signed-off-by: raver119 --- libnd4j/include/graph/Context.h | 11 +++++- libnd4j/include/graph/impl/Context.cpp | 14 ++++++- .../ops/declarable/impl/DeclarableOp.cpp | 38 ++++++++++++------- 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index d1e8a4dad..51f7bfa2b 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -69,6 +69,9 @@ namespace nd4j { // in some cases we might be able to skip shape function for validation purposes bool _shapeFunctionOverride = false; + + // special flag used during conversion from Graph exec to FastPath exec + bool _forbidFastPath = false; public: Context(ContextPrototype* prototype, VariableSpace* variableSpace); @@ -176,11 +179,17 @@ namespace nd4j { // methods used in java interop /** - * This method checks, if Context uses fastpath variable access + * This method checks if Context uses fastpath variable access * @return */ bool isFastPath(); + /** + * Method allows to forbid FastPath execution + * @param reallyForbid + */ + void forbidFastPath(bool reallyForbid); + #ifndef __JAVACPP_HACK__ std::vector& fastpath_in(); std::vector& fastpath_out(); diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 671d89c24..26f8875a8 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -136,7 +136,19 @@ namespace nd4j { } bool Context::isFastPath() { - return !(_fastpath_in.empty() && _fastpath_out.empty()); + auto ie = _fastpath_in.empty(); + auto io = _fastpath_out.empty(); + // two options here. + // either both IN/OUT are filled + auto b1 = (!ie && !io); + + // or at least something is filled, and FastPath is NOT forbidden + auto b2 = (!ie || !io) && !_forbidFastPath; + return b1 || b2; + } + + void Context::forbidFastPath(bool reallyForbid) { + _forbidFastPath = reallyForbid; } VariableSpace *Context::getVariableSpace() { diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index 9724b6ba5..6949bb4ed 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -142,6 +142,8 @@ namespace nd4j { NodeProfile *node = nullptr; std::chrono::time_point inputEnd, inputStart, shapeStart, shapeEnd, arrayStart, arrayEnd; + auto fp = ctx.isFastPath(); + if (Environment::getInstance()->isProfiling()) { if (ctx.getVariableSpace() != nullptr && ctx.getVariableSpace()->flowPath() != nullptr) { prof = ctx.getVariableSpace()->flowPath()->profile(); @@ -170,20 +172,22 @@ namespace nd4j { return static_cast(ctx.width()); } else { // if op is not inplace - we should pre-allocate arrays - ShapeList inSha; int results = 0; + bool canUseFastPath = true; + if (Environment::getInstance()->isProfiling() && node != nullptr) inputStart = std::chrono::system_clock::now(); int cntIn = 0; // we build list of input shapes - if (ctx.isFastPath()) { + if (fp) { for (const auto p:ctx.fastpath_in()) { inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo()); } } else { + int arrCnt = 0; for (auto p: *ctx.inputs()) { auto var = ctx.variable(p); if (var->variableType() == VariableType::NDARRAY) { @@ -192,13 +196,19 @@ namespace nd4j { throw unresolved_input_exception::build("Variable wasn't resolved prior shape calculation", p); inSha.push_back(array->getShapeInfo()); + + // we're also filling ctx with arrays + if (canUseFastPath) + ctx.setInputArray(arrCnt++, array); + } else { + canUseFastPath = false; } cntIn++; } } // if we override shape function, we'll return size of fastPath - if (ctx.isFastPath() && ctx.shapeFunctionOverride()) { + if (fp && ctx.shapeFunctionOverride()) { return (int) ctx.fastpath_out().size(); } @@ -232,8 +242,9 @@ namespace nd4j { } int cnt = 0; + for (auto out: *outSha->asVector()) { - if (!ctx.isFastPath()) { + if (!fp) { // we need to check, if Z is really needed std::pair pair(ctx.nodeId(), cnt++); @@ -244,11 +255,17 @@ namespace nd4j { auto outArr = new NDArray(out, true, ctx.launchContext()); ctx.pushNDArrayToVariableSpace(pair, outArr); + + if (canUseFastPath) + ctx.setOutputArray(pair.second, outArr); } else { // validate/compare shapes here. existent vs provided in outSha auto var = ctx.variable(pair); auto shape = var->getNDArray()->shapeInfo(); + if (canUseFastPath) + ctx.setOutputArray(pair.second, var->getNDArray()); + if (!shape::equalsSoft(out, shape) || shape::isEmpty(out) != shape::isEmpty(shape)) { auto eShape = ShapeUtils::shapeAsString(out); auto aShape = ShapeUtils::shapeAsString(shape); @@ -289,20 +306,13 @@ namespace nd4j { nd4j_printf("Expected vs provided shape mismatch %s vs %s at index %i\n", eShape.c_str(), aShape.c_str(), idx); throw std::runtime_error("Expected vs provided shape mismatch"); } - - /* - * FIXME: we want to uncomment this eventually, and check data types equality - //checking out data type equality - if (ArrayOptions::dataType(out) != array->dataType()) { - std::string msg = "Provided array [" + StringUtils::valueToString(idx) + "] has unexpected data type"; - throw nd4j::datatype_exception::build(msg, ArrayOptions::dataType(out), array->dataType()); - } - */ } } } - //outSha->destroy(); + if (!canUseFastPath) + ctx.forbidFastPath(true); + delete outSha; // saving arrayTime