Shyrma concat2 (#119)

* - rewrite/improve concat

Signed-off-by: Yurii <yurii@skymind.io>

* - ged rid of unnecessary argument in concat kernel

Signed-off-by: Yurii <yurii@skymind.io>
master
Yurii Shyrma 2019-08-16 06:57:20 +03:00 committed by raver119
parent b370544b8f
commit 9f2ba6a85d
6 changed files with 266 additions and 274 deletions

View File

@ -29,104 +29,102 @@
#include <PointersManager.h>
#include <ConstantTadHelper.h>
namespace nd4j {
namespace ops {
namespace helpers {
namespace nd4j {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
template<typename T>
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
__shared__ int arrIdx, blocksPerArr;
T* z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
__shared__ int rank;
if (threadIdx.x == 0) {
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
arrIdx = blockIdx.x / blocksPerArr;
}
__syncthreads();
for(int j = arrIdx; j < numOfArrs; j += gridDim.x) {
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[j]);
auto* z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[j]);
const auto* xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[j];
const auto* zShapeInfo = reinterpret_cast<Nd4jLong**>(pzShapeInfo)[j];
const auto arrLen = shape::length(xShapeInfo);
const auto arrLenPerBlock = (arrLen + blocksPerArr - 1) / blocksPerArr; // ceil
const auto start = (blockIdx.x % blocksPerArr) * arrLenPerBlock;
const auto end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock);
for (Nd4jLong i = start + threadIdx.x; i < end; i += blockDim.x)
z[shape::getIndexOffset(i, zShapeInfo, arrLen)] = x[shape::getIndexOffset(i, xShapeInfo, arrLen)];
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
concatCuda<T><<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
}
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
const int numOfArrs = inArrs.size();
for(int i = 0; i < numOfArrs; ++i)
if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice();
const int rank = inArrs[0]->rankOf();
const int rank2 = 2*rank;
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
// take into account indices for first array
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
// loop through the rest of input arrays
for(int i = 1; i < numOfArrs; ++i) {
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
}
std::vector<NDArray*> outSubArrs(numOfArrs);
for(int i = 0; i < numOfArrs; ++i)
outSubArrs[i] = new NDArray(output(indices[i], true));
// prepare arrays of pointers on buffers and shapes
std::vector<void*> hOutBuffers(numOfArrs), hInBuffers(numOfArrs);
std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs);
for(int i = 0; i < numOfArrs; ++i) {
hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer();
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo();
hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo();
}
// allocate and copy all buffers and shapes arrays to global memory
PointersManager manager(context, "helpers::concat");
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
manager.synchronize();
for(int i = 0; i < numOfArrs; ++i)
delete outSubArrs[i];
for(int i = 0; i < numOfArrs; ++i)
inArrs[i]->tickReadHost();
output.tickWriteDevice();
}
}
zLen = shape::length(zShapeInfo);
rank = shape::rank(zShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid >= zLen)
return;
auto coords = sharedMem + threadIdx.x * rank;
shape::index2coords(rank, zShapeInfo + 1, tid, zLen, coords);
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
int inArrIdx = 0;
Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[inArrIdx];
while(coords[axis] >= xShapeInfo[axis + 1]) {
coords[axis] -= xShapeInfo[axis + 1];
xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[++inArrIdx];
}
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[inArrIdx]);
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
z[zOffset] = x[xOffset];
}
///////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void concatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
concatCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(pVx, pxShapeInfo, vz, zShapeInfo, axis);
}
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
//////////////////////////////////////////////////////////////////////////
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
const int threadsPerBlock = MAX_NUM_THREADS / 4;
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
const int numOfArrs = inArrs.size();
for(int i = 0; i < numOfArrs; ++i)
inArrs[i]->syncToDevice();
output.syncToDevice();
// prepare arrays of pointers on buffers and shapes
std::vector<void*> hInBuffers(numOfArrs);
std::vector<Nd4jLong*> hInShapeInfo(numOfArrs);
for(int i = 0; i < numOfArrs; ++i) {
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo();
}
PointersManager manager(context, "helpers::concat");
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES);
manager.synchronize();
for(int i = 0; i < numOfArrs; ++i)
inArrs[i]->tickReadDevice();
output.tickWriteDevice();
}
}
}
}

View File

@ -750,59 +750,6 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) {
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, concat_test10) {
NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32);
NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32);
NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32);
x0 = 0.;
x1 = 1.;
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, concat_14) {
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
NDArray output('f', {2,6}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
// output.printBuffer();
// output.printIndexedBuffer();
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, concat_15) {
NDArray x0('c', {1,4}, {1,2,3,4});
NDArray x1('c', {1,4}, {5,6,7,8});
NDArray output('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
// output.printBuffer();
// output.printIndexedBuffer();
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {

View File

@ -364,77 +364,6 @@ TEST_F(DeclarableOpsTests15, test_rank_2) {
delete result;
}
TEST_F(DeclarableOpsTests15, test_concat_column_1) {
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
auto z = NDArrayFactory::create<double>('c', {2, 2});
nd4j::ops::concat op;
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
ASSERT_EQ(Status::OK(), status);
z.printIndexedBuffer("z");
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests15, test_concat_large_1) {
std::array<NDArray*, 2000> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 300});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {2000, 300});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++) {
auto row = z.tensorAlongDimension(e, {1});
ASSERT_NEAR((float) e, row->e<float>(0), 1e-5f);
delete row;
}
}
TEST_F(DeclarableOpsTests15, test_concat_large_2) {
std::array<NDArray*, 10> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 5, 20});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {arrays.size(), 5, 20});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++) {
auto row = z.tensorAlongDimension(e, {1, 2});
ASSERT_NEAR((float) e, row->meanNumber().e<float>(0), 1e-5f);
delete row;
}
}
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});

View File

@ -373,35 +373,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) {
delete result;
}
TEST_F(DeclarableOpsTests2, Test_Concat_3D_1) {
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
x0.assign(1.0);
x1.assign(2.0);
x2.assign(3.0);
x3.assign(4.0);
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(4 == numOfTads);
for (int e = 0; e < numOfTads; e++) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((double) e+1, mean, 1e-5);
}
delete result;
}
TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) {
auto A = NDArrayFactory::create<float>('c', {3, 3});
auto B = NDArrayFactory::create<float>('c', {3, 1});

View File

@ -2845,31 +2845,4 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
delete result;
}
TEST_F(DeclarableOpsTests6, concat_test14) {
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
x0 = 1.;
x1 = 2.;
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printShapeInfo();
// z->printIndexedBuffer();
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(2 == numOfTads);
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
delete result;
}

View File

@ -584,6 +584,180 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test17) {
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
x0 = 1.;
x1 = 2.;
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
// z->printShapeInfo();
// z->printIndexedBuffer();
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(2 == numOfTads);
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
delete result;
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test18) {
std::array<NDArray*, 2000> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 300});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {2000, 300});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++) {
auto row = z.tensorAlongDimension(e, {1});
ASSERT_NEAR((float) e, row->e<float>(0), 1e-5f);
delete row;
}
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test19) {
std::array<NDArray*, 10> arrays;
Context context(1);
Nd4jLong axis = 0;
// we crate bunch of arrays, filled with specific values
for (int e = 0; e < arrays.size(); e++) {
auto array = NDArrayFactory::create_<float>('c', {1, 5, 20});
array->assign(e);
context.setInputArray(e, array, true);
}
auto z = NDArrayFactory::create<float>('c', {arrays.size(), 5, 20});
context.setOutputArray(0, &z, false);
context.setIArguments(&axis, 1);
nd4j::ops::concat op;
op.execute(&context);
for (int e = 0; e < arrays.size(); e++)
ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e<float>(0), 1e-5f);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test20) {
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
x0.assign(1.0);
x1.assign(2.0);
x2.assign(3.0);
x3.assign(4.0);
nd4j::ops::concat op;
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
ASSERT_TRUE(4 == numOfTads);
for (int e = 0; e < numOfTads; e++) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((double) e+1, mean, 1e-5);
}
delete result;
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test21) {
NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32);
NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32);
NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32);
x0 = 0.;
x1 = 1.;
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test22) {
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
NDArray output('f', {2,6}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test23) {
NDArray x0('c', {1,4}, {1,2,3,4});
NDArray x1('c', {1,4}, {5,6,7,8});
NDArray output('c', {2,4}, nd4j::DataType::DOUBLE);
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
nd4j::ops::concat op;
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test24) {
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
auto z = NDArrayFactory::create<double>('c', {2, 2});
nd4j::ops::concat op;
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, tile_bp_test1) {