From b686368b82027a756572799182b14c97f8892b44 Mon Sep 17 00:00:00 2001 From: Oleh Date: Wed, 26 Feb 2020 09:20:39 +0200 Subject: [PATCH] Refactoring split operation (#266) * libnd4j moved split operation implementation to helpers before special case adding Signed-off-by: Oleg * libnd4j minor fixes for general split operation move, merge master Signed-off-by: Oleg * libndj4 split cpu implementation Signed-off-by: Oleg * - provide cuda helper for split op Signed-off-by: Yurii * - minor correction Signed-off-by: Yurii * - minor correction 2 Signed-off-by: Yurii * libnd4j moved split implementation from specials to split.cpp Signed-off-by: Oleg * libnd4j update loopkind selections for 3D, 4D and 5D cases Signed-off-by: Oleg * libnd4j removed unnecessary BUILD_SINGLE_TEMPLATE Signed-off-by: Oleg Co-authored-by: Yurii Shyrma --- libnd4j/include/helpers/LoopKind.h | 22 +++- .../ops/declarable/helpers/cpu/split.cpp | 109 +++++++++++++++--- libnd4j/include/ops/impl/specials_single.hpp | 94 --------------- libnd4j/include/ops/specials.h | 1 - 4 files changed, 112 insertions(+), 114 deletions(-) diff --git a/libnd4j/include/helpers/LoopKind.h b/libnd4j/include/helpers/LoopKind.h index d97f3b225..541e6e5a7 100644 --- a/libnd4j/include/helpers/LoopKind.h +++ b/libnd4j/include/helpers/LoopKind.h @@ -107,12 +107,22 @@ LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, c bool bNotCommonVectorCase = (countUnityDimsInY != yRank - 1) && (countUnityDimsInX != xRank - 1); - if (3 == xRank && bNDLoopsRanks && bNotCommonVectorCase) - return nd4j::LoopKind::BROADCAST_3D; - if (4 == xRank && bNDLoopsRanks && bNotCommonVectorCase) - return nd4j::LoopKind::BROADCAST_4D; - if (5 == xRank && bNDLoopsRanks && bNotCommonVectorCase) - return nd4j::LoopKind::BROADCAST_5D; + + if (bNDLoopsRanks && bNotCommonVectorCase) { + // case x[3,4,5] * y[1,4,5] = z[3,4,5] or reverse x[1,4,5] + y[3,4,5] = z[3,4,5] + if (nd4j::LoopKind::EWS1 == deduceKindOfLoopXYZ(xShapeInfo, yShapeInfo, zShapeInfo) + && (1 == shape::sizeAt(yShapeInfo, 0) || 1 == shape::sizeAt(xShapeInfo, 0))) { + return EWS1; + } + + if (3 == xRank) + return nd4j::LoopKind::BROADCAST_3D; + if (4 == xRank) + return nd4j::LoopKind::BROADCAST_4D; + if (5 == xRank) + return nd4j::LoopKind::BROADCAST_5D; + + } if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/split.cpp b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp index bdae61f16..5c9c2bbf7 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/split.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/split.cpp @@ -18,26 +18,109 @@ // @author Oleh Semeniv (oleg.semeniv@gmail.com) // - #include -#include +#include namespace nd4j { namespace ops { namespace helpers { -////////////////////////////////////////////////////////////////////////// -template -static void split_(const NDArray& input, const std::vector& outArrs, const int axis) { - nd4j::SpecialMethods::splitCpuGeneric(input, outArrs, axis); -} -void split(nd4j::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { - BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), LIBND4J_TYPES); -} + ////////////////////////////////////////////////////////////////////////// + template + static void split_(const NDArray& input, const std::vector& outArrs, const int axis) { + int numSplits = outArrs.size(); -BUILD_SINGLE_TEMPLATE(template void split_, (const NDArray& input, const std::vector& outArrs, const int axis), LIBND4J_TYPES); + const auto sizeofT = input.sizeOfT(); + T* xBuff = input.bufferAsT(); + + bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; + + if (luckCase1) { + for (uint i = 0; i < numSplits; ++i) { + luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; + if (!luckCase1) + break; + } + } + + if (luckCase1) { + + T* x = const_cast(xBuff); + for (uint i = 0; i < numSplits; ++i) { + const auto memAmountToCopy = outArrs[i]->lengthOf(); + memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); + x += memAmountToCopy; + } + return; + } + + const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c'; + bool areOutsContin = true; + bool allSameOrder = true; + + if (isXcontin) { + for (uint i = 0; i < numSplits; ++i) { + areOutsContin &= outArrs[i]->strideAt(axis) == 1; + allSameOrder &= outArrs[i]->ordering() == input.ordering(); + if (!areOutsContin || !allSameOrder) + break; + } + } + + const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; + + if (luckCase2) { + + const uint xDim = input.sizeAt(axis); + + for (uint i = 0; i < input.lengthOf() / xDim; ++i) { + + T* x = xBuff + xDim * i; + + for (uint j = 0; j < numSplits; ++j) { + const auto zDim = outArrs[j]->sizeAt(axis); + T* z = outArrs[j]->bufferAsT() + zDim * i; + memcpy(z, x, zDim * sizeofT); + z += zDim; + x += zDim; + } + } + + return; + } + + uint zDim = outArrs[0]->sizeAt(axis); + // general case + + auto func = PRAGMA_THREADS_FOR{ + + Nd4jLong coords[MAX_RANK]; + for (auto i = start; i < stop; i += increment) { + + shape::index2coords(i, input.getShapeInfo(), coords); + const auto xOffset = shape::getOffset(input.getShapeInfo(), coords); + + uint outArrIdx = 0; + + while (coords[axis] >= zDim) { + coords[axis] -= zDim; + ++outArrIdx; + } + + T* z = outArrs[outArrIdx]->bufferAsT(); + const auto zOffset = shape::getOffset(outArrs[outArrIdx]->getShapeInfo(), coords); + z[zOffset] = xBuff[xOffset]; + } + }; + + samediff::Threads::parallel_for(func, 0, input.lengthOf()); + } + + void split(nd4j::LaunchContext* context, const NDArray& input, std::vector& outArrs, const int axis) { + BUILD_SINGLE_SELECTOR(input.dataType(), split_, (input, outArrs, axis), LIBND4J_TYPES); + } + } + } } -} -} \ No newline at end of file diff --git a/libnd4j/include/ops/impl/specials_single.hpp b/libnd4j/include/ops/impl/specials_single.hpp index 779cb5c2a..ad4c96e7c 100644 --- a/libnd4j/include/ops/impl/specials_single.hpp +++ b/libnd4j/include/ops/impl/specials_single.hpp @@ -217,100 +217,6 @@ void SpecialMethods::concatCpuGeneric(int dimension, int numArrays, Nd4jPoint delete inputs[i]; } - -template -void SpecialMethods::splitCpuGeneric(const NDArray& input, const std::vector& outArrs, const int axis) { - - int numSplits = outArrs.size(); - - const auto sizeofT = input.sizeOfT(); - - T* xBuff = input.bufferAsT(); - - bool luckCase1 = ((axis == 0 && input.ordering() == 'c') || (axis == input.rankOf() - 1 && input.ordering() == 'f')) && input.ews() == 1; - - if (luckCase1) { - for (uint i = 0; i < numSplits; ++i) { - luckCase1 &= outArrs[i]->ordering() == input.ordering() && outArrs[i]->ews() == 1; - if (!luckCase1) - break; - } - } - - if (luckCase1) { - - T* x = const_cast(xBuff); - for (uint i = 0; i < numSplits; ++i) { - const auto memAmountToCopy = outArrs[i]->lengthOf(); - memcpy(outArrs[i]->bufferAsT(), x, memAmountToCopy * sizeofT); - x += memAmountToCopy; - } - return; - } - - const bool isXcontin = input.strideAt(axis) == 1 && input.ordering() == 'c'; - bool areOutsContin = true; - bool allSameOrder = true; - - if (isXcontin) { - for (uint i = 0; i < numSplits; ++i) { - areOutsContin &= outArrs[i]->strideAt(axis) == 1; - allSameOrder &= outArrs[i]->ordering() == input.ordering(); - if (!areOutsContin || !allSameOrder) - break; - } - } - - const bool luckCase2 = isXcontin && areOutsContin && allSameOrder; - - if (luckCase2) { - - const uint xDim = input.sizeAt(axis); - - for (uint i = 0; i < input.lengthOf() / xDim; ++i) { - - T* x = xBuff + xDim * i; - - for (uint j = 0; j < numSplits; ++j) { - const auto zDim = outArrs[j]->sizeAt(axis); - T* z = outArrs[j]->bufferAsT() + zDim * i; - memcpy(z, x, zDim * sizeofT); - z += zDim; - x += zDim; - } - } - - return; - } - - uint zDim = outArrs[0]->sizeAt(axis); - // general case - - auto func = PRAGMA_THREADS_FOR{ - - Nd4jLong coords[MAX_RANK]; - for (auto i = start; i < stop; i += increment) { - - shape::index2coords(i, input.getShapeInfo(), coords); - const auto xOffset = shape::getOffset(input.getShapeInfo(), coords); - - uint outArrIdx = 0; - - while (coords[axis] >= zDim) { - coords[axis] -= zDim; - ++outArrIdx; - } - - T* z = outArrs[outArrIdx]->bufferAsT(); - const auto zOffset = shape::getOffset(outArrs[outArrIdx]->getShapeInfo(), coords); - z[zOffset] = xBuff[xOffset]; - } - }; - - samediff::Threads::parallel_for(func, 0, input.lengthOf()); -} - - /** * This kernel accumulates X arrays, and stores result into Z * diff --git a/libnd4j/include/ops/specials.h b/libnd4j/include/ops/specials.h index fea31cf6f..94fce8477 100644 --- a/libnd4j/include/ops/specials.h +++ b/libnd4j/include/ops/specials.h @@ -68,7 +68,6 @@ namespace nd4j { static void decodeBitmapGeneric(void *dx, Nd4jLong N, void *dz, Nd4jLong *zShapeInfo); static Nd4jLong encodeBitmapGeneric(void *dx, Nd4jLong *zShapeInfo, Nd4jLong N, int *dz, float threshold); - static void splitCpuGeneric(const NDArray& input, const std::vector& outArrs, const int axis); }; template