Broadcast perf improvements (#248)
* broadcast as scalar edge case Signed-off-by: raver119 <raver119@gmail.com> * missing return Signed-off-by: raver119 <raver119@gmail.com> * few fixes Signed-off-by: raver119 <raver119@gmail.com> * one more fix Signed-off-by: raver119 <raver119@gmail.com> * no need for lambdas Signed-off-by: raver119 <raver119@gmail.com>master
parent
f9d51b7278
commit
2698fbf541
|
@ -163,15 +163,32 @@ void NativeOpExecutioner::execBroadcast(nd4j::LaunchContext *lc,
|
|||
BUILD_PAIRWISE_SELECTOR(xType, yType, zType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||
#else
|
||||
|
||||
auto loopKind = nd4j::LoopKind::deduceKindOfLoopBroadcast(hXShapeInfo, hYShapeInfo, hZShapeInfo);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, start, stop), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR_THRICE(xType, functions::broadcast::Broadcast, ::exec(opNum, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, dimension, dimensionLength, tadOnlyShapeInfo, tadOffsets, tadOnlyShapeInfoZ, tadOffsetsZ, loopKind, start, stop), LIBND4J_TYPES);
|
||||
};
|
||||
|
||||
auto xLen = shape::length(hXShapeInfo);
|
||||
auto yLen = shape::length(hYShapeInfo);
|
||||
auto numTads = xLen / yLen;
|
||||
Nd4jLong numTads = 0;
|
||||
|
||||
switch (loopKind) {
|
||||
case nd4j::LoopKind::BROADCAST_SCALAR_X: {
|
||||
numTads = shape::length(hXShapeInfo);
|
||||
}
|
||||
break;
|
||||
case nd4j::LoopKind::BROADCAST_SCALAR_Y: {
|
||||
numTads = shape::length(hYShapeInfo);
|
||||
}
|
||||
break;
|
||||
default: {
|
||||
auto xLen = shape::length(hXShapeInfo);
|
||||
auto yLen = shape::length(hYShapeInfo);
|
||||
numTads = xLen / yLen;
|
||||
}
|
||||
}
|
||||
|
||||
samediff::Threads::parallel_tad(func, 0, numTads);
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -37,12 +37,13 @@ namespace nd4j {
|
|||
class ND4J_EXPORT LoopKind {
|
||||
|
||||
public:
|
||||
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON};
|
||||
enum Kind {SMALLARR2DX, EWS1, EWSNONZERO, RANK1, RANK2, RANK3, RANK4, RANK5, X_EWSNONZERO, Y_EWSNONZERO, Z_EWSNONZERO, COMMON, BROADCAST_SCALAR_X, BROADCAST_SCALAR_Y};
|
||||
|
||||
static FORCEINLINE Kind deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopTadXZ(const Nd4jLong* xShapeInfo, const Nd4jLong* zShapeInfo, const Nd4jLong* tadShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopTadXYZ(const Nd4jLong* xTadShapeInfo, const Nd4jLong* yTadShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
static FORCEINLINE Kind deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo);
|
||||
|
||||
};
|
||||
|
||||
|
@ -82,6 +83,38 @@ LoopKind::Kind LoopKind::deduceKindOfLoopXZ(const Nd4jLong* xShapeInfo, const Nd
|
|||
return COMMON;
|
||||
}
|
||||
|
||||
LoopKind::Kind LoopKind::deduceKindOfLoopBroadcast(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||
auto xRank = shape::rank(xShapeInfo);
|
||||
auto yRank = shape::rank(yShapeInfo);
|
||||
auto zRank = shape::rank(zShapeInfo);
|
||||
|
||||
auto xOrder = shape::order(xShapeInfo);
|
||||
auto yOrder = shape::order(yShapeInfo);
|
||||
auto zOrder = shape::order(zShapeInfo);
|
||||
|
||||
auto xEws = shape::elementWiseStride(xShapeInfo);
|
||||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zShapeInfo);
|
||||
|
||||
if (xRank == yRank && xRank == zRank && xOrder == 'c' && yOrder == 'c' && zOrder == 'c' && xEws == 1 && yEws == 1 && zEws == 1 && xRank >= 2) {
|
||||
// we validate that shapes are equal till the last dim
|
||||
for (int e = 0; e < xRank - 1; e++) {
|
||||
if (xShapeInfo[e+1] != yShapeInfo[e+1])
|
||||
return COMMON;
|
||||
}
|
||||
|
||||
// now, if one of the shapes has 1 as last dim
|
||||
auto detect = xShapeInfo[xRank] == 1 ? -1 : (yShapeInfo[xRank] == 1) ? 1 : 0;
|
||||
|
||||
if (detect == 1)
|
||||
return nd4j::LoopKind::BROADCAST_SCALAR_Y;
|
||||
else if (detect == -1)
|
||||
return nd4j::LoopKind::BROADCAST_SCALAR_X;
|
||||
}
|
||||
|
||||
return nd4j::LoopKind::COMMON;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
LoopKind::Kind LoopKind::deduceKindOfLoopXYZ(const Nd4jLong* xShapeInfo, const Nd4jLong* yShapeInfo, const Nd4jLong* zShapeInfo) {
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#endif
|
||||
|
||||
#include <helpers/TAD.h>
|
||||
#include <helpers/LoopKind.h>
|
||||
|
||||
#include "legacy_ops.h"
|
||||
|
||||
|
@ -122,6 +123,7 @@ namespace functions {
|
|||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop);
|
||||
|
||||
|
@ -149,6 +151,7 @@ namespace functions {
|
|||
Nd4jLong *tadOffset,
|
||||
Nd4jLong *tadShapeInfoZ,
|
||||
Nd4jLong *tadOffsetZ,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop);
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ namespace functions {
|
|||
Nd4jLong *xTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop) {
|
||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(x,
|
||||
|
@ -88,7 +89,7 @@ namespace functions {
|
|||
xTadShapeInfo,
|
||||
xTadOffset,
|
||||
zTadShapeInfo,
|
||||
zTadOffset, start, stop), BROADCAST_OPS);
|
||||
zTadOffset, loopKind, start, stop), BROADCAST_OPS);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
|
@ -105,6 +106,7 @@ namespace functions {
|
|||
Nd4jLong *xTadOffset,
|
||||
Nd4jLong *zTadShapeInfo,
|
||||
Nd4jLong *zTadOffset,
|
||||
nd4j::LoopKind::Kind loopKind,
|
||||
uint64_t start,
|
||||
uint64_t stop) {
|
||||
|
||||
|
@ -142,7 +144,7 @@ namespace functions {
|
|||
auto yEws = shape::elementWiseStride(yShapeInfo);
|
||||
auto zEws = shape::elementWiseStride(zTadShapeInfo);
|
||||
|
||||
const nd4j::LoopKind::Kind kindOfLoop = nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||
const nd4j::LoopKind::Kind kindOfLoop = loopKind == nd4j::LoopKind::BROADCAST_SCALAR_X || loopKind == nd4j::LoopKind::BROADCAST_SCALAR_Y ? loopKind : nd4j::LoopKind::deduceKindOfLoopXYZ(xTadShapeShapeInfo, yShapeInfo, zTadShapeInfo);
|
||||
|
||||
if (kindOfLoop == nd4j::LoopKind::EWS1) {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
|
@ -163,6 +165,34 @@ namespace functions {
|
|||
for (unsigned int f = 0; f < tadLength; f++)
|
||||
oZ[f * zEws] = OpType::op(oX[f * xEws], y[f * yEws]);
|
||||
}
|
||||
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_X){
|
||||
// this loop effectively turns broadcast into series of scalar ops
|
||||
auto loopLength = yShapeInfo[shape::rank(yShapeInfo)];
|
||||
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto oY = y + (i * loopLength);
|
||||
auto oZ = z + (i * loopLength);
|
||||
|
||||
const auto oX = x[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < loopLength; f++)
|
||||
oZ[f] = OpType::op(oX, oY[f]);
|
||||
}
|
||||
} else if(kindOfLoop == nd4j::LoopKind::BROADCAST_SCALAR_Y){
|
||||
// this loop effectively turns broadcast into series of scalar ops
|
||||
auto loopLength = xShapeInfo[shape::rank(xShapeInfo)];
|
||||
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto oX = x + (i * loopLength);
|
||||
auto oZ = z + (i * loopLength);
|
||||
|
||||
const auto oY = y[i];
|
||||
|
||||
PRAGMA_OMP_SIMD
|
||||
for (unsigned int f = 0; f < loopLength; f++)
|
||||
oZ[f] = OpType::op(oX[f], oY);
|
||||
}
|
||||
}
|
||||
else if(shape::haveSameShapeAndStrides(xTadShapeShapeInfo, yShapeInfo) && shape::haveSameShapeAndStrides(xTadShapeShapeInfo, zTadShapeInfo)) {
|
||||
uint tadShapeShapeInfoCast[MAX_RANK];
|
||||
|
|
|
@ -179,6 +179,7 @@ TEST_F(BroadcastableOpsTests, Test_Minimum_1) {
|
|||
auto z = result->at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
|
|
|
@ -54,7 +54,7 @@ TEST_F(BroadcastMultiDimTest,MultimDimTest) {
|
|||
tad->tadOnlyShapeInfo, //tadShapeInfo
|
||||
tad->tadOffsets, //tadOffset
|
||||
tad->tadOnlyShapeInfo, //tadShapeInfoZ
|
||||
tad->tadOffsets, 0, tad->numTads); //tadOffsetZ
|
||||
tad->tadOffsets, nd4j::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ
|
||||
for(int i = 0; i < 30; i++) {
|
||||
ASSERT_EQ(dataAssertion[i],result[i]);
|
||||
}
|
||||
|
|
|
@ -149,6 +149,16 @@ TEST_F(PlaygroundTests, test_bert_1) {
|
|||
delete graph;
|
||||
}
|
||||
|
||||
TEST_F(PlaygroundTests, test_one_off_ops_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||
auto y = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||
auto z = x.ulike();
|
||||
|
||||
nd4j::ops::squaredsubtract op;
|
||||
op.execute({&x, &y}, {&z});
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
|
||||
TEST_F(PlaygroundTests, test_broadcast_1) {
|
||||
|
|
Loading…
Reference in New Issue