Oleh true broadcast opt (#234)

* libnd4j trueBroadcast special case

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j fix trueBroadcast special case

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j special case of TrueBroadcastHelper

Signed-off-by: Oleg <oleg.semeniv@gmail.com>

* libnd4j trueBroadCast special case and test

* libnd4j minor changes sync with master

* libnd4j changes to TrueBroadcastHelper.hpp per require

Signed-off-by: Oleg <oleg.semeniv@gmail.com>
master
Oleh 2020-02-12 13:12:17 +02:00 committed by GitHub
parent b941186302
commit 11cb561045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 3 deletions

View File

@ -1,4 +1,4 @@
{
{
"configurations": [
{
"name": "x64-Debug",

View File

@ -32,9 +32,10 @@ template <typename X, typename Y, typename Z>
template<typename OpType>
void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
const X* x = reinterpret_cast<X*>(xArr.getBuffer());
const Y* y = reinterpret_cast<Y*>(yArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
Z* z = reinterpret_cast<Z*>(zArr.getBuffer());
const auto xShapeInfo = xArr.getShapeInfo();
const auto yShapeInfo = yArr.getShapeInfo();
@ -44,8 +45,26 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
const int yRank = yArr.rankOf();
const int zRank = zArr.rankOf();
const Nd4jLong zLen = zArr.lengthOf();
bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank &&
1 == yArr.ews() && 'c' == yArr.ordering() &&
1 == zArr.ews() && 'c' == zArr.ordering());
if (bSpecialCase) {
auto yLen = (uint32_t)yArr.lengthOf();
auto func = PRAGMA_THREADS_FOR{
for (uint32_t i = start; i < stop; i++) {
auto rZ = z + (i * yLen);
auto v = x[i];
for (uint32_t j = 0; j < yLen; j++) {
rZ[j] = OpType::op(v, y[j]);
}
}
};
samediff::Threads::parallel_tad(func, 0, xArr.lengthOf());
return;
}
const Nd4jLong zLen = zArr.lengthOf();
auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());

View File

@ -532,3 +532,24 @@ TEST_F(DeclarableOpsTests14, repeat_5) {
delete result;
}
/////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) {
auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32);
auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32);
auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32);
y.assign(1.0);
x.linspace(1.0);
nd4j::ops::add op;
auto result = op.evaluate({ &x, &y });
ASSERT_EQ(Status::OK(), result->status());
auto res = *result->at(0);
ASSERT_EQ(e, res);
delete result;
}