Oleh bert multiply true broad cast (#239)
* libnd4j trueBroadcast rank 3 row implementation of special case Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j rule clarify for second special case for all tests pass * libnd4j parallel_tad loop switch on in special case * libnd4j more general case for special case 2, need additional testing Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j more general case for trueBroadcast special cases added * libnd4j minor corrections and clean up * libnd4j one more minor fix Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed check point to support all Y common vector representations in first special case for trueBroadcast Signed-off-by: Oleg <oleg.semeniv@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
4206171b70
commit
6e6289b6b9
|
@ -14,9 +14,9 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <loops/TrueBroadcastHelper.h>
|
||||
#include <ops/ops.h>
|
||||
|
@ -24,226 +24,268 @@
|
|||
|
||||
using namespace simdOps;
|
||||
|
||||
namespace nd4j {
|
||||
namespace helpers {
|
||||
namespace nd4j {
|
||||
namespace helpers {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Y, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
|
||||
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
||||
1 == zArr.ews() && 'c' == zArr.ordering());
|
||||
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() &&
|
||||
1 == yArr.ews() && 'c' == yArr.ordering() &&
|
||||
1 == zArr.ews() && 'c' == zArr.ordering());
|
||||
|
||||
if (bSpecialCase) {
|
||||
auto yLen = (uint32_t)yArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (uint32_t i = start; i < stop; i++) {
|
||||
auto rZ = z + (i * yLen);
|
||||
auto v = x[i];
|
||||
for (uint32_t j = 0; j < yLen; j++) {
|
||||
rZ[j] = OpType::op(v, y[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||
return;
|
||||
if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) {
|
||||
auto yLen = (uint32_t)yArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (uint32_t i = start; i < stop; i++) {
|
||||
auto rZ = z + (i * yLen);
|
||||
auto v = x[i];
|
||||
for (uint32_t j = 0; j < yLen; j++) {
|
||||
rZ[j] = OpType::op(v, y[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
auto yShapeInt = yArr.getShapeAsVectorInt();
|
||||
auto xShapeInt = xArr.getShapeAsVectorInt();
|
||||
auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; });
|
||||
auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; });
|
||||
|
||||
bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX);
|
||||
|
||||
if (bSpecialCase && bSpecialCase2) {
|
||||
|
||||
int zDim1 = zArr.sizeAt(-2);
|
||||
int zDim2 = zArr.sizeAt(-1);
|
||||
|
||||
int nLen = zArr.lengthOf() / yArr.sizeAt(-1);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (uint32_t total = start; total < stop; total += increment) {
|
||||
|
||||
uint32_t i = total / zDim1;
|
||||
uint32_t j = total % zDim1;
|
||||
|
||||
uint32_t index = (i * zDim1) + j;
|
||||
auto rZ = z + (index * zDim2);
|
||||
auto rY = y + (i * zDim2);
|
||||
auto rX = x[index];
|
||||
|
||||
for (uint32_t n = 0; n < zDim2; n++) {
|
||||
rZ[n] = OpType::op(rX, rY[n]);
|
||||
}
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, nLen, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
X* z = reinterpret_cast<X*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
}
|
||||
else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
|
||||
}
|
||||
|
||||
/*
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
||||
*/
|
||||
}
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Z>
|
||||
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X, typename Z>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template <typename X>
|
||||
template<typename OpType>
|
||||
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
|
||||
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
|
||||
X* z = reinterpret_cast<X*>(zArr.getBuffer());
|
||||
|
||||
const auto xShapeInfo = xArr.getShapeInfo();
|
||||
const auto yShapeInfo = yArr.getShapeInfo();
|
||||
const auto zShapeInfo = zArr.getShapeInfo();
|
||||
|
||||
const int xRank = xArr.rankOf();
|
||||
const int yRank = yArr.rankOf();
|
||||
const int zRank = zArr.rankOf();
|
||||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
||||
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if (ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
} else {
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
} else {
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
|
||||
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
|
||||
|
||||
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, zLen);
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
|
||||
}
|
||||
|
||||
/*
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
|
||||
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
|
||||
|
||||
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
|
||||
*/
|
||||
}
|
||||
}
|
|
@ -546,13 +546,13 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
|
|||
delete result;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) {
|
||||
|
||||
auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32);
|
||||
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
|
||||
y.assign(1.0);
|
||||
x.linspace(1.0);
|
||||
|
||||
|
@ -566,3 +566,119 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
|
|||
|
||||
delete result;
|
||||
}
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) {
|
||||
|
||||
auto y = NDArray('c', { 1, 3 }, nd4j::DataType::FLOAT32);
|
||||
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
y.assign(1.0);
|
||||
x.linspace(1.0);
|
||||
|
||||
nd4j::ops::add op;
|
||||
auto result = op.evaluate({ &x, &y });
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto res = *result->at(0);
|
||||
|
||||
ASSERT_EQ(e, res);
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest3) {
|
||||
|
||||
auto x = NDArray('c', { 3, 5, 1 }, nd4j::DataType::FLOAT32);
|
||||
auto y = NDArray('c', { 3, 1, 4 }, nd4j::DataType::FLOAT32);
|
||||
auto z = NDArray('c', { 3, 5, 4 }, nd4j::DataType::FLOAT32);
|
||||
// recieved by main algorithm
|
||||
auto e = NDArray('c', { 3, 5, 4 }, { 10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(1.f);
|
||||
y.linspace(10.f);
|
||||
z.assign(0.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest4) {
|
||||
|
||||
auto x = NDArray('c', { 2, 3, 5, 1 }, nd4j::DataType::FLOAT32);
|
||||
auto y = NDArray('c', { 2, 3, 1, 4 }, nd4j::DataType::FLOAT32);
|
||||
auto z = NDArray('c', { 2, 3, 5, 4 }, nd4j::DataType::FLOAT32);
|
||||
// recieved by main algorithm
|
||||
auto e = NDArray('c', { 2, 3, 5, 4 }, { 10., 11., 12., 13.,20., 22., 24., 26.,30., 33., 36., 39.,40., 44., 48., 52.,50., 55., 60., 65.,84., 90., 96., 102.,98., 105., 112., 119.,112., 120., 128., 136.,126., 135., 144., 153.,140., 150., 160., 170.,198., 209., 220., 231.,216., 228., 240., 252.,234., 247., 260., 273.,252., 266., 280., 294.,270., 285., 300., 315.,352., 368., 384., 400.,374., 391., 408., 425.,396., 414., 432., 450.,418., 437., 456., 475.,440., 460., 480., 500.,546., 567., 588., 609.,572., 594., 616., 638.,598., 621., 644., 667.,624., 648., 672., 696.,650., 675., 700., 725.,780., 806., 832., 858.,810., 837., 864., 891.,840., 868., 896., 924.,870., 899., 928., 957.,900., 930., 960., 990. }, nd4j::DataType::FLOAT32);
|
||||
x.linspace(1.f);
|
||||
y.linspace(10.f);
|
||||
z.assign(0.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest5) {
|
||||
|
||||
auto x = NDArray('c', { 3, 5, 1 }, nd4j::DataType::FLOAT32);
|
||||
auto y = NDArray('c', { 3, 1, 4 }, nd4j::DataType::FLOAT32);
|
||||
auto z = NDArray('c', { 3, 5, 4 }, nd4j::DataType::FLOAT32);
|
||||
// recieved by main algorithm
|
||||
auto e = NDArray('c', { 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, 0.833333, 0.789474, 0.750000, 0.714286 }, nd4j::DataType::FLOAT32);
|
||||
x.linspace(1.f);
|
||||
y.linspace(10.f);
|
||||
z.assign(0.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest6) {
|
||||
|
||||
auto x = NDArray('c', { 2, 3, 5, 1 }, nd4j::DataType::FLOAT32);
|
||||
auto y = NDArray('c', { 2, 3, 1, 4 }, nd4j::DataType::FLOAT32);
|
||||
auto z = NDArray('c', { 2, 3, 5, 4 }, nd4j::DataType::FLOAT32);
|
||||
// recieved by main algorithm
|
||||
auto e = NDArray('c', { 2, 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235,0.611111, 0.578947, 0.550000, 0.523810,0.666667, 0.631579, 0.600000, 0.571429,0.722222, 0.684211, 0.650000, 0.619048,0.777778, 0.736842, 0.700000, 0.666667,0.833333, 0.789474, 0.750000, 0.714286, 0.727273, 0.695652, 0.666667, 0.64, 0.772727, 0.739130, 0.708333, 0.68, 0.818182, 0.782609, 0.750000, 0.72, 0.863636, 0.826087, 0.791667, 0.76, 0.909091, 0.869565, 0.833333, 0.80, 0.807692, 0.777778, 0.750000, 0.724138, 0.846154, 0.814815, 0.785714, 0.758621, 0.884615, 0.851852, 0.821429, 0.793103, 0.923077, 0.888889, 0.857143, 0.827586, 0.961538, 0.925926, 0.892857, 0.862069, 0.866667, 0.838710, 0.812500, 0.787879, 0.900000, 0.870968, 0.843750, 0.818182, 0.933333, 0.903226, 0.875000, 0.848485, 0.966667, 0.935484, 0.906250, 0.878788, 1.000000, 0.967742, 0.937500, 0.909091 }, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(1.f);
|
||||
y.linspace(10.f);
|
||||
z.assign(0.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest7) {
|
||||
|
||||
auto x = NDArray('c', { 3, 5, 1 }, nd4j::DataType::FLOAT32);
|
||||
auto y = NDArray('c', { 3, 1, 4 }, nd4j::DataType::FLOAT32);
|
||||
auto z = NDArray('c', { 3, 5, 4 }, nd4j::DataType::FLOAT32);
|
||||
// recieved by main algorithm
|
||||
auto e = NDArray('c', { 3, 5, 4 }, { -9., -10., -11., -12.,-8., -9., -10., -11., -7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8.000000, -9.000000, -10.00,-6.000000, -7.000000, -8.000000, -9.000,-5.000000, -6.000000, -7.000000, -8.000,-4.000000, -5.000000, -6.000000, -7.000,-3.000000, -4.000000, -5.000000, -6.000 }, nd4j::DataType::FLOAT32);
|
||||
x.linspace(1.f);
|
||||
y.linspace(10.f);
|
||||
z.assign(0.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) {
|
||||
|
||||
auto x = NDArray('c', { 2, 3, 5, 1 }, nd4j::DataType::FLOAT32);
|
||||
auto y = NDArray('c', { 2, 3, 1, 4 }, nd4j::DataType::FLOAT32);
|
||||
auto z = NDArray('c', { 2, 3, 5, 4 }, nd4j::DataType::FLOAT32);
|
||||
// recieved by main algorithm
|
||||
auto e = NDArray('c', { 2, 3, 5, 4 }, { -9.0, -10., -11., -12.,-8., -9., -10., -11.0,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4., 0., -1., -2., -3. }, nd4j::DataType::FLOAT32);
|
||||
|
||||
x.linspace(1.f);
|
||||
y.linspace(10.f);
|
||||
z.assign(0.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z);
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue