[WIP] CUDA concat tweak (#148)

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* local memory for concat

Signed-off-by: raver119 <raver119@gmail.com>

* fixed grid size for concat

Signed-off-by: raver119 <raver119@gmail.com>

* fixed grid size for concat

Signed-off-by: raver119 <raver119@gmail.com>

* test commented out

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-24 17:01:03 +03:00 committed by GitHub
parent 39d43ca170
commit d1e5e79c10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 29 deletions

View File

@ -39,13 +39,10 @@ template<typename T>
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
T* z = reinterpret_cast<T*>(vz); T* z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; __shared__ Nd4jLong zLen, totalThreads;
__shared__ int rank; __shared__ int rank;
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
zLen = shape::length(zShapeInfo); zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo); rank = shape::rank(zShapeInfo);
totalThreads = gridDim.x * blockDim.x; totalThreads = gridDim.x * blockDim.x;
@ -54,12 +51,10 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid >= zLen) Nd4jLong coords[MAX_RANK];
return;
auto coords = sharedMem + threadIdx.x * rank; for (uint64_t i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, coords);
shape::index2coords(tid, zShapeInfo, coords);
const auto zOffset = shape::getOffset(zShapeInfo, coords); const auto zOffset = shape::getOffset(zShapeInfo, coords);
@ -76,6 +71,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL
z[zOffset] = x[xOffset]; z[zOffset] = x[xOffset];
} }
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
@ -89,9 +85,9 @@ BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) { void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
const int threadsPerBlock = MAX_NUM_THREADS / 4; const int threadsPerBlock = 256;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = 512;
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128; const int sharedMem = 512;
const int numOfArrs = inArrs.size(); const int numOfArrs = inArrs.size();

View File

@ -62,19 +62,31 @@ public:
/* /*
TEST_F(PlaygroundTests, test_s_1) { TEST_F(PlaygroundTests, test_s_1) {
auto x = NDArrayFactory::create<float>('c', {32,112,112,16}); auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176});
auto y = NDArrayFactory::create<float>('c', {16}); auto x1 = x0.ulike();
auto z = x.ulike(); auto x2 = x0.ulike();
auto x3 = x0.ulike();
auto x4 = x0.ulike();
auto x5 = x0.ulike();
auto y = NDArrayFactory::create<int >(3);
auto z = NDArrayFactory::create<float>('c', {32, 7, 7, 1056});
Context ctx(1); Context ctx(1);
ctx.setInputArray(0, &x); ctx.setInputArray(0, &x0);
ctx.setInputArray(1, &y); ctx.setInputArray(1, &x1);
ctx.setInputArray(2, &x2);
ctx.setInputArray(3, &x3);
ctx.setInputArray(4, &x4);
ctx.setInputArray(5, &x5);
ctx.setInputArray(6, &y);
ctx.setOutputArray(0, &z); ctx.setOutputArray(0, &z);
ctx.setBArguments({true});
std::vector<Nd4jLong> values; std::vector<Nd4jLong> values;
nd4j::ops::concat op;
nd4j::ops::biasadd op;
op.execute(&ctx); op.execute(&ctx);
for (int e = 0; e < 1000; e++) { for (int e = 0; e < 1000; e++) {