[WIP] CUDA concat tweak (#148)

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* one special test

Signed-off-by: raver119 <raver119@gmail.com>

* local memory for concat

Signed-off-by: raver119 <raver119@gmail.com>

* fixed grid size for concat

Signed-off-by: raver119 <raver119@gmail.com>

* fixed grid size for concat

Signed-off-by: raver119 <raver119@gmail.com>

* test commented out

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-12-24 17:01:03 +03:00 committed by GitHub
parent 39d43ca170
commit d1e5e79c10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 29 deletions

View File

@ -39,13 +39,10 @@ template<typename T>
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
T* z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
__shared__ Nd4jLong zLen, totalThreads;
__shared__ int rank;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
@ -54,27 +51,26 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid >= zLen)
return;
Nd4jLong coords[MAX_RANK];
auto coords = sharedMem + threadIdx.x * rank;
for (uint64_t i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, coords);
shape::index2coords(tid, zShapeInfo, coords);
const auto zOffset = shape::getOffset(zShapeInfo, coords);
const auto zOffset = shape::getOffset(zShapeInfo, coords);
int inArrIdx = 0;
Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong **>(pxShapeInfo)[inArrIdx];
int inArrIdx = 0;
Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[inArrIdx];
while (coords[axis] >= xShapeInfo[axis + 1]) {
coords[axis] -= xShapeInfo[axis + 1];
xShapeInfo = reinterpret_cast<Nd4jLong **>(pxShapeInfo)[++inArrIdx];
}
while(coords[axis] >= xShapeInfo[axis + 1]) {
coords[axis] -= xShapeInfo[axis + 1];
xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[++inArrIdx];
const auto *x = reinterpret_cast<T *>(reinterpret_cast<void **>(pVx)[inArrIdx]);
const auto xOffset = shape::getOffset(xShapeInfo, coords);
z[zOffset] = x[xOffset];
}
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[inArrIdx]);
const auto xOffset = shape::getOffset(xShapeInfo, coords);
z[zOffset] = x[xOffset];
}
///////////////////////////////////////////////////////////////////
@ -89,9 +85,9 @@ BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid
//////////////////////////////////////////////////////////////////////////
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
const int threadsPerBlock = 256;
const int blocksPerGrid = 512;
const int sharedMem = 512;
const int numOfArrs = inArrs.size();

View File

@ -62,19 +62,31 @@ public:
/*
TEST_F(PlaygroundTests, test_s_1) {
auto x = NDArrayFactory::create<float>('c', {32,112,112,16});
auto y = NDArrayFactory::create<float>('c', {16});
auto z = x.ulike();
auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176});
auto x1 = x0.ulike();
auto x2 = x0.ulike();
auto x3 = x0.ulike();
auto x4 = x0.ulike();
auto x5 = x0.ulike();
auto y = NDArrayFactory::create<int >(3);
auto z = NDArrayFactory::create<float>('c', {32, 7, 7, 1056});
Context ctx(1);
ctx.setInputArray(0, &x);
ctx.setInputArray(1, &y);
ctx.setInputArray(0, &x0);
ctx.setInputArray(1, &x1);
ctx.setInputArray(2, &x2);
ctx.setInputArray(3, &x3);
ctx.setInputArray(4, &x4);
ctx.setInputArray(5, &x5);
ctx.setInputArray(6, &y);
ctx.setOutputArray(0, &z);
ctx.setBArguments({true});
std::vector<Nd4jLong> values;
nd4j::ops::biasadd op;
nd4j::ops::concat op;
op.execute(&ctx);
for (int e = 0; e < 1000; e++) {