parent
cea68c18f1
commit
ee5d25caa9
|
@ -66,18 +66,20 @@ __global__ static void trueBroadcastCuda(const void* vx, const Nd4jLong* xShapeI
|
||||||
|
|
||||||
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
for(int ix = xRank - 1, iy = yRank - 1, iz = zRank - 1; iz >= 0; --iz) {
|
||||||
|
|
||||||
if(ix >= 0)
|
if(ix >= 0) {
|
||||||
if(xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
|
if (xShapeInfo[ix + 1] == zShapeInfo[iz + 1])
|
||||||
xCoords[ix--] = zCoords[iz];
|
xCoords[ix--] = zCoords[iz];
|
||||||
else
|
else
|
||||||
xCoords[ix--] = 0;
|
xCoords[ix--] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
if(iy >= 0)
|
if(iy >= 0) {
|
||||||
if(yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
|
if (yShapeInfo[iy + 1] == zShapeInfo[iz + 1])
|
||||||
yCoords[iy--] = zCoords[iz];
|
yCoords[iy--] = zCoords[iz];
|
||||||
else
|
else
|
||||||
yCoords[iy--] = 0;
|
yCoords[iy--] = 0;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
|
const auto xOffset = shape::getOffset(xShapeInfo, xCoords);
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||||
|
@ -100,8 +102,8 @@ void TrueBroadcastHelper<X,Y,Z>::exec(const nd4j::broadcast::Ops opNum, const ND
|
||||||
|
|
||||||
dim3 launchDims;
|
dim3 launchDims;
|
||||||
|
|
||||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||||
launchDims.z = 1024; // sharedMem
|
launchDims.z = 1024; // sharedMem
|
||||||
|
|
||||||
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
|
PointersManager manager(xArr.getContext(), "TrueBroadcastHelper<X,Y,Z>::exec");
|
||||||
|
@ -182,8 +184,8 @@ template<typename X, typename Y>
|
||||||
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
void TrueBroadcastBoolHelper<X,Y>::exec(const nd4j::broadcast::BoolOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
|
||||||
dim3 launchDims;
|
dim3 launchDims;
|
||||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||||
launchDims.z = 1024; // sharedMem
|
launchDims.z = 1024; // sharedMem
|
||||||
|
|
||||||
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
|
PointersManager manager(xArr.getContext(), "TrueBroadcastBoolHelper<X,Y>::exec");
|
||||||
|
@ -264,8 +266,8 @@ template<typename X>
|
||||||
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
void TrueBroadcastIntHelper<X>::exec(const nd4j::broadcast::IntOps opNum, const NDArray& xArr, const NDArray& yArr, NDArray& zArr) {
|
||||||
|
|
||||||
dim3 launchDims;
|
dim3 launchDims;
|
||||||
launchDims.x = MAX_NUM_THREADS / 8; // threadsPerBlock
|
launchDims.x = 128; //MAX_NUM_THREADS / 8; // threadsPerBlock
|
||||||
launchDims.y = (zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
launchDims.y = 256; //(zArr.lengthOf() + launchDims.x - 1) / launchDims.x; // blocksPerGrid
|
||||||
launchDims.z = 1024; // sharedMem
|
launchDims.z = 1024; // sharedMem
|
||||||
|
|
||||||
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
|
PointersManager manager(xArr.getContext(), "TrueBroadcastIntHelper<X>::exec");
|
||||||
|
|
|
@ -862,3 +862,19 @@ TEST_F(BroadcastableOpsTests, test_bert_multiply_1) {
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(BroadcastableOpsTests, test_bert_multiply_2) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {4, 128, 1});
|
||||||
|
auto y = NDArrayFactory::create<float>('c', {768});
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||||
|
auto e = NDArrayFactory::create<float>('c', {4, 128, 768});
|
||||||
|
|
||||||
|
x.assign(1.f);
|
||||||
|
y.assign(2.f);
|
||||||
|
z.assign(119.f);
|
||||||
|
e.assign(2.f);
|
||||||
|
|
||||||
|
x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), &y, &z);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue