From ee5d25caa94d527ace43fc645437496ec8b80c83 Mon Sep 17 00:00:00 2001 From: raver119 Date: Mon, 9 Dec 2019 11:17:16 +0300 Subject: [PATCH] cuda broadcast exec fix Signed-off-by: raver119 --- .../helpers/cuda/TrueBroadcastHelper.cu | 22 ++++++++++--------- .../layers_tests/BroadcastableOpsTests.cpp | 16 ++++++++++++++ 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu index f40690795..fdbd001fd 100644 --- a/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu +++ b/libnd4j/include/helpers/cuda/TrueBroadcastHelper.cu @@ -66,17 +66,19 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { - if(ix >= 0) - if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) + if(ix >= 0) { + if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) xCoords[ix--] = zCoords[iz]; else xCoords[ix--] = 0; + } - if(iy >= 0) - if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) + if(iy >= 0) { + if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) yCoords[iy--] = zCoords[iz]; else yCoords[iy--] = 0; + } } const auto xOffset = shape::getOffset(xShapeInfo, xCoords); @@ -100,8 +102,8 @@ void TrueBroadcastHelper::exec(const nd4j::broadcast::Ops opNum, const ND dim3 launchDims; - launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid + launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastHelper::exec"); @@ -182,8 +184,8 @@ template void TrueBroadcastBoolHelper::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid + launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper::exec"); @@ -264,8 +266,8 @@ template void TrueBroadcastIntHelper::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { dim3 launchDims; - launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock - launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid + launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock + launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.z = 1024; // sharedMem PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper::exec"); diff --git a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp index 036117aa9..238c2f15d 100644 --- a/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -862,3 +862,19 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { ASSERT_EQ(e, z); } + +TEST_F(BroadcastableOpsTests, test_bert_multiply_2) { + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {768}); + auto z = NDArrayFactory::create('c', {4, 128, 768}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); + + x.assign(1.f); + y.assign(2.f); + z.assign(119.f); + e.assign(2.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z); + + ASSERT_EQ(e, z); +}