[WIP] CUDA concat tweak (#148)
* one special test Signed-off-by: raver119 <raver119@gmail.com> * one special test Signed-off-by: raver119 <raver119@gmail.com> * local memory for concat Signed-off-by: raver119 <raver119@gmail.com> * fixed grid size for concat Signed-off-by: raver119 <raver119@gmail.com> * fixed grid size for concat Signed-off-by: raver119 <raver119@gmail.com> * test commented out Signed-off-by: raver119 <raver119@gmail.com>master
parent
39d43ca170
commit
d1e5e79c10
|
@ -39,13 +39,10 @@ template<typename T>
|
||||||
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
|
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
|
||||||
|
|
||||||
T* z = reinterpret_cast<T*>(vz);
|
T* z = reinterpret_cast<T*>(vz);
|
||||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
__shared__ Nd4jLong zLen, totalThreads;
|
||||||
__shared__ int rank;
|
__shared__ int rank;
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
extern __shared__ unsigned char shmem[];
|
|
||||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
|
||||||
|
|
||||||
zLen = shape::length(zShapeInfo);
|
zLen = shape::length(zShapeInfo);
|
||||||
rank = shape::rank(zShapeInfo);
|
rank = shape::rank(zShapeInfo);
|
||||||
totalThreads = gridDim.x * blockDim.x;
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
|
@ -54,12 +51,10 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
if(tid >= zLen)
|
Nd4jLong coords[MAX_RANK];
|
||||||
return;
|
|
||||||
|
|
||||||
auto coords = sharedMem + threadIdx.x * rank;
|
for (uint64_t i = tid; i < zLen; i += totalThreads) {
|
||||||
|
shape::index2coords(i, zShapeInfo, coords);
|
||||||
shape::index2coords(tid, zShapeInfo, coords);
|
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
||||||
|
|
||||||
|
@ -76,6 +71,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jL
|
||||||
|
|
||||||
z[zOffset] = x[xOffset];
|
z[zOffset] = x[xOffset];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
@ -89,9 +85,9 @@ BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
const int threadsPerBlock = 256;
|
||||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = 512;
|
||||||
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
|
const int sharedMem = 512;
|
||||||
|
|
||||||
const int numOfArrs = inArrs.size();
|
const int numOfArrs = inArrs.size();
|
||||||
|
|
||||||
|
|
|
@ -62,19 +62,31 @@ public:
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(PlaygroundTests, test_s_1) {
|
TEST_F(PlaygroundTests, test_s_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {32,112,112,16});
|
auto x0 = NDArrayFactory::create<float>('c', {32, 7, 7, 176});
|
||||||
auto y = NDArrayFactory::create<float>('c', {16});
|
auto x1 = x0.ulike();
|
||||||
auto z = x.ulike();
|
auto x2 = x0.ulike();
|
||||||
|
auto x3 = x0.ulike();
|
||||||
|
auto x4 = x0.ulike();
|
||||||
|
auto x5 = x0.ulike();
|
||||||
|
|
||||||
|
auto y = NDArrayFactory::create<int >(3);
|
||||||
|
auto z = NDArrayFactory::create<float>('c', {32, 7, 7, 1056});
|
||||||
|
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
ctx.setInputArray(0, &x);
|
ctx.setInputArray(0, &x0);
|
||||||
ctx.setInputArray(1, &y);
|
ctx.setInputArray(1, &x1);
|
||||||
|
ctx.setInputArray(2, &x2);
|
||||||
|
ctx.setInputArray(3, &x3);
|
||||||
|
ctx.setInputArray(4, &x4);
|
||||||
|
ctx.setInputArray(5, &x5);
|
||||||
|
|
||||||
|
ctx.setInputArray(6, &y);
|
||||||
ctx.setOutputArray(0, &z);
|
ctx.setOutputArray(0, &z);
|
||||||
|
ctx.setBArguments({true});
|
||||||
|
|
||||||
std::vector<Nd4jLong> values;
|
std::vector<Nd4jLong> values;
|
||||||
|
|
||||||
|
nd4j::ops::concat op;
|
||||||
nd4j::ops::biasadd op;
|
|
||||||
op.execute(&ctx);
|
op.execute(&ctx);
|
||||||
|
|
||||||
for (int e = 0; e < 1000; e++) {
|
for (int e = 0; e < 1000; e++) {
|
||||||
|
|
Loading…
Reference in New Issue