diff --git a/libnd4j/CMakeSettings.json b/libnd4j/CMakeSettings.json index afda69260..867132ab2 100644 --- a/libnd4j/CMakeSettings.json +++ b/libnd4j/CMakeSettings.json @@ -1,4 +1,4 @@ -{ +{ "configurations": [ { "name": "x64-Debug", diff --git a/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp b/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp index c79c1f242..6005c3647 100644 --- a/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp +++ b/libnd4j/include/loops/cpu/TrueBroadcastHelper.hpp @@ -32,9 +32,10 @@ template template void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { + const X* x = reinterpret_cast(xArr.getBuffer()); const Y* y = reinterpret_cast(yArr.getBuffer()); - Z* z = reinterpret_cast(zArr.getBuffer()); + Z* z = reinterpret_cast(zArr.getBuffer()); const auto xShapeInfo = xArr.getShapeInfo(); const auto yShapeInfo = yArr.getShapeInfo(); @@ -44,8 +45,26 @@ void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr const int yRank = yArr.rankOf(); const int zRank = zArr.rankOf(); - const Nd4jLong zLen = zArr.lengthOf(); + bool bSpecialCase = (1 == xArr.ews() && 'c' == xArr.ordering() && 1 == yRank && + 1 == yArr.ews() && 'c' == yArr.ordering() && + 1 == zArr.ews() && 'c' == zArr.ordering()); + if (bSpecialCase) { + auto yLen = (uint32_t)yArr.lengthOf(); + auto func = PRAGMA_THREADS_FOR{ + for (uint32_t i = start; i < stop; i++) { + auto rZ = z + (i * yLen); + auto v = x[i]; + for (uint32_t j = 0; j < yLen; j++) { + rZ[j] = OpType::op(v, y[j]); + } + } + }; + samediff::Threads::parallel_tad(func, 0, xArr.lengthOf()); + return; + } + + const Nd4jLong zLen = zArr.lengthOf(); auto func = PRAGMA_THREADS_FOR { std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp index 1815e5336..600004ec2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -532,3 +532,24 @@ TEST_F(DeclarableOpsTests14, repeat_5) { delete result; } +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_SpecialCaseTest) { + + auto y = NDArray('c', { 3 }, nd4j::DataType::FLOAT32); + auto x = NDArray('c', { 5, 2, 1 }, nd4j::DataType::FLOAT32); + + auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, nd4j::DataType::FLOAT32); + + y.assign(1.0); + x.linspace(1.0); + + nd4j::ops::add op; + auto result = op.evaluate({ &x, &y }); + ASSERT_EQ(Status::OK(), result->status()); + + auto res = *result->at(0); + + ASSERT_EQ(e, res); + + delete result; +}