[WIP] CUDA concat tweak (#148)
* one special test Signed-off-by: raver119 <raver119@gmail.com> * one special test Signed-off-by: raver119 <raver119@gmail.com> * local memory for concat Signed-off-by: raver119 <raver119@gmail.com> * fixed grid size for concat Signed-off-by: raver119 <raver119@gmail.com> * fixed grid size for concat Signed-off-by: raver119 <raver119@gmail.com> * test commented out Signed-off-by: raver119 <raver119@gmail.com>master
parent
39d43ca170
commit
d1e5e79c10
|
@ -39,13 +39,10 @@ template<typename T>
|
|||
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
|
||||
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
||||
__shared__ Nd4jLong zLen, totalThreads;
|
||||
__shared__ int rank;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||
|
||||
zLen = shape::length(zShapeInfo);
|
||||
rank = shape::rank(zShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
|
@ -54,27 +51,26 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL
|
|||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if(tid >= zLen)
|
||||
return;
|
||||
Nd4jLong coords[MAX_RANK];
|
||||
|
||||
auto coords = sharedMem + threadIdx.x * rank;
|
||||
|
||||
shape::index2coords(tid, zShapeInfo, coords);
|
||||
for (uint64_t i = tid; i < zLen; i += totalThreads) {
|
||||
shape::index2coords(i, zShapeInfo, coords);
|
||||
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
||||
|
||||
int inArrIdx = 0;
|
||||
Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[inArrIdx];
|
||||
Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong **>(pxShapeInfo)[inArrIdx];
|
||||
|
||||
while(coords[axis] >= xShapeInfo[axis + 1]) {
|
||||
while (coords[axis] >= xShapeInfo[axis + 1]) {
|
||||
coords[axis] -= xShapeInfo[axis + 1];
|
||||
xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[++inArrIdx];
|
||||
xShapeInfo = reinterpret_cast<Nd4jLong **>(pxShapeInfo)[++inArrIdx];
|
||||
}
|
||||
|
||||
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[inArrIdx]);
|
||||
const auto *x = reinterpret_cast<T *>(reinterpret_cast<void **>(pVx)[inArrIdx]);
|
||||
const auto xOffset = shape::getOffset(xShapeInfo, coords);
|
||||
|
||||
z[zOffset] = x[xOffset];
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -89,9 +85,9 @@ BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
|
||||
const int threadsPerBlock = 256;
|
||||
const int blocksPerGrid = 512;
|
||||
const int sharedMem = 512;
|
||||
|
||||
const int numOfArrs = inArrs.size();
|
||||
|
||||
|
|
|
@ -62,19 +62,31 @@ public:
|
|||
|
||||
/*
|
||||
TEST_F(PlaygroundTests, test_s_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {32,112,112,16});
|
||||
auto y = NDArrayFactory::create<float>('c', {16});
|
||||
auto z = x.ulike();
|
||||
auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176});
|
||||
auto x1 = x0.ulike();
|
||||
auto x2 = x0.ulike();
|
||||
auto x3 = x0.ulike();
|
||||
auto x4 = x0.ulike();
|
||||
auto x5 = x0.ulike();
|
||||
|
||||
auto y = NDArrayFactory::create<int >(3);
|
||||
auto z = NDArrayFactory::create<float>('c', {32, 7, 7, 1056});
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &x);
|
||||
ctx.setInputArray(1, &y);
|
||||
ctx.setInputArray(0, &x0);
|
||||
ctx.setInputArray(1, &x1);
|
||||
ctx.setInputArray(2, &x2);
|
||||
ctx.setInputArray(3, &x3);
|
||||
ctx.setInputArray(4, &x4);
|
||||
ctx.setInputArray(5, &x5);
|
||||
|
||||
ctx.setInputArray(6, &y);
|
||||
ctx.setOutputArray(0, &z);
|
||||
ctx.setBArguments({true});
|
||||
|
||||
std::vector<Nd4jLong> values;
|
||||
|
||||
|
||||
nd4j::ops::biasadd op;
|
||||
nd4j::ops::concat op;
|
||||
op.execute(&ctx);
|
||||
|
||||
for (int e = 0; e < 1000; e++) {
|
||||
|
|
Loading…
Reference in New Issue