diff --git a/libnd4j/blas/cpu/NativeOpExecutioner.cpp b/libnd4j/blas/cpu/NativeOpExecutioner.cpp index c155bd781..cbc224838 100644 --- a/libnd4j/blas/cpu/NativeOpExecutioner.cpp +++ b/libnd4j/blas/cpu/NativeOpExecutioner.cpp @@ -163,15 +163,32 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc, BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES); #else + auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo); + auto func = PRAGMA_THREADS_FOR { - BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES); }; - auto xLen = shape::length(hXShapeInfo); - auto yLen = shape::length(hYShapeInfo); - auto numTads = xLen / yLen; + Nd4jLong numTads = 0; + + switch (loopKind) { + case nd4j::LoopKind::BROADCAST_SCALAR_X: { + numTads = shape::length(hXShapeInfo); + } + break; + case nd4j::LoopKind::BROADCAST_SCALAR_Y: { + numTads = shape::length(hYShapeInfo); + } + break; + default: { + auto xLen = shape::length(hXShapeInfo); + auto yLen = shape::length(hYShapeInfo); + numTads = xLen / yLen; + } + } samediff::Threads::parallel_tad(func, 0, numTads); + #endif } diff --git a/libnd4j/include/helpers/LoopKind.h b/libnd4j/include/helpers/LoopKind.h index f8f8084c8..ddd1c95e5 100644 --- a/libnd4j/include/helpers/LoopKind.h +++ b/libnd4j/include/helpers/LoopKind.h @@ -37,12 +37,13 @@ namespace nd4j { class ND4J_EXPORT LoopKind { public: - enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON}; + enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y}; static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo); static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo); + static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo); }; @@ -82,6 +83,38 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd return COMMON; } +LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) { + auto xRank = shape::rank(xShapeInfo); + auto yRank = shape::rank(yShapeInfo); + auto zRank = shape::rank(zShapeInfo); + + auto xOrder = shape::order(xShapeInfo); + auto yOrder = shape::order(yShapeInfo); + auto zOrder = shape::order(zShapeInfo); + + auto xEws = shape::elementWiseStride(xShapeInfo); + auto yEws = shape::elementWiseStride(yShapeInfo); + auto zEws = shape::elementWiseStride(zShapeInfo); + + if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) { + // we validate that shapes are equal till the last dim + for (int e = 0; e < xRank - 1; e++) { + if (xShapeInfo[e+1] != yShapeInfo[e+1]) + return COMMON; + } + + // now, if one of the shapes has 1 as last dim + auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0; + + if (detect == 1) + return nd4j::LoopKind::BROADCAST_SCALAR_Y; + else if (detect == -1) + return nd4j::LoopKind::BROADCAST_SCALAR_X; + } + + return nd4j::LoopKind::COMMON; +} + ////////////////////////////////////////////////////////////////////////////// LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) { diff --git a/libnd4j/include/loops/broadcasting.h b/libnd4j/include/loops/broadcasting.h index a38e79c3f..ebf702004 100755 --- a/libnd4j/include/loops/broadcasting.h +++ b/libnd4j/include/loops/broadcasting.h @@ -40,6 +40,7 @@ #endif #include +#include #include "legacy_ops.h" @@ -122,6 +123,7 @@ namespace functions { Nd4jLong *tadOffset, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetZ, + nd4j::LoopKind::Kind loopKind, uint64_t start, uint64_t stop); @@ -149,6 +151,7 @@ namespace functions { Nd4jLong *tadOffset, Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetZ, + nd4j::LoopKind::Kind loopKind, uint64_t start, uint64_t stop); diff --git a/libnd4j/include/loops/cpu/broadcasting.hpp b/libnd4j/include/loops/cpu/broadcasting.hpp index 37dbf833f..691b95b83 100644 --- a/libnd4j/include/loops/cpu/broadcasting.hpp +++ b/libnd4j/include/loops/cpu/broadcasting.hpp @@ -75,6 +75,7 @@ namespace functions { Nd4jLong *xTadOffset, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffset, + nd4j::LoopKind::Kind loopKind, uint64_t start, uint64_t stop) { DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x, @@ -88,7 +89,7 @@ namespace functions { xTadShapeInfo, xTadOffset, zTadShapeInfo, - zTadOffset, start, stop), BROADCAST_OPS); + zTadOffset, loopKind, start, stop), BROADCAST_OPS); } template @@ -105,6 +106,7 @@ namespace functions { Nd4jLong *xTadOffset, Nd4jLong *zTadShapeInfo, Nd4jLong *zTadOffset, + nd4j::LoopKind::Kind loopKind, uint64_t start, uint64_t stop) { @@ -142,7 +144,7 @@ namespace functions { auto yEws = shape::elementWiseStride(yShapeInfo); auto zEws = shape::elementWiseStride(zTadShapeInfo); - const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); + const nd4j::LoopKind::Kind kindOfLoop = loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X || loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo); if (kindOfLoop == nd4j::LoopKind::EWS1) { for (auto i = start; i < stop; i++) { @@ -163,6 +165,34 @@ namespace functions { for (unsigned int f = 0; f < tadLength; f++) oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]); } + } else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){ + // this loop effectively turns broadcast into series of scalar ops + auto loopLength = yShapeInfo[shape::rank(yShapeInfo)]; + + for (auto i = start; i < stop; i++) { + auto oY = y + (i * loopLength); + auto oZ = z + (i * loopLength); + + const auto oX = x[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < loopLength; f++) + oZ[f] = OpType::op(oX, oY[f]); + } + } else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){ + // this loop effectively turns broadcast into series of scalar ops + auto loopLength = xShapeInfo[shape::rank(xShapeInfo)]; + + for (auto i = start; i < stop; i++) { + auto oX = x + (i * loopLength); + auto oZ = z + (i * loopLength); + + const auto oY = y[i]; + + PRAGMA_OMP_SIMD + for (unsigned int f = 0; f < loopLength; f++) + oZ[f] = OpType::op(oX[f], oY); + } } else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) { uint tadShapeShapeInfoCast[MAX_RANK]; diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 1f6000f06..9b6d06ec6 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -179,6 +179,7 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) { auto z = result->at(0); ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); delete result; diff --git a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp index 9a8f09b87..bc2ae2152 100644 --- a/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BrodcastTests.cpp @@ -54,7 +54,7 @@ TEST_F(BroadcastMultiDimTest,MultimDimTest) { tad->tadOnlyShapeInfo, //tadShapeInfo tad->tadOffsets, //tadOffset tad->tadOnlyShapeInfo, //tadShapeInfoZ - tad->tadOffsets, 0, tad->numTads); //tadOffsetZ + tad->tadOffsets, nd4j::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ for(int i = 0; i < 30; i++) { ASSERT_EQ(dataAssertion[i],result[i]); } diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 9ce56d463..7cdf40c7f 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -149,6 +149,16 @@ TEST_F(PlaygroundTests, test_bert_1) { delete graph; } +TEST_F(PlaygroundTests, test_one_off_ops_1) { + auto x = NDArrayFactory::create('c', {4, 128, 768}); + auto y = NDArrayFactory::create('c', {4, 128, 1}); + auto z = x.ulike(); + + nd4j::ops::squaredsubtract op; + op.execute({&x, &y}, {&z}); +} + + /* TEST_F(PlaygroundTests, test_broadcast_1) {