cpu truebroadcast fix
Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									b66154a9d4
								
							
						
					
					
						commit
						ae7933a428
					
				@ -46,9 +46,9 @@ void TrueBroadcastHelper<X, Y, Z>::exec(const NDArray& xArr, const NDArray& yArr
 | 
			
		||||
 | 
			
		||||
    const Nd4jLong zLen  = zArr.lengthOf();
 | 
			
		||||
 | 
			
		||||
    std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
 | 
			
		||||
 | 
			
		||||
    auto func = PRAGMA_THREADS_FOR {
 | 
			
		||||
        std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
 | 
			
		||||
 | 
			
		||||
        for (auto i = start; i < stop; ++i) {
 | 
			
		||||
 | 
			
		||||
            shape::index2coords(i, zShapeInfo, zCoords.data());
 | 
			
		||||
@ -109,6 +109,7 @@ void TrueBroadcastBoolHelper<X, Z>::exec(const NDArray& xArr, const NDArray& yAr
 | 
			
		||||
 | 
			
		||||
    auto func = PRAGMA_THREADS_FOR {
 | 
			
		||||
        std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
 | 
			
		||||
 | 
			
		||||
        for (auto i = start; i < stop; ++i) {
 | 
			
		||||
 | 
			
		||||
            shape::index2coords(i, zShapeInfo, zCoords.data());
 | 
			
		||||
@ -167,9 +168,9 @@ void TrueBroadcastIntHelper<X>::exec(const NDArray& xArr, const NDArray& yArr, N
 | 
			
		||||
 | 
			
		||||
    const Nd4jLong zLen  = zArr.lengthOf();
 | 
			
		||||
 | 
			
		||||
    std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
 | 
			
		||||
 | 
			
		||||
    auto func = PRAGMA_THREADS_FOR {
 | 
			
		||||
        std::vector<Nd4jLong> xCoords(xArr.rankOf()), yCoords(yArr.rankOf()), zCoords(zArr.rankOf());
 | 
			
		||||
 | 
			
		||||
        for (auto i = start; i < stop; ++i) {
 | 
			
		||||
 | 
			
		||||
            shape::index2coords(i, zShapeInfo, zCoords.data());
 | 
			
		||||
 | 
			
		||||
@ -832,3 +832,33 @@ TEST_F(BroadcastableOpsTests, broadcast_3) {
 | 
			
		||||
    ASSERT_TRUE(z.isSameShape(e));
 | 
			
		||||
    ASSERT_TRUE(z.equalsTo(e));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
 | 
			
		||||
    auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
 | 
			
		||||
    auto y = NDArrayFactory::create<float>('c', {4, 1, 128});
 | 
			
		||||
    auto z = NDArrayFactory::create<float>('c', {4, 128, 128});
 | 
			
		||||
    auto e = NDArrayFactory::create<float>('c', {4, 128, 128});
 | 
			
		||||
 | 
			
		||||
    x.assign(0.f);
 | 
			
		||||
    y.assign(1.f);
 | 
			
		||||
    z.assign(119.f);
 | 
			
		||||
    e.assign(0.f);
 | 
			
		||||
/*
 | 
			
		||||
    Context ctx(1);
 | 
			
		||||
    ctx.setInputArray(0, &x);
 | 
			
		||||
    ctx.setInputArray(1, &y);
 | 
			
		||||
    ctx.setOutputArray(0, &z);
 | 
			
		||||
 | 
			
		||||
    nd4j::ops::multiply op;
 | 
			
		||||
    auto status = op.execute(&ctx);
 | 
			
		||||
    ASSERT_EQ(Status::OK(), status);
 | 
			
		||||
 | 
			
		||||
    z.printIndexedBuffer();
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
    x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
 | 
			
		||||
 | 
			
		||||
    //z.printIndexedBuffer();
 | 
			
		||||
 | 
			
		||||
    ASSERT_EQ(e, z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user