Refactor NativeOps.h to export C functions
parent
fad8da878f
commit
dcc72e23b2
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -27,11 +27,10 @@ namespace nd4j {
|
||||||
ProviderRNG::ProviderRNG() {
|
ProviderRNG::ProviderRNG() {
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[100000];
|
Nd4jLong *buffer = new Nd4jLong[100000];
|
||||||
NativeOps nativeOps;
|
|
||||||
std::lock_guard<std::mutex> lock(_mutex);
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
#ifndef __CUDABLAS__
|
#ifndef __CUDABLAS__
|
||||||
// at this moment we don't have streams etc, so let's just skip this for now
|
// at this moment we don't have streams etc, so let's just skip this for now
|
||||||
_rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
_rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||||
#endif
|
#endif
|
||||||
// if(_rng != nullptr)
|
// if(_rng != nullptr)
|
||||||
}
|
}
|
||||||
|
@ -49,4 +48,4 @@ random::RandomBuffer* ProviderRNG::getRNG() const {
|
||||||
|
|
||||||
std::mutex ProviderRNG::_mutex;
|
std::mutex ProviderRNG::_mutex;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,8 +41,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
|
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
|
||||||
NativeOps nativeOps;
|
refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
||||||
nativeOps.refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -60,4 +59,4 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -110,11 +110,9 @@ namespace helpers {
|
||||||
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
|
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
|
||||||
|
|
||||||
NDArray scores(*scales);
|
NDArray scores(*scales);
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, stream};
|
Nd4jPointer extras[2] = {nullptr, stream};
|
||||||
|
|
||||||
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||||
// TO DO: sort indices using scales as value row
|
// TO DO: sort indices using scales as value row
|
||||||
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
||||||
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
||||||
|
@ -169,4 +167,4 @@ namespace helpers {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,8 +60,7 @@ namespace helpers {
|
||||||
params[1] = context->getCudaStream();
|
params[1] = context->getCudaStream();
|
||||||
|
|
||||||
if (input->isVector()) {
|
if (input->isVector()) {
|
||||||
NativeOps ops;
|
sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
|
||||||
ops.sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
|
|
||||||
|
|
||||||
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
|
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
|
||||||
}
|
}
|
||||||
|
@ -74,8 +73,7 @@ namespace helpers {
|
||||||
auto pTadShapeH = packX.primaryShapeInfo();
|
auto pTadShapeH = packX.primaryShapeInfo();
|
||||||
auto pTadOffsets = packX.specialOffsets();
|
auto pTadOffsets = packX.specialOffsets();
|
||||||
// auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int));
|
// auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int));
|
||||||
NativeOps ops;
|
sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
|
||||||
ops.sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
|
|
||||||
// manager.synchronize();
|
// manager.synchronize();
|
||||||
sortedVals.tickWriteDevice();
|
sortedVals.tickWriteDevice();
|
||||||
sortedVals.syncToHost();
|
sortedVals.syncToHost();
|
||||||
|
|
|
@ -38,32 +38,28 @@ TEST_F(HeaderTest, test_dataTypes_1) {
|
||||||
std::string header("0NUMPY6789{'descr': '>f4");
|
std::string header("0NUMPY6789{'descr': '>f4");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::FLOAT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HeaderTest, test_dataTypes_2) {
|
TEST_F(HeaderTest, test_dataTypes_2) {
|
||||||
std::string header("0NUMPY6789{'descr': '>f8");
|
std::string header("0NUMPY6789{'descr': '>f8");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::DOUBLE, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HeaderTest, test_dataTypes_3) {
|
TEST_F(HeaderTest, test_dataTypes_3) {
|
||||||
std::string header("0NUMPY6789{'descr': '<i4");
|
std::string header("0NUMPY6789{'descr': '<i4");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::INT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::INT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HeaderTest, test_dataTypes_4) {
|
TEST_F(HeaderTest, test_dataTypes_4) {
|
||||||
std::string header("0NUMPY6789{'descr': '>u2");
|
std::string header("0NUMPY6789{'descr': '>u2");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::UINT16, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::UINT16, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -88,12 +84,11 @@ TEST_F(LoadFromStringTest,PathTest) {
|
||||||
ASSERT_EQ(4.0,data[3]);
|
ASSERT_EQ(4.0,data[3]);
|
||||||
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
|
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
|
||||||
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
|
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
|
||||||
NativeOps nativeOps;
|
Nd4jPointer pointer1 = dataPointForNumpy(loaded);
|
||||||
Nd4jPointer pointer1 = nativeOps.dataPointForNumpy(loaded);
|
|
||||||
delete[] shapeBuffer;
|
delete[] shapeBuffer;
|
||||||
|
|
||||||
double *data2 = reinterpret_cast<double *>(pointer1);
|
double *data2 = reinterpret_cast<double *>(pointer1);
|
||||||
delete[] loaded;
|
delete[] loaded;
|
||||||
}
|
}
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
|
@ -472,9 +472,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
|
||||||
/*
|
/*
|
||||||
Nd4jLong *buffer = new Nd4jLong[100000];
|
Nd4jLong *buffer = new Nd4jLong[100000];
|
||||||
|
|
||||||
NativeOps nativeOps;
|
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||||
|
|
||||||
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
|
||||||
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("RNG initialization failed");
|
throw std::runtime_error("RNG initialization failed");
|
||||||
|
@ -496,7 +494,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
|
||||||
|
|
||||||
ASSERT_TRUE(x->sumNumber() > 0.0);
|
ASSERT_TRUE(x->sumNumber() > 0.0);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
|
|
||||||
delete variableSpace;
|
delete variableSpace;
|
||||||
|
@ -1450,8 +1448,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////
|
||||||
// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) {
|
// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) {
|
||||||
// NativeOps nativeOps;
|
|
||||||
|
|
||||||
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
||||||
// x->assign(1.0f);
|
// x->assign(1.0f);
|
||||||
|
|
||||||
|
@ -1483,8 +1479,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
// outputShapes[0] = (Nd4jPointer) z->getShapeInfo();
|
// outputShapes[0] = (Nd4jPointer) z->getShapeInfo();
|
||||||
|
|
||||||
|
|
||||||
// //auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
|
// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
|
||||||
// auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
// auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
// ASSERT_EQ(ND4J_STATUS_OK, status);
|
// ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
// // z->printIndexedBuffer("Output add");
|
// // z->printIndexedBuffer("Output add");
|
||||||
// ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5);
|
// ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5);
|
||||||
|
@ -1503,8 +1499,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////
|
||||||
// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) {
|
// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) {
|
||||||
// NativeOps nativeOps;
|
|
||||||
|
|
||||||
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
||||||
// x->assign(1.0f);
|
// x->assign(1.0f);
|
||||||
|
|
||||||
|
@ -1532,7 +1526,7 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
// auto outputBuffers = new Nd4jPointer[1];
|
// auto outputBuffers = new Nd4jPointer[1];
|
||||||
// auto outputShapes = new Nd4jPointer[1];
|
// auto outputShapes = new Nd4jPointer[1];
|
||||||
|
|
||||||
// nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
|
// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
|
||||||
|
|
||||||
// ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5);
|
// ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5);
|
||||||
// ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5);
|
// ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5);
|
||||||
|
|
|
@ -876,14 +876,13 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
|
||||||
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
||||||
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
||||||
|
|
||||||
NativeOps op;
|
|
||||||
Nd4jPointer nativeStart[2];
|
Nd4jPointer nativeStart[2];
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
nativeStart[1] = *(x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
4, pidx,
|
4, pidx,
|
||||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||||
|
@ -912,12 +911,11 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
|
||||||
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
||||||
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
||||||
|
|
||||||
NativeOps op;
|
|
||||||
Nd4jPointer nativeStart[2];
|
Nd4jPointer nativeStart[2];
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
nativeStart[1] = *(x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
4, pidx,
|
4, pidx,
|
||||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||||
|
|
|
@ -110,8 +110,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
|
||||||
double extraParams[] = {lambda};
|
double extraParams[] = {lambda};
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
NativeOps nativeOps;
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
|
||||||
|
|
||||||
|
@ -122,7 +121,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
|
||||||
ASSERT_NEAR(mean, actualMean, 0.01);
|
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||||
ASSERT_NEAR(std, actualStd, 0.01);
|
ASSERT_NEAR(std, actualStd, 0.01);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -142,8 +141,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
NativeOps nativeOps;
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
|
||||||
|
|
||||||
|
@ -155,7 +153,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
|
||||||
ASSERT_NEAR(mean, actualMean, 0.01);
|
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||||
ASSERT_NEAR(std, actualStd, 0.01);
|
ASSERT_NEAR(std, actualStd, 0.01);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -172,8 +170,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
|
||||||
double extraParams[] = {lambda};
|
double extraParams[] = {lambda};
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
NativeOps nativeOps;
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
|
||||||
|
|
||||||
|
@ -184,7 +181,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
|
||||||
ASSERT_NEAR(mean, actualMean, 0.01);
|
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||||
ASSERT_NEAR(std, actualStd, 0.01);
|
ASSERT_NEAR(std, actualStd, 0.01);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
@ -206,14 +203,13 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
// Nd4jPointer extra[2];
|
// Nd4jPointer extra[2];
|
||||||
#ifndef __CUDABLAS__
|
#ifndef __CUDABLAS__
|
||||||
NativeOps nativeOps;
|
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
|
||||||
|
|
||||||
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
|
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
#endif
|
#endif
|
||||||
const double actualMean = x.meanNumber().e<double>(0);
|
const double actualMean = x.meanNumber().e<double>(0);
|
||||||
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
|
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
|
||||||
|
@ -1005,12 +1001,10 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
x1.linspace(1);
|
x1.linspace(1);
|
||||||
/*
|
/*
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
float prob[] = {0.5f};
|
float prob[] = {0.5f};
|
||||||
Nd4jLong* _bufferA = new Nd4jLong[100000];
|
Nd4jLong* _bufferA = new Nd4jLong[100000];
|
||||||
long _seed = 119L;
|
long _seed = 119L;
|
||||||
auto _rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
auto _rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||||
|
|
||||||
x0. applyTransform(random::DropOutInverted, &x0, prob);
|
x0. applyTransform(random::DropOutInverted, &x0, prob);
|
||||||
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
|
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
|
||||||
|
@ -1026,7 +1020,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
// ASSERT_FALSE(x0.equalsTo(nexp0));
|
// ASSERT_FALSE(x0.equalsTo(nexp0));
|
||||||
// ASSERT_FALSE(x0.equalsTo(nexp1));
|
// ASSERT_FALSE(x0.equalsTo(nexp1));
|
||||||
// ASSERT_FALSE(x0.equalsTo(nexp2));
|
// ASSERT_FALSE(x0.equalsTo(nexp2));
|
||||||
nativeOps.destroyRandom(_rngA);
|
destroyRandom(_rngA);
|
||||||
delete [] _bufferA;
|
delete [] _bufferA;
|
||||||
*/
|
*/
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
|
@ -2911,4 +2905,4 @@ TEST_F(DeclarableOpsTests9, Cholesky_Test_3) {
|
||||||
// ASSERT_TRUE(isGradCorrect);
|
// ASSERT_TRUE(isGradCorrect);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
@ -51,9 +51,7 @@ public:
|
||||||
*/
|
*/
|
||||||
|
|
||||||
TEST_F(GraphStateTests, Basic_Tests_1) {
|
TEST_F(GraphStateTests, Basic_Tests_1) {
|
||||||
NativeOps nativeOps;
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
|
||||||
ASSERT_EQ(117L, state->id());
|
ASSERT_EQ(117L, state->id());
|
||||||
|
|
||||||
// this call will create scope internally
|
// this call will create scope internally
|
||||||
|
@ -72,14 +70,12 @@ TEST_F(GraphStateTests, Basic_Tests_1) {
|
||||||
ASSERT_TRUE(scope != nullptr);
|
ASSERT_TRUE(scope != nullptr);
|
||||||
ASSERT_EQ(2, scope->size());
|
ASSERT_EQ(2, scope->size());
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
// just separate case for doubles wrapper in NativeOps, nothing else
|
// just separate case for doubles wrapper in NativeOps, nothing else
|
||||||
TEST_F(GraphStateTests, Basic_Tests_2) {
|
TEST_F(GraphStateTests, Basic_Tests_2) {
|
||||||
NativeOps nativeOps;
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
|
||||||
ASSERT_EQ(117L, state->id());
|
ASSERT_EQ(117L, state->id());
|
||||||
|
|
||||||
// this call will create scope internally
|
// this call will create scope internally
|
||||||
|
@ -98,46 +94,40 @@ TEST_F(GraphStateTests, Basic_Tests_2) {
|
||||||
ASSERT_TRUE(scope != nullptr);
|
ASSERT_TRUE(scope != nullptr);
|
||||||
ASSERT_EQ(2, scope->size());
|
ASSERT_EQ(2, scope->size());
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
||||||
NativeOps nativeOps;
|
auto state = getGraphState(117L);
|
||||||
|
|
||||||
auto state = nativeOps.getGraphState(117L);
|
|
||||||
|
|
||||||
Nd4jLong scopes[] = {22, 33};
|
Nd4jLong scopes[] = {22, 33};
|
||||||
//auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
//auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::THROW(), status);
|
ASSERT_EQ(Status::THROW(), status);
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_2) {
|
TEST_F(GraphStateTests, Stateful_Execution_2) {
|
||||||
NativeOps nativeOps;
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
|
||||||
|
|
||||||
state->registerScope(22);
|
state->registerScope(22);
|
||||||
state->registerScope(33);
|
state->registerScope(33);
|
||||||
|
|
||||||
Nd4jLong scopes[] = {22, 33};
|
Nd4jLong scopes[] = {22, 33};
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
|
||||||
// it's no-op: just LogicScope
|
// it's no-op: just LogicScope
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This test checks WHILE loop
|
* This test checks WHILE loop
|
||||||
*/
|
*/
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(11.0f);
|
auto var1 = NDArrayFactory::create<float>(11.0f);
|
||||||
auto var2 = NDArrayFactory::create<float>(2.0f);
|
auto var2 = NDArrayFactory::create<float>(2.0f);
|
||||||
|
@ -147,7 +137,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
auto res2 = NDArrayFactory::create<float>(0.0f);
|
auto res2 = NDArrayFactory::create<float>(0.0f);
|
||||||
|
|
||||||
// registering our GraphState holder
|
// registering our GraphState holder
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
// we're prepping pointers to input/output buffers
|
// we're prepping pointers to input/output buffers
|
||||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
|
||||||
|
@ -197,7 +187,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
Nd4jLong scopes[] = {22, 33};
|
Nd4jLong scopes[] = {22, 33};
|
||||||
|
|
||||||
// we're executing while loop
|
// we're executing while loop
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
|
auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
// now we check provided result array
|
// now we check provided result array
|
||||||
|
@ -211,7 +201,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
|
|
||||||
// nd4j_printf("0 ------------------\n","");
|
// nd4j_printf("0 ------------------\n","");
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
|
|
||||||
// nd4j_printf("1 ------------------\n","");
|
// nd4j_printf("1 ------------------\n","");
|
||||||
}
|
}
|
||||||
|
@ -220,8 +210,6 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
* This test checks CONDITIONAL execution for FALSE
|
* This test checks CONDITIONAL execution for FALSE
|
||||||
*/
|
*/
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
|
@ -232,7 +220,7 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
|
|
||||||
|
|
||||||
// registering our GraphState holder
|
// registering our GraphState holder
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
// we're prepping pointers to input/output buffers
|
// we're prepping pointers to input/output buffers
|
||||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||||
|
@ -283,14 +271,14 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
Nd4jLong scopes[] = {22, 33, 44};
|
Nd4jLong scopes[] = {22, 33, 44};
|
||||||
|
|
||||||
// we're executing conditional op
|
// we're executing conditional op
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(&res0));
|
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||||
ASSERT_TRUE(exp.equalsTo(&res0));
|
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||||
|
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -298,8 +286,6 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
* This test checks CONDITIONAL execution for TRUE
|
* This test checks CONDITIONAL execution for TRUE
|
||||||
*/
|
*/
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
|
@ -310,7 +296,7 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
|
|
||||||
|
|
||||||
// registering our GraphState holder
|
// registering our GraphState holder
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
// we're prepping pointers to input/output buffers
|
// we're prepping pointers to input/output buffers
|
||||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||||
|
@ -361,12 +347,11 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
Nd4jLong scopes[] = {22, 33, 44};
|
Nd4jLong scopes[] = {22, 33, 44};
|
||||||
|
|
||||||
// we're executing conditional op
|
// we're executing conditional op
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(&res0));
|
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||||
ASSERT_TRUE(exp.equalsTo(&res0));
|
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
nativeOps.deleteGraphState(state);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
|
||||||
e.assign(2.f);
|
e.assign(2.f);
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
NativeOps nativeOps;
|
|
||||||
Context context(1);
|
Context context(1);
|
||||||
|
|
||||||
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||||
|
@ -53,7 +52,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
|
||||||
|
|
||||||
nd4j_printf("Starting execution...\n","");
|
nd4j_printf("Starting execution...\n","");
|
||||||
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
|
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context);
|
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||||
|
|
||||||
pm.synchronize();
|
pm.synchronize();
|
||||||
|
|
||||||
|
@ -71,7 +70,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
|
||||||
e.assign(false);
|
e.assign(false);
|
||||||
|
|
||||||
nd4j::ops::equals op;
|
nd4j::ops::equals op;
|
||||||
NativeOps nativeOps;
|
|
||||||
Context context(1);
|
Context context(1);
|
||||||
|
|
||||||
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||||
|
@ -82,7 +80,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
|
||||||
|
|
||||||
nd4j_printf("Starting execution...\n","");
|
nd4j_printf("Starting execution...\n","");
|
||||||
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
|
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context);
|
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||||
|
|
||||||
pm.synchronize();
|
pm.synchronize();
|
||||||
|
|
||||||
|
|
|
@ -41,8 +41,6 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
|
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
|
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
|
|
||||||
std::vector<double> tArgs({});
|
std::vector<double> tArgs({});
|
||||||
|
@ -50,7 +48,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
||||||
|
|
||||||
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
|
||||||
|
|
||||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||||
|
|
||||||
ASSERT_EQ(1, shapeList->size());
|
ASSERT_EQ(1, shapeList->size());
|
||||||
|
|
||||||
|
@ -64,7 +62,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
||||||
//delete[] ptr;
|
//delete[] ptr;
|
||||||
//delete shapeList;
|
//delete shapeList;
|
||||||
|
|
||||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
deleteShapeList((Nd4jPointer) shapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,9 +70,6 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
|
||||||
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
|
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::shape_of op;
|
nd4j::ops::shape_of op;
|
||||||
|
|
||||||
std::vector<double> tArgs({});
|
std::vector<double> tArgs({});
|
||||||
|
@ -83,14 +78,14 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
|
||||||
|
|
||||||
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
|
||||||
|
|
||||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||||
|
|
||||||
ASSERT_EQ(1, shapeList->size());
|
ASSERT_EQ(1, shapeList->size());
|
||||||
|
|
||||||
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
||||||
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
||||||
|
|
||||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
deleteShapeList((Nd4jPointer) shapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, TestShapeExposure3) {
|
TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||||
|
@ -112,13 +107,12 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||||
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
|
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
|
||||||
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
|
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::split_v op;
|
nd4j::ops::split_v op;
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {1};
|
Nd4jLong iArgs[] = {1};
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
|
|
||||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
||||||
|
|
||||||
ASSERT_EQ(3, shapeList->size());
|
ASSERT_EQ(3, shapeList->size());
|
||||||
|
|
||||||
|
@ -126,7 +120,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||||
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
|
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
|
||||||
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
|
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
|
||||||
|
|
||||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
deleteShapeList((Nd4jPointer) shapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
||||||
|
@ -143,10 +137,7 @@ TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -167,10 +158,7 @@ TEST_F(JavaInteropTests, Test_RDiv_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -203,11 +191,9 @@ TEST_F(JavaInteropTests, TestSconv2d_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
||||||
nullptr, 0, exp, 9, nullptr, 0, false);
|
nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
//output.printBuffer("output");
|
//output.printBuffer("output");
|
||||||
|
@ -238,11 +224,9 @@ TEST_F(JavaInteropTests, TestSconv2d_2) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
//output.printBuffer("output");
|
//output.printBuffer("output");
|
||||||
|
|
||||||
|
@ -266,9 +250,7 @@ TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
||||||
|
|
||||||
Nd4jStatus status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -294,13 +276,11 @@ TEST_F(JavaInteropTests, TestCol2Im_1) {
|
||||||
|
|
||||||
nd4j::ops::col2im op;
|
nd4j::ops::col2im op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
|
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
||||||
}
|
}
|
||||||
|
@ -320,8 +300,6 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
|
||||||
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::pnormpool2d op;
|
nd4j::ops::pnormpool2d op;
|
||||||
|
|
||||||
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
||||||
|
@ -332,7 +310,7 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
||||||
}
|
}
|
||||||
|
@ -343,8 +321,6 @@ TEST_F(JavaInteropTests, TestInplace_1) {
|
||||||
//auto exp('c', {10, 10});
|
//auto exp('c', {10, 10});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::clipbyvalue op;
|
nd4j::ops::clipbyvalue op;
|
||||||
|
|
||||||
double extras[] = {-1.0f, 1.0f};
|
double extras[] = {-1.0f, 1.0f};
|
||||||
|
@ -353,7 +329,7 @@ TEST_F(JavaInteropTests, TestInplace_1) {
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jStatus result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||||
|
|
||||||
|
@ -415,7 +391,6 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||||
x.linspace(1.0);
|
x.linspace(1.0);
|
||||||
z.linspace(1.0);
|
z.linspace(1.0);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::avgpool2d op;
|
nd4j::ops::avgpool2d op;
|
||||||
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
||||||
|
|
||||||
|
@ -427,7 +402,7 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
auto result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result);
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
|
@ -496,15 +471,13 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
||||||
|
|
||||||
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data);
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||||
|
|
||||||
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
nativeOps.unregisterGraph(nullptr, 119);
|
unregisterGraph(nullptr, 119);
|
||||||
|
|
||||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
|
@ -520,8 +493,6 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
|
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
|
||||||
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
|
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
// we load graph from file, because we're not in java here, and dont have buffer ready
|
// we load graph from file, because we're not in java here, and dont have buffer ready
|
||||||
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
||||||
|
|
||||||
|
@ -529,7 +500,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
// register the graph, to call for it later
|
// register the graph, to call for it later
|
||||||
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data);
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||||
|
|
||||||
// and ensure we're ok
|
// and ensure we're ok
|
||||||
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
@ -547,7 +518,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
|
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
|
||||||
|
|
||||||
// now we're executing stored graph and providing replacement for input variable
|
// now we're executing stored graph and providing replacement for input variable
|
||||||
auto res_0 = nativeOps.executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
|
||||||
ASSERT_EQ(1, res_0->size());
|
ASSERT_EQ(1, res_0->size());
|
||||||
|
|
||||||
|
@ -562,7 +533,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
|
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
|
||||||
|
|
||||||
// doing it again
|
// doing it again
|
||||||
auto res_1 = nativeOps.executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
||||||
ASSERT_EQ(1, res_1->size());
|
ASSERT_EQ(1, res_1->size());
|
||||||
|
|
||||||
|
@ -577,7 +548,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
|
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
|
||||||
|
|
||||||
// and again
|
// and again
|
||||||
auto res_2 = nativeOps.executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
||||||
ASSERT_EQ(1, res_2->size());
|
ASSERT_EQ(1, res_2->size());
|
||||||
|
|
||||||
|
@ -586,7 +557,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
|
|
||||||
|
|
||||||
//////// clean out
|
//////// clean out
|
||||||
nativeOps.unregisterGraph(nullptr, 119);
|
unregisterGraph(nullptr, 119);
|
||||||
|
|
||||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
|
@ -616,9 +587,7 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
o.printIndexedBuffer("Greater JIT");
|
o.printIndexedBuffer("Greater JIT");
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
}
|
}
|
||||||
|
@ -641,9 +610,7 @@ TEST_F(JavaInteropTests, Test_Greater_2) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
}
|
}
|
||||||
|
@ -662,9 +629,8 @@ TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
|
@ -685,9 +651,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
@ -710,9 +675,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
@ -736,9 +700,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {1};
|
Nd4jLong iArgs[] = {1};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(output));
|
ASSERT_TRUE(e.isSameShape(output));
|
||||||
|
@ -753,8 +716,7 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
|
||||||
|
|
||||||
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execReduce3Tad(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
|
||||||
nativeOps.execReduce3(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
|
|
||||||
y.buffer(), y.shapeInfo(), nullptr, nullptr,
|
y.buffer(), y.shapeInfo(), nullptr, nullptr,
|
||||||
z.buffer(), z.shapeInfo(), nullptr, nullptr,
|
z.buffer(), z.shapeInfo(), nullptr, nullptr,
|
||||||
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
|
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
|
||||||
|
@ -764,10 +726,8 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
|
||||||
Environment::getInstance()->setDebug(true);
|
Environment::getInstance()->setDebug(true);
|
||||||
Environment::getInstance()->setVerbose(false);
|
Environment::getInstance()->setVerbose(false);
|
||||||
|
|
||||||
NativeOps ops;
|
|
||||||
|
|
||||||
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
|
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
|
||||||
auto ptr = ops.executeFlatGraph(nullptr, pl);
|
auto ptr = executeFlatGraph(nullptr, pl);
|
||||||
|
|
||||||
Environment::getInstance()->setDebug(false);
|
Environment::getInstance()->setDebug(false);
|
||||||
Environment::getInstance()->setVerbose(false);
|
Environment::getInstance()->setVerbose(false);
|
||||||
|
@ -792,9 +752,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
|
||||||
|
@ -818,9 +777,8 @@ TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -843,9 +801,8 @@ TEST_F(JavaInteropTests, Test_Unstack_1) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -864,9 +821,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
|
||||||
|
@ -883,8 +839,7 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
|
||||||
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
||||||
|
|
||||||
NativeOps ops;
|
execPairwiseTransform(nullptr, pairwise::Add,
|
||||||
ops.execPairwiseTransform(nullptr, pairwise::Add,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
|
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
|
@ -898,7 +853,6 @@ TEST_F(JavaInteropTests, Test_Add_1) {
|
||||||
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
||||||
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
|
||||||
|
@ -907,7 +861,7 @@ TEST_F(JavaInteropTests, Test_Add_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_EQ(e, x);
|
ASSERT_EQ(e, x);
|
||||||
}
|
}
|
||||||
|
@ -920,7 +874,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
||||||
|
|
||||||
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
|
||||||
|
@ -929,7 +882,7 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
@ -939,8 +892,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1) {
|
||||||
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
||||||
|
|
||||||
NativeOps ops;
|
execTransformAny(nullptr, transform::IsMax,
|
||||||
ops.execTransformAny(nullptr, transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
nullptr);
|
nullptr);
|
||||||
|
@ -953,8 +905,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
||||||
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
||||||
|
|
||||||
NativeOps ops;
|
execTransformAny(nullptr, transform::IsMax,
|
||||||
ops.execTransformAny(nullptr, transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
nullptr);
|
nullptr);
|
||||||
|
@ -970,8 +921,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
||||||
Nd4jLong *ex[] = {tad, off};
|
Nd4jLong *ex[] = {tad, off};
|
||||||
float ea[] = {2, 1, 2};
|
float ea[] = {2, 1, 2};
|
||||||
|
|
||||||
NativeOps ops;
|
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
||||||
ops.execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
ea);
|
ea);
|
||||||
|
@ -995,8 +945,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::greater_equal op;
|
nd4j::ops::greater_equal op;
|
||||||
NativeOps ops;
|
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
auto shapeList = ops.calculateOutputShapes(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
|
||||||
|
|
||||||
delete shapeList;
|
delete shapeList;
|
||||||
}
|
}
|
||||||
|
@ -1013,8 +962,7 @@ TEST_F(JavaInteropTests, Test_L2_Loss_3) {
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||||
|
|
||||||
nd4j::ops::l2_loss op;
|
nd4j::ops::l2_loss op;
|
||||||
NativeOps ops;
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
auto status = ops.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
z.printIndexedBuffer("z");
|
z.printIndexedBuffer("z");
|
||||||
|
@ -1036,9 +984,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_3) {
|
||||||
|
|
||||||
ASSERT_EQ(2, ctx.width());
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(exp, z);
|
ASSERT_EQ(exp, z);
|
||||||
}
|
}
|
||||||
|
@ -1054,9 +1001,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_4) {
|
||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
ctx.setIArguments(iArgs, 3);
|
ctx.setIArguments(iArgs, 3);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::tri op;
|
nd4j::ops::tri op;
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(exp, z);
|
ASSERT_EQ(exp, z);
|
||||||
}
|
}
|
||||||
|
@ -1074,9 +1020,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_5) {
|
||||||
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
||||||
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
|
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
@ -1104,9 +1049,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_6) {
|
||||||
|
|
||||||
ctx.setIArguments(iArgs, 3);
|
ctx.setIArguments(iArgs, 3);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::matmul_bp op;
|
nd4j::ops::matmul_bp op;
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
@ -1122,7 +1066,6 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
||||||
|
|
||||||
ctx.setIArguments(iArgs, 1);
|
ctx.setIArguments(iArgs, 1);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
||||||
|
@ -1130,7 +1073,7 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
||||||
|
|
||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -1138,10 +1081,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
NativeOps ops;
|
|
||||||
|
|
||||||
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||||
auto ptr = ops.executeFlatGraph(nullptr, pl);
|
auto ptr = executeFlatGraph(nullptr, pl);
|
||||||
|
|
||||||
// at this point we have FlatResults
|
// at this point we have FlatResults
|
||||||
auto flatResult = GetFlatResult(ptr->pointer());
|
auto flatResult = GetFlatResult(ptr->pointer());
|
||||||
|
@ -1190,8 +1131,6 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
|
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
|
||||||
// NativeOps ops;
|
|
||||||
|
|
||||||
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
|
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
|
||||||
// std::array<float, 60> syn1;
|
// std::array<float, 60> syn1;
|
||||||
// std::array<float, 100000> exp;
|
// std::array<float, 100000> exp;
|
||||||
|
@ -1283,5 +1222,5 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
|
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
|
||||||
|
|
||||||
|
|
||||||
// ops.execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
||||||
// }
|
// }
|
||||||
|
|
|
@ -53,9 +53,8 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
NativeOps nativeOps;
|
sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
|
||||||
nativeOps.sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
|
|
||||||
x.tickWriteDevice();
|
x.tickWriteDevice();
|
||||||
|
|
||||||
ASSERT_EQ(e, x);
|
ASSERT_EQ(e, x);
|
||||||
}
|
}
|
||||||
|
|
|
@ -501,8 +501,7 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execReduce3Tad(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
|
||||||
nullptr, nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -517,9 +516,8 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
@ -543,9 +541,8 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
@ -569,9 +566,8 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
@ -593,8 +589,7 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
||||||
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
|
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
|
||||||
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
|
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
|
||||||
|
|
||||||
NativeOps ops;
|
execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
ops.execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
|
@ -697,4 +692,4 @@ TEST_F(LegacyOpsTests, test_legacy_transform_float_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
|
auto x = NDArrayFactory::create<float>('c', {1, 0, 4});
|
||||||
|
|
||||||
NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr);
|
NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,8 +33,6 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
// just 10GB
|
// just 10GB
|
||||||
Nd4jLong size = 100000L;
|
Nd4jLong size = 100000L;
|
||||||
|
|
||||||
|
@ -43,11 +41,11 @@ TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
||||||
ofs.write("", 1);
|
ofs.write("", 1);
|
||||||
ofs.close();
|
ofs.close();
|
||||||
|
|
||||||
auto result = nativeOps.mmapFile(nullptr, "file", size);
|
auto result = mmapFile(nullptr, "file", size);
|
||||||
|
|
||||||
ASSERT_FALSE(result == nullptr);
|
ASSERT_FALSE(result == nullptr);
|
||||||
|
|
||||||
nativeOps.munmapFile(nullptr, result, size);
|
munmapFile(nullptr, result, size);
|
||||||
|
|
||||||
remove("file");
|
remove("file");
|
||||||
}
|
}
|
||||||
|
|
|
@ -2258,7 +2258,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) {
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {5, 8});
|
auto z = NDArrayFactory::create<float>('c', {5, 8});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(4);
|
std::vector<void*> buffers(4);
|
||||||
|
@ -2272,7 +2271,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||||
}
|
}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat result linear");
|
z.printBuffer("C Concat result linear");
|
||||||
|
@ -2281,7 +2280,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('f', {5, 8});
|
auto z = NDArrayFactory::create<float>('f', {5, 8});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(4);
|
std::vector<void*> buffers(4);
|
||||||
|
@ -2295,7 +2293,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||||
}
|
}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("F Concat result linear");
|
z.printBuffer("F Concat result linear");
|
||||||
|
@ -2304,7 +2302,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('f', {3, 3});
|
auto z = NDArrayFactory::create<float>('f', {3, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
|
@ -2321,7 +2318,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||||
//}
|
//}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("F Concat result linear");
|
z.printBuffer("F Concat result linear");
|
||||||
|
@ -2331,7 +2328,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {3, 3});
|
auto z = NDArrayFactory::create<float>('c', {3, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
|
@ -2348,7 +2344,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||||
//}
|
//}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat result linear");
|
z.printBuffer("C Concat result linear");
|
||||||
|
@ -2358,7 +2354,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
|
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
|
auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
|
@ -2375,7 +2370,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
||||||
//}
|
//}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat result linear");
|
z.printBuffer("C Concat result linear");
|
||||||
|
@ -2385,7 +2380,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
|
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
|
||||||
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
|
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
|
||||||
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
|
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
|
auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
|
||||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(3);
|
std::vector<void*> buffers(3);
|
||||||
|
@ -2406,7 +2400,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
||||||
printf("The third array is %p\n", buffers[2]);
|
printf("The third array is %p\n", buffers[2]);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat3D result linear");
|
z.printBuffer("C Concat3D result linear");
|
||||||
|
@ -2417,7 +2411,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
||||||
auto x1 = NDArrayFactory::create<float>(1);
|
auto x1 = NDArrayFactory::create<float>(1);
|
||||||
auto x2 = NDArrayFactory::create<float>(2);
|
auto x2 = NDArrayFactory::create<float>(2);
|
||||||
auto x3 = NDArrayFactory::create<float>(3);
|
auto x3 = NDArrayFactory::create<float>(3);
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
|
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
|
||||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(3);
|
std::vector<void*> buffers(3);
|
||||||
|
@ -2438,7 +2431,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
||||||
printf("The third array is %p\n", buffers[2]);
|
printf("The third array is %p\n", buffers[2]);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat scalar result linear");
|
z.printBuffer("C Concat scalar result linear");
|
||||||
|
@ -2462,7 +2455,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
||||||
lx[i].assign(i);
|
lx[i].assign(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {totalCount, width});
|
auto z = NDArrayFactory::create<float>('c', {totalCount, width});
|
||||||
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(totalCount);
|
std::vector<void*> buffers(totalCount);
|
||||||
|
@ -2478,7 +2470,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
||||||
printf("The third array is %p\n", buffers[2]);
|
printf("The third array is %p\n", buffers[2]);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
|
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
|
||||||
//z.printIndexedBuffer("Concat result");
|
//z.printIndexedBuffer("Concat result");
|
||||||
|
@ -2496,7 +2488,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
||||||
arrays.emplace_back(input);
|
arrays.emplace_back(input);
|
||||||
}
|
}
|
||||||
auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
|
auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
|
||||||
NativeOps native;
|
|
||||||
|
|
||||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
|
@ -2512,7 +2503,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
||||||
hostShapes[i] = arrays[i].shapeInfo();
|
hostShapes[i] = arrays[i].shapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
native.concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
nd4j::ops::tear op;
|
nd4j::ops::tear op;
|
||||||
|
|
||||||
auto result = op.execute({&z}, {}, {1, 2});
|
auto result = op.execute({&z}, {}, {1, 2});
|
||||||
|
@ -2536,7 +2527,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
||||||
arrays.emplace_back(input);
|
arrays.emplace_back(input);
|
||||||
}
|
}
|
||||||
auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
|
auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
|
||||||
NativeOps native;
|
|
||||||
|
|
||||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
|
@ -2552,7 +2542,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
||||||
hostShapes[i] = arrays[i].shapeInfo();
|
hostShapes[i] = arrays[i].shapeInfo();
|
||||||
}
|
}
|
||||||
std::vector<int> dimsToExclude({1,2});
|
std::vector<int> dimsToExclude({1,2});
|
||||||
native.concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
// z.syncToHost();
|
// z.syncToHost();
|
||||||
// z.printBuffer("Pile OK");
|
// z.printBuffer("Pile OK");
|
||||||
// z.printIndexedBuffer("Pile 10x10");
|
// z.printIndexedBuffer("Pile 10x10");
|
||||||
|
@ -2569,7 +2559,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
||||||
Nd4jPointer target = arrays[i].specialBuffer();
|
Nd4jPointer target = arrays[i].specialBuffer();
|
||||||
cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice);
|
cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice);
|
||||||
}
|
}
|
||||||
native.tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
||||||
// auto result = op.execute({&z}, {}, {1, 2});
|
// auto result = op.execute({&z}, {}, {1, 2});
|
||||||
// nd4j_printf("Result count is %lu\n", result->size());
|
// nd4j_printf("Result count is %lu\n", result->size());
|
||||||
//ASSERT_EQ(10, result->size());
|
//ASSERT_EQ(10, result->size());
|
||||||
|
|
|
@ -313,12 +313,10 @@ TEST_F(PlaygroundTests, test_reduce_3) {
|
||||||
Nd4jLong max = 0L;
|
Nd4jLong max = 0L;
|
||||||
Nd4jLong min = DataTypeUtils::max<Nd4jLong>();
|
Nd4jLong min = DataTypeUtils::max<Nd4jLong>();
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
auto timeStart = std::chrono::system_clock::now();
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
|
execReduce3Tad(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
|
||||||
x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(),
|
x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(),
|
||||||
y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr,
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr,
|
||||||
|
@ -964,8 +962,6 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
||||||
auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count();
|
auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count();
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0};
|
Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0};
|
||||||
Nd4jPointer inputBuffers[] = {input.buffer()};
|
Nd4jPointer inputBuffers[] = {input.buffer()};
|
||||||
Nd4jPointer inputShapes[] = {input.shapeInfo()};
|
Nd4jPointer inputShapes[] = {input.shapeInfo()};
|
||||||
|
@ -976,7 +972,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
||||||
auto javaStart = std::chrono::system_clock::now();
|
auto javaStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto javaEnd = std::chrono::system_clock::now();
|
auto javaEnd = std::chrono::system_clock::now();
|
||||||
|
@ -990,7 +986,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
||||||
|
|
||||||
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto javaPermEnd = std::chrono::system_clock::now();
|
auto javaPermEnd = std::chrono::system_clock::now();
|
||||||
|
@ -1020,9 +1016,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_2) {
|
||||||
Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()};
|
Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()};
|
||||||
Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()};
|
Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PlaygroundTests, Test_Col2Im_1) {
|
TEST_F(PlaygroundTests, Test_Col2Im_1) {
|
||||||
|
@ -1140,8 +1134,6 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
||||||
int length = (int) array->lengthOf();
|
int length = (int) array->lengthOf();
|
||||||
int span = (int) (array->lengthOf() / 6) + 8;
|
int span = (int) (array->lengthOf() / 6) + 8;
|
||||||
|
|
||||||
NativeOps ops;
|
|
||||||
|
|
||||||
auto t = new int[1000000];
|
auto t = new int[1000000];
|
||||||
|
|
||||||
|
|
||||||
|
@ -1150,7 +1142,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
||||||
FloatBits fb;
|
FloatBits fb;
|
||||||
float threshold = 0.99f;
|
float threshold = 0.99f;
|
||||||
fb.f_ = threshold;
|
fb.f_ = threshold;
|
||||||
int le = ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
int le = estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||||
|
|
||||||
t[0] = le;
|
t[0] = le;
|
||||||
t[1] = length;
|
t[1] = length;
|
||||||
|
@ -1162,7 +1154,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
||||||
|
|
||||||
for (int x = 0; x < iterations; x++) {
|
for (int x = 0; x < iterations; x++) {
|
||||||
auto permStart = std::chrono::system_clock::now();
|
auto permStart = std::chrono::system_clock::now();
|
||||||
ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||||
TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t);
|
TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t);
|
||||||
|
|
||||||
auto permEnd = std::chrono::system_clock::now();
|
auto permEnd = std::chrono::system_clock::now();
|
||||||
|
|
|
@ -29,7 +29,6 @@ using namespace nd4j;
|
||||||
|
|
||||||
class RNGTests : public testing::Test {
|
class RNGTests : public testing::Test {
|
||||||
private:
|
private:
|
||||||
NativeOps nativeOps;
|
|
||||||
//Nd4jLong *_bufferA;
|
//Nd4jLong *_bufferA;
|
||||||
//Nd4jLong *_bufferB;
|
//Nd4jLong *_bufferB;
|
||||||
|
|
||||||
|
@ -47,8 +46,8 @@ public:
|
||||||
RNGTests() {
|
RNGTests() {
|
||||||
//_bufferA = new Nd4jLong[100000];
|
//_bufferA = new Nd4jLong[100000];
|
||||||
//_bufferB = new Nd4jLong[100000];
|
//_bufferB = new Nd4jLong[100000];
|
||||||
//_rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
//_rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||||
//_rngB = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
|
//_rngB = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
|
||||||
_rngA.setStates(_seed, _seed);
|
_rngA.setStates(_seed, _seed);
|
||||||
_rngB.setStates(_seed, _seed);
|
_rngB.setStates(_seed, _seed);
|
||||||
nexp0->assign(-1.0f);
|
nexp0->assign(-1.0f);
|
||||||
|
@ -57,8 +56,8 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
~RNGTests() {
|
~RNGTests() {
|
||||||
//nativeOps.destroyRandom(_rngA);
|
//destroyRandom(_rngA);
|
||||||
//nativeOps.destroyRandom(_rngB);
|
//destroyRandom(_rngB);
|
||||||
//delete[] _bufferA;
|
//delete[] _bufferA;
|
||||||
//delete[] _bufferB;
|
//delete[] _bufferB;
|
||||||
|
|
||||||
|
@ -791,14 +790,13 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Reproducibility_9) {
|
TEST_F(RNGTests, Test_Reproducibility_9) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||||
const int bufferSize = 10000;
|
const int bufferSize = 10000;
|
||||||
int64_t buffer[bufferSize];
|
int64_t buffer[bufferSize];
|
||||||
|
|
||||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||||
|
|
||||||
const int length = 4000000;
|
const int length = 4000000;
|
||||||
int *arrayE = new int[length];
|
int *arrayE = new int[length];
|
||||||
|
@ -809,7 +807,7 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
|
||||||
|
|
||||||
rng->rewindH(static_cast<Nd4jLong>(length));
|
rng->rewindH(static_cast<Nd4jLong>(length));
|
||||||
|
|
||||||
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||||
|
|
||||||
for (int e = 0; e < length; e++)
|
for (int e = 0; e < length; e++)
|
||||||
arrayT[e] = rng->relativeInt(e);
|
arrayT[e] = rng->relativeInt(e);
|
||||||
|
@ -825,18 +823,17 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
|
||||||
delete[] arrayE;
|
delete[] arrayE;
|
||||||
delete[] arrayT;
|
delete[] arrayT;
|
||||||
|
|
||||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Reproducibility_8) {
|
TEST_F(RNGTests, Test_Reproducibility_8) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<int> shape = {32, 3, 28, 28};
|
std::vector<int> shape = {32, 3, 28, 28};
|
||||||
const int bufferSize = 10000;
|
const int bufferSize = 10000;
|
||||||
int64_t buffer[bufferSize];
|
int64_t buffer[bufferSize];
|
||||||
|
|
||||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||||
|
|
||||||
const int length = 4000000;
|
const int length = 4000000;
|
||||||
int *arrayE = new int[length];
|
int *arrayE = new int[length];
|
||||||
|
@ -847,7 +844,7 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
|
||||||
|
|
||||||
rng->rewindH(static_cast<Nd4jLong>(length));
|
rng->rewindH(static_cast<Nd4jLong>(length));
|
||||||
|
|
||||||
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||||
|
|
||||||
for (int e = 0; e < length; e++)
|
for (int e = 0; e < length; e++)
|
||||||
arrayT[e] = static_cast<int>(rng->relativeT<float>(e));
|
arrayT[e] = static_cast<int>(rng->relativeT<float>(e));
|
||||||
|
@ -863,29 +860,27 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
|
||||||
delete[] arrayE;
|
delete[] arrayE;
|
||||||
delete[] arrayT;
|
delete[] arrayT;
|
||||||
|
|
||||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_RandomBuffer_Half_1) {
|
TEST_F(RNGTests, Test_RandomBuffer_Half_1) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||||
const int bufferSize = 10000;
|
const int bufferSize = 10000;
|
||||||
int64_t buffer[bufferSize];
|
int64_t buffer[bufferSize];
|
||||||
|
|
||||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||||
|
|
||||||
auto r0 = rng->relativeT<float16>(12L);
|
auto r0 = rng->relativeT<float16>(12L);
|
||||||
auto r1 = rng->relativeT<float16>(13L);
|
auto r1 = rng->relativeT<float16>(13L);
|
||||||
|
|
||||||
ASSERT_NE(r0, r1);
|
ASSERT_NE(r0, r1);
|
||||||
|
|
||||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Reproducibility_1) {
|
TEST_F(RNGTests, Test_Reproducibility_1) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||||
|
@ -918,7 +913,6 @@ TEST_F(RNGTests, Test_Reproducibility_1) {
|
||||||
|
|
||||||
#ifndef DEBUG_BUILD
|
#ifndef DEBUG_BUILD
|
||||||
TEST_F(RNGTests, Test_Reproducibility_2) {
|
TEST_F(RNGTests, Test_Reproducibility_2) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 64, 64};
|
std::vector<Nd4jLong> shape = {32, 3, 64, 64};
|
||||||
|
|
|
@ -44,8 +44,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_key_1) {
|
||||||
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
@ -62,8 +61,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_val_1) {
|
||||||
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
@ -81,8 +79,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_key_1) {
|
||||||
|
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
@ -100,9 +97,8 @@ TEST_F(SortCpuTests, test_tad_sort_by_val_1) {
|
||||||
|
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,8 +42,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_key_1) {
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
|
@ -60,8 +59,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
|
@ -78,8 +76,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
||||||
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
k.printIndexedBuffer("KEYS");
|
k.printIndexedBuffer("KEYS");
|
||||||
|
@ -97,8 +94,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
|
@ -119,11 +115,10 @@ TEST_F(SortCudaTests, test_tad_sort_by_val_1) {
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,9 +58,8 @@ TEST_F(TypeCastTests, Test_ConvertDtype_1) {
|
||||||
float16 dst[5];
|
float16 dst[5];
|
||||||
float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f};
|
float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f};
|
||||||
|
|
||||||
NativeOps ops;
|
convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
|
||||||
ops.convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
|
|
||||||
|
|
||||||
for (int e = 0; e < 5; e++)
|
for (int e = 0; e < 5; e++)
|
||||||
ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f);
|
ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue