parent
cea68c18f1
commit
ee5d25caa9
|
@ -66,17 +66,19 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
|
|||
|
||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||
|
||||
if(ix >= 0)
|
||||
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
|
||||
if(ix >= 0) {
|
||||
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
|
||||
xCoords[ix--] = zCoords[iz];
|
||||
else
|
||||
xCoords[ix--] = 0;
|
||||
}
|
||||
|
||||
if(iy >= 0)
|
||||
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
|
||||
if(iy >= 0) {
|
||||
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
|
||||
yCoords[iy--] = zCoords[iz];
|
||||
else
|
||||
yCoords[iy--] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
|
||||
|
@ -100,8 +102,8 @@ void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const ND
|
|||
|
||||
dim3 launchDims;
|
||||
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = 1024; // sharedMem
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
|
||||
|
@ -182,8 +184,8 @@ template<typename X, typename Y>
|
|||
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
dim3 launchDims;
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = 1024; // sharedMem
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
|
||||
|
@ -264,8 +266,8 @@ template<typename X>
|
|||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||
|
||||
dim3 launchDims;
|
||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||
launchDims.z = 1024; // sharedMem
|
||||
|
||||
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
|
||||
|
|
|
@ -862,3 +862,19 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
|
|||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(BroadcastableOpsTests, test_bert_multiply_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||
auto y = NDArrayFactory::create<float>('c', {768});
|
||||
auto z = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||
auto e = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||
|
||||
x.assign(1.f);
|
||||
y.assign(2.f);
|
||||
z.assign(119.f);
|
||||
e.assign(2.f);
|
||||
|
||||
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue