cpu truebroadcast fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-09 08:01:12 +03:00
parent b66154a9d4
commit ae7933a428
2 changed files with 35 additions and 4 deletions

View File

@ -46,9 +46,9 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
const Nd4jLong zLen = zArr.lengthOf(); const Nd4jLong zLen = zArr.lengthOf();
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) { for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data()); shape::index2coords(i, zShapeInfo, zCoords.data());
@ -109,6 +109,7 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf()); std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) { for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data()); shape::index2coords(i, zShapeInfo, zCoords.data());
@ -167,9 +168,9 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
const Nd4jLong zLen = zArr.lengthOf(); const Nd4jLong zLen = zArr.lengthOf();
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
for (auto i = start; i < stop; ++i) { for (auto i = start; i < stop; ++i) {
shape::index2coords(i, zShapeInfo, zCoords.data()); shape::index2coords(i, zShapeInfo, zCoords.data());

View File

@ -832,3 +832,33 @@ TEST_F(BroadcastableOpsTests, broadcast_3) {
ASSERT_TRUE(z.isSameShape(e)); ASSERT_TRUE(z.isSameShape(e));
ASSERT_TRUE(z.equalsTo(e)); ASSERT_TRUE(z.equalsTo(e));
} }
TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
auto y = NDArrayFactory::create<float>('c', {4, 1, 128});
auto z = NDArrayFactory::create<float>('c', {4, 128, 128});
auto e = NDArrayFactory::create<float>('c', {4, 128, 128});
x.assign(0.f);
y.assign(1.f);
z.assign(119.f);
e.assign(0.f);
/*
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setInputArray(1, &y);
ctx.setOutputArray(0, &z);
nd4j::ops::multiply op;
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
z.printIndexedBuffer();
*/
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
//z.printIndexedBuffer();
ASSERT_EQ(e, z);
}