Shyrma concat2 (#119)
* - rewrite/improve concat Signed-off-by: Yurii <yurii@skymind.io> * - ged rid of unnecessary argument in concat kernel Signed-off-by: Yurii <yurii@skymind.io>master
parent
b370544b8f
commit
9f2ba6a85d
|
@ -30,103 +30,101 @@
|
|||
#include <ConstantTadHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
|
||||
template<typename T>
|
||||
__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
|
||||
|
||||
__shared__ int arrIdx, blocksPerArr;
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
__shared__ Nd4jLong zLen, totalThreads, *sharedMem;
|
||||
__shared__ int rank;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||
|
||||
blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil
|
||||
arrIdx = blockIdx.x / blocksPerArr;
|
||||
zLen = shape::length(zShapeInfo);
|
||||
rank = shape::rank(zShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(int j = arrIdx; j < numOfArrs; j += gridDim.x) {
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[j]);
|
||||
auto* z = reinterpret_cast<T*>(reinterpret_cast<void**>(pVz)[j]);
|
||||
const auto* xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[j];
|
||||
const auto* zShapeInfo = reinterpret_cast<Nd4jLong**>(pzShapeInfo)[j];
|
||||
if(tid >= zLen)
|
||||
return;
|
||||
|
||||
const auto arrLen = shape::length(xShapeInfo);
|
||||
auto coords = sharedMem + threadIdx.x * rank;
|
||||
|
||||
const auto arrLenPerBlock = (arrLen + blocksPerArr - 1) / blocksPerArr; // ceil
|
||||
shape::index2coords(rank, zShapeInfo + 1, tid, zLen, coords);
|
||||
|
||||
const auto start = (blockIdx.x % blocksPerArr) * arrLenPerBlock;
|
||||
const auto end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock);
|
||||
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||
|
||||
for (Nd4jLong i = start + threadIdx.x; i < end; i += blockDim.x)
|
||||
z[shape::getIndexOffset(i, zShapeInfo, arrLen)] = x[shape::getIndexOffset(i, xShapeInfo, arrLen)];
|
||||
}
|
||||
int inArrIdx = 0;
|
||||
Nd4jLong *xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[inArrIdx];
|
||||
|
||||
while(coords[axis] >= xShapeInfo[axis + 1]) {
|
||||
coords[axis] -= xShapeInfo[axis + 1];
|
||||
xShapeInfo = reinterpret_cast<Nd4jLong**>(pxShapeInfo)[++inArrIdx];
|
||||
}
|
||||
|
||||
const auto* x = reinterpret_cast<T*>(reinterpret_cast<void**>(pVx)[inArrIdx]);
|
||||
const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank);
|
||||
|
||||
z[zOffset] = x[xOffset];
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) {
|
||||
template<typename T>
|
||||
__host__ static void concatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||
void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) {
|
||||
|
||||
concatCuda<T><<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES);
|
||||
concatCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(pVx, pxShapeInfo, vz, zShapeInfo, axis);
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 4;
|
||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128;
|
||||
|
||||
const int numOfArrs = inArrs.size();
|
||||
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice();
|
||||
inArrs[i]->syncToDevice();
|
||||
|
||||
const int rank = inArrs[0]->rankOf();
|
||||
const int rank2 = 2*rank;
|
||||
std::vector<std::vector<Nd4jLong>> indices(numOfArrs, std::vector<Nd4jLong>(rank2,0));
|
||||
|
||||
// take into account indices for first array
|
||||
indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis);
|
||||
|
||||
// loop through the rest of input arrays
|
||||
for(int i = 1; i < numOfArrs; ++i) {
|
||||
indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from
|
||||
indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding)
|
||||
}
|
||||
|
||||
std::vector<NDArray*> outSubArrs(numOfArrs);
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
outSubArrs[i] = new NDArray(output(indices[i], true));
|
||||
output.syncToDevice();
|
||||
|
||||
// prepare arrays of pointers on buffers and shapes
|
||||
std::vector<void*> hOutBuffers(numOfArrs), hInBuffers(numOfArrs);
|
||||
std::vector<Nd4jLong*> hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs);
|
||||
std::vector<void*> hInBuffers(numOfArrs);
|
||||
std::vector<Nd4jLong*> hInShapeInfo(numOfArrs);
|
||||
|
||||
for(int i = 0; i < numOfArrs; ++i) {
|
||||
hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer();
|
||||
hInBuffers[i] = inArrs[i]->getSpecialBuffer();
|
||||
hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo();
|
||||
hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo();
|
||||
}
|
||||
|
||||
// allocate and copy all buffers and shapes arrays to global memory
|
||||
PointersManager manager(context, "helpers::concat");
|
||||
void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*));
|
||||
|
||||
void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*));
|
||||
void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*));
|
||||
void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*));
|
||||
|
||||
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES);
|
||||
|
||||
manager.synchronize();
|
||||
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
delete outSubArrs[i];
|
||||
|
||||
for(int i = 0; i < numOfArrs; ++i)
|
||||
inArrs[i]->tickReadHost();
|
||||
inArrs[i]->tickReadDevice();
|
||||
|
||||
output.tickWriteDevice();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -750,59 +750,6 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) {
|
|||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, concat_test10) {
|
||||
|
||||
NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x0 = 0.;
|
||||
x1 = 1.;
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, concat_14) {
|
||||
|
||||
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
|
||||
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
|
||||
NDArray output('f', {2,6}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||
|
||||
nd4j::ops::concat op;
|
||||
|
||||
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
// output.printBuffer();
|
||||
// output.printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, concat_15) {
|
||||
|
||||
NDArray x0('c', {1,4}, {1,2,3,4});
|
||||
NDArray x1('c', {1,4}, {5,6,7,8});
|
||||
NDArray output('c', {2,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
|
||||
|
||||
nd4j::ops::concat op;
|
||||
|
||||
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
// output.printBuffer();
|
||||
// output.printIndexedBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests12, reduceMeanBp_4) {
|
||||
|
||||
|
|
|
@ -364,77 +364,6 @@ TEST_F(DeclarableOpsTests15, test_rank_2) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_concat_column_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
|
||||
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
|
||||
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
|
||||
auto z = NDArrayFactory::create<double>('c', {2, 2});
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
z.printIndexedBuffer("z");
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_concat_large_1) {
|
||||
std::array<NDArray*, 2000> arrays;
|
||||
Context context(1);
|
||||
Nd4jLong axis = 0;
|
||||
|
||||
// we crate bunch of arrays, filled with specific values
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
auto array = NDArrayFactory::create_<float>('c', {1, 300});
|
||||
array->assign(e);
|
||||
context.setInputArray(e, array, true);
|
||||
}
|
||||
|
||||
auto z = NDArrayFactory::create<float>('c', {2000, 300});
|
||||
context.setOutputArray(0, &z, false);
|
||||
context.setIArguments(&axis, 1);
|
||||
|
||||
nd4j::ops::concat op;
|
||||
op.execute(&context);
|
||||
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
auto row = z.tensorAlongDimension(e, {1});
|
||||
|
||||
ASSERT_NEAR((float) e, row->e<float>(0), 1e-5f);
|
||||
|
||||
delete row;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_concat_large_2) {
|
||||
std::array<NDArray*, 10> arrays;
|
||||
Context context(1);
|
||||
Nd4jLong axis = 0;
|
||||
|
||||
// we crate bunch of arrays, filled with specific values
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
auto array = NDArrayFactory::create_<float>('c', {1, 5, 20});
|
||||
array->assign(e);
|
||||
context.setInputArray(e, array, true);
|
||||
}
|
||||
|
||||
auto z = NDArrayFactory::create<float>('c', {arrays.size(), 5, 20});
|
||||
context.setOutputArray(0, &z, false);
|
||||
context.setIArguments(&axis, 1);
|
||||
|
||||
nd4j::ops::concat op;
|
||||
op.execute(&context);
|
||||
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
auto row = z.tensorAlongDimension(e, {1, 2});
|
||||
|
||||
ASSERT_NEAR((float) e, row->meanNumber().e<float>(0), 1e-5f);
|
||||
|
||||
delete row;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
||||
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
|
||||
auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});
|
||||
|
|
|
@ -373,35 +373,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests2, Test_Concat_3D_1) {
|
||||
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
|
||||
x0.assign(1.0);
|
||||
x1.assign(2.0);
|
||||
x2.assign(3.0);
|
||||
x3.assign(4.0);
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
|
||||
ASSERT_TRUE(4 == numOfTads);
|
||||
|
||||
for (int e = 0; e < numOfTads; e++) {
|
||||
NDArray tad = (*z)(e, {0});
|
||||
auto mean = tad.meanNumber().e<double>(0);
|
||||
ASSERT_NEAR((double) e+1, mean, 1e-5);
|
||||
}
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) {
|
||||
auto A = NDArrayFactory::create<float>('c', {3, 3});
|
||||
auto B = NDArrayFactory::create<float>('c', {3, 1});
|
||||
|
|
|
@ -2845,31 +2845,4 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, concat_test14) {
|
||||
|
||||
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
|
||||
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
|
||||
|
||||
x0 = 1.;
|
||||
x1 = 2.;
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto result = op.execute({&x0, &x1}, {}, {0}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
// z->printShapeInfo();
|
||||
// z->printIndexedBuffer();
|
||||
|
||||
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
|
||||
ASSERT_TRUE(2 == numOfTads);
|
||||
|
||||
for (int e = 0; e < numOfTads; ++e) {
|
||||
NDArray tad = (*z)(e, {0});
|
||||
auto mean = tad.meanNumber().e<double>(0);
|
||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||
}
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
|
|
|
@ -584,6 +584,180 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test17) {
|
||||
|
||||
NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
|
||||
NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE);
|
||||
|
||||
x0 = 1.;
|
||||
x1 = 2.;
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto result = op.execute({&x0, &x1}, {}, {0}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
// z->printShapeInfo();
|
||||
// z->printIndexedBuffer();
|
||||
|
||||
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
|
||||
ASSERT_TRUE(2 == numOfTads);
|
||||
|
||||
for (int e = 0; e < numOfTads; ++e) {
|
||||
NDArray tad = (*z)(e, {0});
|
||||
auto mean = tad.meanNumber().e<double>(0);
|
||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||
}
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test18) {
|
||||
std::array<NDArray*, 2000> arrays;
|
||||
Context context(1);
|
||||
Nd4jLong axis = 0;
|
||||
|
||||
// we crate bunch of arrays, filled with specific values
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
auto array = NDArrayFactory::create_<float>('c', {1, 300});
|
||||
array->assign(e);
|
||||
context.setInputArray(e, array, true);
|
||||
}
|
||||
|
||||
auto z = NDArrayFactory::create<float>('c', {2000, 300});
|
||||
context.setOutputArray(0, &z, false);
|
||||
context.setIArguments(&axis, 1);
|
||||
|
||||
nd4j::ops::concat op;
|
||||
op.execute(&context);
|
||||
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
auto row = z.tensorAlongDimension(e, {1});
|
||||
|
||||
ASSERT_NEAR((float) e, row->e<float>(0), 1e-5f);
|
||||
|
||||
delete row;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test19) {
|
||||
|
||||
std::array<NDArray*, 10> arrays;
|
||||
Context context(1);
|
||||
Nd4jLong axis = 0;
|
||||
|
||||
// we crate bunch of arrays, filled with specific values
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
auto array = NDArrayFactory::create_<float>('c', {1, 5, 20});
|
||||
array->assign(e);
|
||||
context.setInputArray(e, array, true);
|
||||
}
|
||||
|
||||
auto z = NDArrayFactory::create<float>('c', {arrays.size(), 5, 20});
|
||||
context.setOutputArray(0, &z, false);
|
||||
context.setIArguments(&axis, 1);
|
||||
|
||||
nd4j::ops::concat op;
|
||||
op.execute(&context);
|
||||
|
||||
for (int e = 0; e < arrays.size(); e++)
|
||||
ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e<float>(0), 1e-5f);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test20) {
|
||||
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
|
||||
x0.assign(1.0);
|
||||
x1.assign(2.0);
|
||||
x2.assign(3.0);
|
||||
x3.assign(4.0);
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0});
|
||||
ASSERT_TRUE(4 == numOfTads);
|
||||
|
||||
for (int e = 0; e < numOfTads; e++) {
|
||||
NDArray tad = (*z)(e, {0});
|
||||
auto mean = tad.meanNumber().e<double>(0);
|
||||
ASSERT_NEAR((double) e+1, mean, 1e-5);
|
||||
}
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test21) {
|
||||
|
||||
NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32);
|
||||
NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32);
|
||||
|
||||
x0 = 0.;
|
||||
x1 = 1.;
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test22) {
|
||||
|
||||
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
|
||||
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
|
||||
NDArray output('f', {2,6}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||
|
||||
nd4j::ops::concat op;
|
||||
|
||||
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test23) {
|
||||
|
||||
NDArray x0('c', {1,4}, {1,2,3,4});
|
||||
NDArray x1('c', {1,4}, {5,6,7,8});
|
||||
NDArray output('c', {2,4}, nd4j::DataType::DOUBLE);
|
||||
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
|
||||
|
||||
nd4j::ops::concat op;
|
||||
|
||||
auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test24) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
|
||||
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
|
||||
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
|
||||
auto z = NDArrayFactory::create<double>('c', {2, 2});
|
||||
|
||||
nd4j::ops::concat op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, tile_bp_test1) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue