Oleh bert multiply true broad cast (#239)

* libnd4j trueBroadcast rank 3 row implementation of special case

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j rule clarify for second special case for all tests pass

* libnd4j parallel_tad loop switch on in special case

* libnd4j more general case for special case 2, need additional testing

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j more general case for trueBroadcast special cases added

* libnd4j minor corrections and clean up

* libnd4j one more minor fix

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fixed check point to support all Y common vector representations in first special case for trueBroadcast

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
Oleh 2020-02-14 11:04:38 +02:00 committed by GitHub
parent 4206171b70
commit 6e6289b6b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 377 additions and 219 deletions

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <loops/TrueBroadcastHelper.h>
#include <ops/ops.h>
@ -24,226 +24,268 @@
using namespace simdOps;
namespace nd4j {
namespace helpers {
namespace nd4j {
namespace helpers {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
////////////////////////////////////////////////////////////////////////
template <typename X, typename Y, typename Z>
template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering());
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() &&
1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering());
if (bSpecialCase) {
auto yLen = (uint32_t)yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
for (uint32_t i = start; i < stop; i++) {
auto rZ = z + (i * yLen);
auto v = x[i];
for (uint32_t j = 0; j < yLen; j++) {
rZ[j] = OpType::op(v, y[j]);
}
}
};
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
if (bSpecialCase && yArr.isColumnVector() && 1 == xArr.sizeAt(-1) ) {
auto yLen = (uint32_t)yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
for (uint32_t i = start; i < stop; i++) {
auto rZ = z + (i * yLen);
auto v = x[i];
for (uint32_t j = 0; j < yLen; j++) {
rZ[j] = OpType::op(v, y[j]);
}
}
};
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
}
auto yShapeInt = yArr.getShapeAsVectorInt();
auto xShapeInt = xArr.getShapeAsVectorInt();
auto nCountY = std::count_if(yShapeInt.cbegin(), yShapeInt.cend(), [](int i) { return i == 1; });
auto nCountX = std::count_if(xShapeInt.cbegin(), xShapeInt.cend(), [](int i) { return i == 1; });
bool bSpecialCase2 = (xRank == zRank && yRank == zRank && 1 == xArr.sizeAt(-1) && 1 == yArr.sizeAt(-2) && 1 == nCountY && 1 == nCountX);
if (bSpecialCase && bSpecialCase2) {
int zDim1 = zArr.sizeAt(-2);
int zDim2 = zArr.sizeAt(-1);
int nLen = zArr.lengthOf() / yArr.sizeAt(-1);
auto func = PRAGMA_THREADS_FOR{
for (uint32_t total = start; total < stop; total += increment) {
uint32_t i = total / zDim1;
uint32_t j = total % zDim1;
uint32_t index = (i * zDim1) + j;
auto rZ = z + (index * zDim2);
auto rY = y + (i * zDim2);
auto rX = x[index];
for (uint32_t n = 0; n < zDim2; n++) {
rZ[n] = OpType::op(rX, rY[n]);
}
}
};
samediff::Threads::parallel_tad(func, 0, nLen, 1);
return;
}
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y, typename Z>
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y>
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpType>
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
X* z = reinterpret_cast<X*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
}
else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
}
else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
}
/*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
*/
}
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y, typename Z>
void TrueBroadcastHelper<X, Y, Z>::exec(const nd4j::broadcast::Ops opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TTT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X, typename Z>
template<typename OpType>
void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset], nullptr);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X, typename Y>
void TrueBroadcastBoolHelper<X, Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_TT(exec, PARAMS(xArr, yArr, zArr), BROADCAST_BOOL_OPS);
}
////////////////////////////////////////////////////////////////////////
template <typename X>
template<typename OpType>
void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const X* y = reinterpret_cast<X*>(yArr.getBuffer());
X* z = reinterpret_cast<X*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
const auto zShapeInfo = zArr.getShapeInfo();
const int xRank = xArr.rankOf();
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data());
for (int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if (ix >= 0) {
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) {
xCoords[ix--] = zCoords[iz];
} else {
xCoords[ix--] = 0;
}
}
if (iy >= 0) {
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) {
yCoords[iy--] = zCoords[iz];
} else {
yCoords[iy--] = 0;
}
}
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords.data());
const auto yOffset = shape::getOffset(yShapeInfo, yCoords.data());
const auto zOffset = shape::getOffset(zShapeInfo, zCoords.data());
z[zOffset] = OpType::op(x[xOffset], y[yOffset]);
}
};
samediff::Threads::parallel_for(func, 0, zLen);
}
template <typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
DISPATCH_BY_OPNUM_T(exec, PARAMS(xArr, yArr, zArr), BROADCAST_INT_OPS);
}
/*
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_0);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_1);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_2);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_3);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_4);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_5);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_6);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_7);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_8);
BUILD_PAIRWISE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastHelper, , PAIRWISE_TYPES_9);
BUILD_DOUBLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastBoolHelper, , LIBND4J_TYPES, BOOL_TYPES);
BUILD_SINGLE_TEMPLATE(template class ND4J_EXPORT TrueBroadcastIntHelper, , INTEGER_TYPES);
*/
}
}

View File

@ -546,13 +546,13 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
delete result;
}
/////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) {
auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32);
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
y.assign(1.0);
x.linspace(1.0);
@ -566,3 +566,119 @@ TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
delete result;
}
/////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) {
auto y = NDArray('c', { 1, 3 }, nd4j::DataType::FLOAT32);
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
y.assign(1.0);
x.linspace(1.0);
nd4j::ops::add op;
auto result = op.evaluate({ &x, &y });
ASSERT_EQ(Status::OK(), result->status());
auto res = *result->at(0);
ASSERT_EQ(e, res);
delete result;
}
///////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest3) {
auto x = NDArray('c', { 3, 5, 1 }, nd4j::DataType::FLOAT32);
auto y = NDArray('c', { 3, 1, 4 }, nd4j::DataType::FLOAT32);
auto z = NDArray('c', { 3, 5, 4 }, nd4j::DataType::FLOAT32);
// recieved by main algorithm
auto e = NDArray('c', { 3, 5, 4 }, { 10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315. }, nd4j::DataType::FLOAT32);
x.linspace(1.f);
y.linspace(10.f);
z.assign(0.f);
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);
ASSERT_EQ(e, z);
}
///////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest4) {
auto x = NDArray('c', { 2, 3, 5, 1 }, nd4j::DataType::FLOAT32);
auto y = NDArray('c', { 2, 3, 1, 4 }, nd4j::DataType::FLOAT32);
auto z = NDArray('c', { 2, 3, 5, 4 }, nd4j::DataType::FLOAT32);
// recieved by main algorithm
auto e = NDArray('c', { 2, 3, 5, 4 }, { 10., 11., 12., 13.,20., 22., 24., 26.,30., 33., 36., 39.,40., 44., 48., 52.,50., 55., 60., 65.,84., 90., 96., 102.,98., 105., 112., 119.,112., 120., 128., 136.,126., 135., 144., 153.,140., 150., 160., 170.,198., 209., 220., 231.,216., 228., 240., 252.,234., 247., 260., 273.,252., 266., 280., 294.,270., 285., 300., 315.,352., 368., 384., 400.,374., 391., 408., 425.,396., 414., 432., 450.,418., 437., 456., 475.,440., 460., 480., 500.,546., 567., 588., 609.,572., 594., 616., 638.,598., 621., 644., 667.,624., 648., 672., 696.,650., 675., 700., 725.,780., 806., 832., 858.,810., 837., 864., 891.,840., 868., 896., 924.,870., 899., 928., 957.,900., 930., 960., 990. }, nd4j::DataType::FLOAT32);
x.linspace(1.f);
y.linspace(10.f);
z.assign(0.f);
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);
ASSERT_EQ(e, z);
}
///////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest5) {
auto x = NDArray('c', { 3, 5, 1 }, nd4j::DataType::FLOAT32);
auto y = NDArray('c', { 3, 1, 4 }, nd4j::DataType::FLOAT32);
auto z = NDArray('c', { 3, 5, 4 }, nd4j::DataType::FLOAT32);
// recieved by main algorithm
auto e = NDArray('c', { 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, 0.833333, 0.789474, 0.750000, 0.714286 }, nd4j::DataType::FLOAT32);
x.linspace(1.f);
y.linspace(10.f);
z.assign(0.f);
x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z);
ASSERT_EQ(e, z);
}
///////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest6) {
auto x = NDArray('c', { 2, 3, 5, 1 }, nd4j::DataType::FLOAT32);
auto y = NDArray('c', { 2, 3, 1, 4 }, nd4j::DataType::FLOAT32);
auto z = NDArray('c', { 2, 3, 5, 4 }, nd4j::DataType::FLOAT32);
// recieved by main algorithm
auto e = NDArray('c', { 2, 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235,0.611111, 0.578947, 0.550000, 0.523810,0.666667, 0.631579, 0.600000, 0.571429,0.722222, 0.684211, 0.650000, 0.619048,0.777778, 0.736842, 0.700000, 0.666667,0.833333, 0.789474, 0.750000, 0.714286, 0.727273, 0.695652, 0.666667, 0.64, 0.772727, 0.739130, 0.708333, 0.68, 0.818182, 0.782609, 0.750000, 0.72, 0.863636, 0.826087, 0.791667, 0.76, 0.909091, 0.869565, 0.833333, 0.80, 0.807692, 0.777778, 0.750000, 0.724138, 0.846154, 0.814815, 0.785714, 0.758621, 0.884615, 0.851852, 0.821429, 0.793103, 0.923077, 0.888889, 0.857143, 0.827586, 0.961538, 0.925926, 0.892857, 0.862069, 0.866667, 0.838710, 0.812500, 0.787879, 0.900000, 0.870968, 0.843750, 0.818182, 0.933333, 0.903226, 0.875000, 0.848485, 0.966667, 0.935484, 0.906250, 0.878788, 1.000000, 0.967742, 0.937500, 0.909091 }, nd4j::DataType::FLOAT32);
x.linspace(1.f);
y.linspace(10.f);
z.assign(0.f);
x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z);
ASSERT_EQ(e, z);
}
///////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest7) {
auto x = NDArray('c', { 3, 5, 1 }, nd4j::DataType::FLOAT32);
auto y = NDArray('c', { 3, 1, 4 }, nd4j::DataType::FLOAT32);
auto z = NDArray('c', { 3, 5, 4 }, nd4j::DataType::FLOAT32);
// recieved by main algorithm
auto e = NDArray('c', { 3, 5, 4 }, { -9., -10., -11., -12.,-8., -9., -10., -11., -7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8.000000, -9.000000, -10.00,-6.000000, -7.000000, -8.000000, -9.000,-5.000000, -6.000000, -7.000000, -8.000,-4.000000, -5.000000, -6.000000, -7.000,-3.000000, -4.000000, -5.000000, -6.000 }, nd4j::DataType::FLOAT32);
x.linspace(1.f);
y.linspace(10.f);
z.assign(0.f);
x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z);
ASSERT_EQ(e, z);
}
///////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) {
auto x = NDArray('c', { 2, 3, 5, 1 }, nd4j::DataType::FLOAT32);
auto y = NDArray('c', { 2, 3, 1, 4 }, nd4j::DataType::FLOAT32);
auto z = NDArray('c', { 2, 3, 5, 4 }, nd4j::DataType::FLOAT32);
// recieved by main algorithm
auto e = NDArray('c', { 2, 3, 5, 4 }, { -9.0, -10., -11., -12.,-8., -9., -10., -11.0,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4., 0., -1., -2., -3. }, nd4j::DataType::FLOAT32);
x.linspace(1.f);
y.linspace(10.f);
z.assign(0.f);
x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z);
ASSERT_EQ(e, z);
}