From d1e5e79c1079845efb55be654295a0a482a9326c Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 24 Dec 2019 17:01:03 +0300 Subject: [PATCH] [WIP] CUDA concat tweak (#148) * one special test Signed-off-by: raver119 * one special test Signed-off-by: raver119 * local memory for concat Signed-off-by: raver119 * fixed grid size for concat Signed-off-by: raver119 * fixed grid size for concat Signed-off-by: raver119 * test commented out Signed-off-by: raver119 --- .../ops/declarable/helpers/cuda/concat.cu | 40 +++++++++---------- .../layers_tests/PlaygroundTests.cpp | 26 ++++++++---- 2 files changed, 37 insertions(+), 29 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index 6f9a8c6ab..43c0e4af9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -39,13 +39,10 @@ template __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { T* z = reinterpret_cast(vz); - __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + __shared__ Nd4jLong zLen, totalThreads; __shared__ int rank; if (threadIdx.x == 0) { - extern __shared__ unsigned char shmem[]; - sharedMem = reinterpret_cast(shmem); - zLen = shape::length(zShapeInfo); rank = shape::rank(zShapeInfo); totalThreads = gridDim.x * blockDim.x; @@ -54,27 +51,26 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL const auto tid = blockIdx.x * blockDim.x + threadIdx.x; - if(tid >= zLen) - return; + Nd4jLong coords[MAX_RANK]; - auto coords = sharedMem + threadIdx.x * rank; + for (uint64_t i = tid; i < zLen; i += totalThreads) { + shape::index2coords(i, zShapeInfo, coords); - shape::index2coords(tid, zShapeInfo, coords); + const auto zOffset = shape::getOffset(zShapeInfo, coords); - const auto zOffset = shape::getOffset(zShapeInfo, coords); + int inArrIdx = 0; + Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; - int inArrIdx = 0; - Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; + while (coords[axis] >= xShapeInfo[axis + 1]) { + coords[axis] -= xShapeInfo[axis + 1]; + xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + } - while(coords[axis] >= xShapeInfo[axis + 1]) { - coords[axis] -= xShapeInfo[axis + 1]; - xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + const auto *x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); + const auto xOffset = shape::getOffset(xShapeInfo, coords); + + z[zOffset] = x[xOffset]; } - - const auto* x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); - const auto xOffset = shape::getOffset(xShapeInfo, coords); - - z[zOffset] = x[xOffset]; } /////////////////////////////////////////////////////////////////// @@ -89,9 +85,9 @@ BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid ////////////////////////////////////////////////////////////////////////// void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - const int threadsPerBlock = MAX_NUM_THREADS / 4; - const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; - const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128; + const int threadsPerBlock = 256; + const int blocksPerGrid = 512; + const int sharedMem = 512; const int numOfArrs = inArrs.size(); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index d35d736ed..410ec53a7 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -62,19 +62,31 @@ public: /* TEST_F(PlaygroundTests, test_s_1) { - auto x = NDArrayFactory::create('c', {32,112,112,16}); - auto y = NDArrayFactory::create('c', {16}); - auto z = x.ulike(); + auto x0 = NDArrayFactory::create('c', {32, 7, 7, 176}); + auto x1 = x0.ulike(); + auto x2 = x0.ulike(); + auto x3 = x0.ulike(); + auto x4 = x0.ulike(); + auto x5 = x0.ulike(); + + auto y = NDArrayFactory::create(3); + auto z = NDArrayFactory::create('c', {32, 7, 7, 1056}); Context ctx(1); - ctx.setInputArray(0, &x); - ctx.setInputArray(1, &y); + ctx.setInputArray(0, &x0); + ctx.setInputArray(1, &x1); + ctx.setInputArray(2, &x2); + ctx.setInputArray(3, &x3); + ctx.setInputArray(4, &x4); + ctx.setInputArray(5, &x5); + + ctx.setInputArray(6, &y); ctx.setOutputArray(0, &z); + ctx.setBArguments({true}); std::vector values; - - nd4j::ops::biasadd op; + nd4j::ops::concat op; op.execute(&ctx); for (int e = 0; e < 1000; e++) {