[WIP] maxpool_bp cuda fix (#212)

* one test for alex

Signed-off-by: raver119 <raver119@gmail.com>

* fix

Signed-off-by: raver119 <raver119@gmail.com>

* get rid of safety offset in cpp

Signed-off-by: raver119 <raver119@gmail.com>

* bfloat16

Signed-off-by: raver119 <raver119@gmail.com>

* minor test rearrangement to fastpath launch

Signed-off-by: raver119 <raver119@gmail.com>

* - atomicAdd/Mul/Div fix for float16/bfloat16 misalignment
- one special test for maxpoolbp java
- safety offset of 8 bytes is back to libnd4j legacy

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-31 20:57:05 +03:00 committed by GitHub
parent f00a7bb3f2
commit b71c993ded
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 21 deletions

View File

@ -931,13 +931,13 @@ void initializeFunctions(Nd4jPointer *functions) {
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
Nd4jPointer pointer;
// cudaHostAllocMapped |cudaHostAllocPortable
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize + 8, cudaHostAllocDefault);
if (res != 0) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed");
}
return pointer;
return reinterpret_cast<int8_t*>(pointer);
}
/**
@ -950,13 +950,13 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
*/
Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
Nd4jPointer pointer;
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize);
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize + 8);
if (res != 0) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed");
}
return pointer;
return reinterpret_cast<int8_t*>(pointer);
}
/**

View File

@ -908,7 +908,7 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
/*** max ***/
case 0: {
coord2 = hstart;
coord3 = hend;
coord3 = wstart;
T max = -DataTypeUtils::max<T>();
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
@ -923,8 +923,9 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
}
coords[2] = coord2;
coords[3] = coord3;
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], y[yOffset]);
auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[zOffset], y[yOffset]);
//z[zOffset] += y[yOffset];
}
break;
@ -987,7 +988,7 @@ void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& i
PointersManager manager(block.launchContext(), "pooling2dBP");
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int threadsPerBlock = 256;
const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;

View File

@ -904,13 +904,13 @@ inline __device__ int16_t nd4j_atomicMax<int16_t>(int16_t* address, int16_t val)
template <>
inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val) {
int* address_as_ull = (int*) address;
auto address_as_ull = (int*) address;
long addr = (long) address;
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (addr - 2);
address_as_ull = (int *) (address - 1);
PAIR old, assumed, fresh;
@ -937,13 +937,13 @@ inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val)
template <>
inline __device__ bfloat16 nd4j_atomicMax<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = (int*) address;
auto address_as_ull = (int*) address;
long addr = (long)(address);
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (addr - 2);
address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh;
@ -1060,13 +1060,13 @@ inline __device__ float16 nd4j_atomicAdd<float16>(float16* address, float16 val)
#if __CUDA_ARCH__ >= 700
atomicAdd(reinterpret_cast<__half*>(address), val.data);
#else
int* address_as_ull = (int*) address;
auto address_as_ull = (int*) address;
long addr = (long) address;
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (addr - 2);
address_as_ull = (int *) (address - 1);
PAIR old, assumed, fresh;
@ -1094,13 +1094,13 @@ inline __device__ float16 nd4j_atomicAdd<float16>(float16* address, float16 val)
template <>
inline __device__ bfloat16 nd4j_atomicAdd<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = (int*) address;
auto address_as_ull = (int*) address;
long addr = (long)(address);
auto addr = (long)(address);
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (addr - 2);
address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh;
@ -1367,13 +1367,13 @@ inline __device__ Nd4jLong nd4j_atomicMul<Nd4jLong>(Nd4jLong* address, Nd4jLong
template <>
inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = (int*) address;
auto address_as_ull = (int*) address;
long addr = (long)(address);
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (addr - 2);
address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh;
@ -1400,13 +1400,13 @@ inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16
template <>
inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) {
int* address_as_ull = (int*) address;
auto address_as_ull = (int*) address;
long addr = (long)(address);
bool misaligned = addr & 0x3;
if (misaligned)
address_as_ull = (int *) (addr - 2);
address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh;

View File

@ -905,6 +905,25 @@ TEST_F(DeclarableOpsTests12, softmax_9) {
delete arrF;
}
TEST_F(DeclarableOpsTests12, maxpool_bp_half_1) {
auto x = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1}, {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f});
auto y = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1}, {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
auto z = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1});
nd4j::ops::maxpool2d_bp op;
Context ctx(1);
Nd4jLong iArgs[] = {5,1,1, 2,2,0, 1,1,1, 0,0};
ctx.setIArguments(iArgs, 11);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, lrn_bp_1) {

View File

@ -788,4 +788,23 @@ public class CustomOpsTests extends BaseNd4jTest {
Nd4j.exec(op);
Nd4j.getExecutioner().commit();
}
@Test
public void test() throws Exception {
INDArray in1 = Nd4j.create(DataType.BFLOAT16, 2, 3, 10, 1);//Nd4j.createFromArray(0.2019043,0.6464844,0.9116211,0.60058594,0.34033203,0.7036133,0.6772461,0.3815918,0.87353516,0.04650879,0.67822266,0.8618164,0.88378906,0.7573242,0.66796875,0.63427734,0.33764648,0.46923828,0.62939453,0.76464844,-0.8618164,-0.94873047,-0.9902344,-0.88916016,-0.86572266,-0.92089844,-0.90722656,-0.96533203,-0.97509766,-0.4975586,-0.84814453,-0.984375,-0.98828125,-0.95458984,-0.9472656,-0.91064453,-0.80859375,-0.83496094,-0.9140625,-0.82470703,0.4802246,0.45361328,0.28125,0.28320312,0.79345703,0.44604492,-0.30273438,0.11730957,0.56396484,0.73583984,0.1418457,-0.44848633,0.6923828,-0.40234375,0.40185547,0.48632812,0.14538574,0.4638672,0.13000488,0.5058594)
//.castTo(DataType.BFLOAT16).reshape(2,3,10,1);
INDArray in2 = Nd4j.create(DataType.BFLOAT16, 2, 3, 10, 1); //Nd4j.createFromArray(0.0,-0.13391113,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.1751709,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.51904297,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5107422,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0)
//.castTo(DataType.BFLOAT16).reshape(2,3,10,1);
INDArray out = in1.ulike();
Nd4j.exec(DynamicCustomOp.builder("maxpool2d_bp")
.addInputs(in1, in2)
.addOutputs(out)
.addIntegerArguments(5,1,1,2,2,0,1,1,1,0,0)
.build());
Nd4j.getExecutioner().commit();
}
}