From ae7933a42842b2bcb472e26b0dc256fd6bf2bc58 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 9 Dec 2019 08:01:12 +0300 Subject: [PATCH] cpu truebroadcast fix Signed-off-by: raver119 --- .../helpers/cpu/TrueBroadcastHelper.cpp | 9 +++--- .../layers_tests/BroadcastableOpsTests.cpp | 30 +++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp index dbf080ac9..171d082a7 100644 --- a/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp +++ b/libnd4j/include/helpers/cpu/TrueBroadcastHelper.cpp @@ -46,9 +46,9 @@ void TrueBroadcastHelper::exec(const NDArray& xArr, const NDArray& yArr const Nd4jLong zLen = zArr.lengthOf(); - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - auto func = PRAGMA_THREADS_FOR { + std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { shape::index2coords(i, zShapeInfo, zCoords.data()); @@ -109,6 +109,7 @@ void TrueBroadcastBoolHelper::exec(const NDArray& xArr, const NDArray& yAr auto func = PRAGMA_THREADS_FOR { std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { shape::index2coords(i, zShapeInfo, zCoords.data()); @@ -167,9 +168,9 @@ void TrueBroadcastIntHelper::exec(const NDArray& xArr, const NDArray& yArr, N const Nd4jLong zLen = zArr.lengthOf(); - std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); - auto func = PRAGMA_THREADS_FOR { + std::vector xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); + for (auto i = start; i < stop; ++i) { shape::index2coords(i, zShapeInfo, zCoords.data()); diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index ffa19412a..036117aa9 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -832,3 +832,33 @@ TEST_F(BroadcastableOpsTests, broadcast_3) { ASSERT_TRUE(z.isSameShape(e)); ASSERT_TRUE(z.equalsTo(e)); } + +TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {4, 1, 128}); + auto z = NDArrayFactory::create('c', {4, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 128, 128}); + + x.assign(0.f); + y.assign(1.f); + z.assign(119.f); + e.assign(0.f); +/* + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + nd4j::ops::multiply op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + z.printIndexedBuffer(); +*/ + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + + //z.printIndexedBuffer(); + + ASSERT_EQ(e, z); +}