parent
b66154a9d4
commit
ae7933a428
|
@ -46,9 +46,9 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
|
|||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
@ -109,6 +109,7 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
@ -167,9 +168,9 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
|
|||
|
||||
const Nd4jLong zLen = zArr.lengthOf();
|
||||
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||
|
|
|
@ -832,3 +832,33 @@ TEST_F(BroadcastableOpsTests, broadcast_3) {
|
|||
ASSERT_TRUE(z.isSameShape(e));
|
||||
ASSERT_TRUE(z.equalsTo(e));
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||
auto y = NDArrayFactory::create<float>('c', {4, 1, 128});
|
||||
auto z = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||
auto e = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||
|
||||
x.assign(0.f);
|
||||
y.assign(1.f);
|
||||
z.assign(119.f);
|
||||
e.assign(0.f);
|
||||
/*
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &x);
|
||||
ctx.setInputArray(1, &y);
|
||||
ctx.setOutputArray(0, &z);
|
||||
|
||||
nd4j::ops::multiply op;
|
||||
auto status = op.execute(&ctx);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
z.printIndexedBuffer();
|
||||
*/
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
|
||||
|
||||
//z.printIndexedBuffer();
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue