Broadcast perf improvements (#248)

* broadcast as scalar edge case

Signed-off-by: raver119 <raver119@gmail.com>

* missing return

Signed-off-by: raver119 <raver119@gmail.com>

* few fixes

Signed-off-by: raver119 <raver119@gmail.com>

* one more fix

Signed-off-by: raver119 <raver119@gmail.com>

* no need for lambdas

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-17 16:25:10 +03:00 committed by GitHub
parent f9d51b7278
commit 2698fbf541
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 102 additions and 8 deletions

View File

@ -163,15 +163,32 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
#else
auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo);
auto func = PRAGMA_THREADS_FOR {
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES);
};
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
auto numTads = xLen / yLen;
Nd4jLong numTads = 0;
switch (loopKind) {
case nd4j::LoopKind::BROADCAST_SCALAR_X: {
numTads = shape::length(hXShapeInfo);
}
break;
case nd4j::LoopKind::BROADCAST_SCALAR_Y: {
numTads = shape::length(hYShapeInfo);
}
break;
default: {
auto xLen = shape::length(hXShapeInfo);
auto yLen = shape::length(hYShapeInfo);
numTads = xLen / yLen;
}
}
samediff::Threads::parallel_tad(func, 0, numTads);
#endif
}

View File

@ -37,12 +37,13 @@ namespace nd4j {
class ND4J_EXPORT LoopKind {
public:
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON};
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y};
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
};
@ -82,6 +83,38 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
return COMMON;
}
LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
auto xRank = shape::rank(xShapeInfo);
auto yRank = shape::rank(yShapeInfo);
auto zRank = shape::rank(zShapeInfo);
auto xOrder = shape::order(xShapeInfo);
auto yOrder = shape::order(yShapeInfo);
auto zOrder = shape::order(zShapeInfo);
auto xEws = shape::elementWiseStride(xShapeInfo);
auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zShapeInfo);
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
// we validate that shapes are equal till the last dim
for (int e = 0; e < xRank - 1; e++) {
if (xShapeInfo[e+1] != yShapeInfo[e+1])
return COMMON;
}
// now, if one of the shapes has 1 as last dim
auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0;
if (detect == 1)
return nd4j::LoopKind::BROADCAST_SCALAR_Y;
else if (detect == -1)
return nd4j::LoopKind::BROADCAST_SCALAR_X;
}
return nd4j::LoopKind::COMMON;
}
//////////////////////////////////////////////////////////////////////////////
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {

View File

@ -40,6 +40,7 @@
#endif
#include <helpers/TAD.h>
#include <helpers/LoopKind.h>
#include "legacy_ops.h"
@ -122,6 +123,7 @@ namespace functions {
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop);
@ -149,6 +151,7 @@ namespace functions {
Nd4jLong *tadOffset,
Nd4jLong *tadShapeInfoZ,
Nd4jLong *tadOffsetZ,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop);

View File

@ -75,6 +75,7 @@ namespace functions {
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
@ -88,7 +89,7 @@ namespace functions {
xTadShapeInfo,
xTadOffset,
zTadShapeInfo,
zTadOffset, start, stop), BROADCAST_OPS);
zTadOffset, loopKind, start, stop), BROADCAST_OPS);
}
template <typename X, typename Y, typename Z>
@ -105,6 +106,7 @@ namespace functions {
Nd4jLong *xTadOffset,
Nd4jLong *zTadShapeInfo,
Nd4jLong *zTadOffset,
nd4j::LoopKind::Kind loopKind,
uint64_t start,
uint64_t stop) {
@ -142,7 +144,7 @@ namespace functions {
auto yEws = shape::elementWiseStride(yShapeInfo);
auto zEws = shape::elementWiseStride(zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
const nd4j::LoopKind::Kind kindOfLoop = loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X || loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
if (kindOfLoop == nd4j::LoopKind::EWS1) {
for (auto i = start; i < stop; i++) {
@ -163,6 +165,34 @@ namespace functions {
for (unsigned int f = 0; f < tadLength; f++)
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
}
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){
// this loop effectively turns broadcast into series of scalar ops
auto loopLength = yShapeInfo[shape::rank(yShapeInfo)];
for (auto i = start; i < stop; i++) {
auto oY = y + (i * loopLength);
auto oZ = z + (i * loopLength);
const auto oX = x[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < loopLength; f++)
oZ[f] = OpType::op(oX, oY[f]);
}
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){
// this loop effectively turns broadcast into series of scalar ops
auto loopLength = xShapeInfo[shape::rank(xShapeInfo)];
for (auto i = start; i < stop; i++) {
auto oX = x + (i * loopLength);
auto oZ = z + (i * loopLength);
const auto oY = y[i];
PRAGMA_OMP_SIMD
for (unsigned int f = 0; f < loopLength; f++)
oZ[f] = OpType::op(oX[f], oY);
}
}
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
uint tadShapeShapeInfoCast[MAX_RANK];

View File

@ -179,6 +179,7 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
auto z = result->at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;

View File

@ -54,7 +54,7 @@ TEST_F(BroadcastMultiDimTest,MultimDimTest) {
tad->tadOnlyShapeInfo, //tadShapeInfo
tad->tadOffsets, //tadOffset
tad->tadOnlyShapeInfo, //tadShapeInfoZ
tad->tadOffsets, 0, tad->numTads); //tadOffsetZ
tad->tadOffsets, nd4j::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ
for(int i = 0; i < 30; i++) {
ASSERT_EQ(dataAssertion[i],result[i]);
}

View File

@ -149,6 +149,16 @@ TEST_F(PlaygroundTests, test_bert_1) {
delete graph;
}
TEST_F(PlaygroundTests, test_one_off_ops_1) {
auto x = NDArrayFactory::create<float>('c', {4, 128, 768});
auto y = NDArrayFactory::create<float>('c', {4, 128, 1});
auto z = x.ulike();
nd4j::ops::squaredsubtract op;
op.execute({&x, &y}, {&z});
}
/*
TEST_F(PlaygroundTests, test_broadcast_1) {