parent
b66154a9d4
commit
ae7933a428
|
@ -46,9 +46,9 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
|
||||||
|
|
||||||
const Nd4jLong zLen = zArr.lengthOf();
|
const Nd4jLong zLen = zArr.lengthOf();
|
||||||
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
@ -109,6 +109,7 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
@ -167,9 +168,9 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
|
||||||
|
|
||||||
const Nd4jLong zLen = zArr.lengthOf();
|
const Nd4jLong zLen = zArr.lengthOf();
|
||||||
|
|
||||||
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
|
||||||
|
|
||||||
for (auto i = start; i < stop; ++i) {
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
shape::index2coords(i, zShapeInfo, zCoords.data());
|
shape::index2coords(i, zShapeInfo, zCoords.data());
|
||||||
|
|
|
@ -832,3 +832,33 @@ TEST_F(BroadcastableOpsTests, broadcast_3) {
|
||||||
ASSERT_TRUE(z.isSameShape(e));
|
ASSERT_TRUE(z.isSameShape(e));
|
||||||
ASSERT_TRUE(z.equalsTo(e));
|
ASSERT_TRUE(z.equalsTo(e));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {4, 1, 128});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4, 128, 128});
|
||||||
|
|
||||||
|
x.assign(0.f);
|
||||||
|
y.assign(1.f);
|
||||||
|
z.assign(119.f);
|
||||||
|
e.assign(0.f);
|
||||||
|
/*
|
||||||
|
Context ctx(1);
|
||||||
|
ctx.setInputArray(0, &x);
|
||||||
|
ctx.setInputArray(1, &y);
|
||||||
|
ctx.setOutputArray(0, &z);
|
||||||
|
|
||||||
|
nd4j::ops::multiply op;
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
z.printIndexedBuffer();
|
||||||
|
*/
|
||||||
|
|
||||||
|
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
|
||||||
|
|
||||||
|
//z.printIndexedBuffer();
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue