Broadcast perf improvements (#248)
* broadcast as scalar edge case Signed-off-by: raver119 <raver119@gmail.com> * missing return Signed-off-by: raver119 <raver119@gmail.com> * few fixes Signed-off-by: raver119 <raver119@gmail.com> * one more fix Signed-off-by: raver119 <raver119@gmail.com> * no need for lambdas Signed-off-by: raver119 <raver119@gmail.com>master
parent
f9d51b7278
commit
2698fbf541
|
@ -163,15 +163,32 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
||||||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo);
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto xLen = shape::length(hXShapeInfo);
|
Nd4jLong numTads = 0;
|
||||||
auto yLen = shape::length(hYShapeInfo);
|
|
||||||
auto numTads = xLen / yLen;
|
switch (loopKind) {
|
||||||
|
case nd4j::LoopKind::BROADCAST_SCALAR_X: {
|
||||||
|
numTads = shape::length(hXShapeInfo);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case nd4j::LoopKind::BROADCAST_SCALAR_Y: {
|
||||||
|
numTads = shape::length(hYShapeInfo);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default: {
|
||||||
|
auto xLen = shape::length(hXShapeInfo);
|
||||||
|
auto yLen = shape::length(hYShapeInfo);
|
||||||
|
numTads = xLen / yLen;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,12 +37,13 @@ namespace nd4j {
|
||||||
class ND4J_EXPORT LoopKind {
|
class ND4J_EXPORT LoopKind {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON};
|
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y};
|
||||||
|
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
||||||
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
|
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -82,6 +83,38 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
|
||||||
return COMMON;
|
return COMMON;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||||
|
auto xRank = shape::rank(xShapeInfo);
|
||||||
|
auto yRank = shape::rank(yShapeInfo);
|
||||||
|
auto zRank = shape::rank(zShapeInfo);
|
||||||
|
|
||||||
|
auto xOrder = shape::order(xShapeInfo);
|
||||||
|
auto yOrder = shape::order(yShapeInfo);
|
||||||
|
auto zOrder = shape::order(zShapeInfo);
|
||||||
|
|
||||||
|
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||||
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
|
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||||
|
|
||||||
|
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
|
||||||
|
// we validate that shapes are equal till the last dim
|
||||||
|
for (int e = 0; e < xRank - 1; e++) {
|
||||||
|
if (xShapeInfo[e+1] != yShapeInfo[e+1])
|
||||||
|
return COMMON;
|
||||||
|
}
|
||||||
|
|
||||||
|
// now, if one of the shapes has 1 as last dim
|
||||||
|
auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0;
|
||||||
|
|
||||||
|
if (detect == 1)
|
||||||
|
return nd4j::LoopKind::BROADCAST_SCALAR_Y;
|
||||||
|
else if (detect == -1)
|
||||||
|
return nd4j::LoopKind::BROADCAST_SCALAR_X;
|
||||||
|
}
|
||||||
|
|
||||||
|
return nd4j::LoopKind::COMMON;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////////
|
||||||
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,7 @@
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <helpers/TAD.h>
|
#include <helpers/TAD.h>
|
||||||
|
#include <helpers/LoopKind.h>
|
||||||
|
|
||||||
#include "legacy_ops.h"
|
#include "legacy_ops.h"
|
||||||
|
|
||||||
|
@ -122,6 +123,7 @@ namespace functions {
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ,
|
Nd4jLong *tadOffsetZ,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop);
|
uint64_t stop);
|
||||||
|
|
||||||
|
@ -149,6 +151,7 @@ namespace functions {
|
||||||
Nd4jLong *tadOffset,
|
Nd4jLong *tadOffset,
|
||||||
Nd4jLong *tadShapeInfoZ,
|
Nd4jLong *tadShapeInfoZ,
|
||||||
Nd4jLong *tadOffsetZ,
|
Nd4jLong *tadOffsetZ,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop);
|
uint64_t stop);
|
||||||
|
|
||||||
|
|
|
@ -75,6 +75,7 @@ namespace functions {
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset,
|
Nd4jLong *zTadOffset,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop) {
|
uint64_t stop) {
|
||||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
||||||
|
@ -88,7 +89,7 @@ namespace functions {
|
||||||
xTadShapeInfo,
|
xTadShapeInfo,
|
||||||
xTadOffset,
|
xTadOffset,
|
||||||
zTadShapeInfo,
|
zTadShapeInfo,
|
||||||
zTadOffset, start, stop), BROADCAST_OPS);
|
zTadOffset, loopKind, start, stop), BROADCAST_OPS);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename X, typename Y, typename Z>
|
template <typename X, typename Y, typename Z>
|
||||||
|
@ -105,6 +106,7 @@ namespace functions {
|
||||||
Nd4jLong *xTadOffset,
|
Nd4jLong *xTadOffset,
|
||||||
Nd4jLong *zTadShapeInfo,
|
Nd4jLong *zTadShapeInfo,
|
||||||
Nd4jLong *zTadOffset,
|
Nd4jLong *zTadOffset,
|
||||||
|
nd4j::LoopKind::Kind loopKind,
|
||||||
uint64_t start,
|
uint64_t start,
|
||||||
uint64_t stop) {
|
uint64_t stop) {
|
||||||
|
|
||||||
|
@ -142,7 +144,7 @@ namespace functions {
|
||||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||||
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
||||||
|
|
||||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
const nd4j::LoopKind::Kind kindOfLoop = loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X || loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||||
|
|
||||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||||
for (auto i = start; i < stop; i++) {
|
for (auto i = start; i < stop; i++) {
|
||||||
|
@ -163,6 +165,34 @@ namespace functions {
|
||||||
for (unsigned int f = 0; f < tadLength; f++)
|
for (unsigned int f = 0; f < tadLength; f++)
|
||||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
||||||
}
|
}
|
||||||
|
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){
|
||||||
|
// this loop effectively turns broadcast into series of scalar ops
|
||||||
|
auto loopLength = yShapeInfo[shape::rank(yShapeInfo)];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto oY = y + (i * loopLength);
|
||||||
|
auto oZ = z + (i * loopLength);
|
||||||
|
|
||||||
|
const auto oX = x[i];
|
||||||
|
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (unsigned int f = 0; f < loopLength; f++)
|
||||||
|
oZ[f] = OpType::op(oX, oY[f]);
|
||||||
|
}
|
||||||
|
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){
|
||||||
|
// this loop effectively turns broadcast into series of scalar ops
|
||||||
|
auto loopLength = xShapeInfo[shape::rank(xShapeInfo)];
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; i++) {
|
||||||
|
auto oX = x + (i * loopLength);
|
||||||
|
auto oZ = z + (i * loopLength);
|
||||||
|
|
||||||
|
const auto oY = y[i];
|
||||||
|
|
||||||
|
PRAGMA_OMP_SIMD
|
||||||
|
for (unsigned int f = 0; f < loopLength; f++)
|
||||||
|
oZ[f] = OpType::op(oX[f], oY);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||||
|
|
|
@ -179,6 +179,7 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
|
||||||
auto z = result->at(0);
|
auto z = result->at(0);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
|
|
|
@ -54,7 +54,7 @@ TEST_F(BroadcastMultiDimTest,MultimDimTest) {
|
||||||
tad->tadOnlyShapeInfo, //tadShapeInfo
|
tad->tadOnlyShapeInfo, //tadShapeInfo
|
||||||
tad->tadOffsets, //tadOffset
|
tad->tadOffsets, //tadOffset
|
||||||
tad->tadOnlyShapeInfo, //tadShapeInfoZ
|
tad->tadOnlyShapeInfo, //tadShapeInfoZ
|
||||||
tad->tadOffsets, 0, tad->numTads); //tadOffsetZ
|
tad->tadOffsets, nd4j::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ
|
||||||
for(int i = 0; i < 30; i++) {
|
for(int i = 0; i < 30; i++) {
|
||||||
ASSERT_EQ(dataAssertion[i],result[i]);
|
ASSERT_EQ(dataAssertion[i],result[i]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -149,6 +149,16 @@ TEST_F(PlaygroundTests, test_bert_1) {
|
||||||
delete graph;
|
delete graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PlaygroundTests, test_one_off_ops_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||||
|
auto z = x.ulike();
|
||||||
|
|
||||||
|
nd4j::ops::squaredsubtract op;
|
||||||
|
op.execute({&x, &y}, {&z});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
||||||
TEST_F(PlaygroundTests, test_broadcast_1) {
|
TEST_F(PlaygroundTests, test_broadcast_1) {
|
||||||
|
|
Loading…
Reference in New Issue