cuda broadcast exec fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-09 11:17:16 +03:00
parent cea68c18f1
commit ee5d25caa9
2 changed files with 28 additions and 10 deletions

View File

@ -66,18 +66,20 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) { for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
if(ix >= 0) if(ix >= 0) {
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1]) if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
xCoords[ix--] = zCoords[iz]; xCoords[ix--] = zCoords[iz];
else else
xCoords[ix--] = 0; xCoords[ix--] = 0;
}
if(iy >= 0) if(iy >= 0) {
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1]) if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
yCoords[iy--] = zCoords[iz]; yCoords[iy--] = zCoords[iz];
else else
yCoords[iy--] = 0; yCoords[iy--] = 0;
} }
}
const auto xOffset = shape::getOffset(xShapeInfo, xCoords); const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords); const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
@ -100,8 +102,8 @@ void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const ND
dim3 launchDims; dim3 launchDims;
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
launchDims.z = 1024; // sharedMem launchDims.z = 1024; // sharedMem
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec"); PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
@ -182,8 +184,8 @@ template<typename X, typename Y>
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims; dim3 launchDims;
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
launchDims.z = 1024; // sharedMem launchDims.z = 1024; // sharedMem
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec"); PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
@ -264,8 +266,8 @@ template<typename X>
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) { void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
dim3 launchDims; dim3 launchDims;
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
launchDims.z = 1024; // sharedMem launchDims.z = 1024; // sharedMem
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec"); PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");

View File

@ -862,3 +862,19 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
TEST_F(BroadcastableOpsTests, test_bert_multiply_2) {
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
auto y = NDArrayFactory::create<float>('c', {768});
auto z = NDArrayFactory::create<float>('c', {4, 128, 768});
auto e = NDArrayFactory::create<float>('c', {4, 128, 768});
x.assign(1.f);
y.assign(2.f);
z.assign(119.f);
e.assign(2.f);
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
ASSERT_EQ(e, z);
}