Refactor NativeOps.h to export C functions

master
Samuel Audet 2019-07-22 20:34:08 +09:00 committed by AlexDBlack
parent fad8da878f
commit dcc72e23b2
23 changed files with 1950 additions and 2089 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -27,11 +27,10 @@ namespace nd4j {
ProviderRNG::ProviderRNG() { ProviderRNG::ProviderRNG() {
Nd4jLong *buffer = new Nd4jLong[100000]; Nd4jLong *buffer = new Nd4jLong[100000];
NativeOps nativeOps;
std::lock_guard<std::mutex> lock(_mutex); std::lock_guard<std::mutex> lock(_mutex);
#ifndef __CUDABLAS__ #ifndef __CUDABLAS__
// at this moment we don't have streams etc, so let's just skip this for now // at this moment we don't have streams etc, so let's just skip this for now
_rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); _rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
#endif #endif
// if(_rng != nullptr) // if(_rng != nullptr)
} }

View File

@ -41,8 +41,7 @@ namespace nd4j {
} }
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream // FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
NativeOps nativeOps; refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
nativeOps.refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
return Status::OK(); return Status::OK();
} }

View File

@ -110,11 +110,9 @@ namespace helpers {
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
NDArray scores(*scales); NDArray scores(*scales);
NativeOps nativeOps;
Nd4jPointer extras[2] = {nullptr, stream}; Nd4jPointer extras[2] = {nullptr, stream};
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
// TO DO: sort indices using scales as value row // TO DO: sort indices using scales as value row
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); //std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer()); I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());

View File

@ -60,8 +60,7 @@ namespace helpers {
params[1] = context->getCudaStream(); params[1] = context->getCudaStream();
if (input->isVector()) { if (input->isVector()) {
NativeOps ops; sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
ops.sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice); cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
} }
@ -74,8 +73,7 @@ namespace helpers {
auto pTadShapeH = packX.primaryShapeInfo(); auto pTadShapeH = packX.primaryShapeInfo();
auto pTadOffsets = packX.specialOffsets(); auto pTadOffsets = packX.specialOffsets();
// auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int)); // auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int));
NativeOps ops; sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
ops.sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
// manager.synchronize(); // manager.synchronize();
sortedVals.tickWriteDevice(); sortedVals.tickWriteDevice();
sortedVals.syncToHost(); sortedVals.syncToHost();

View File

@ -38,32 +38,28 @@ TEST_F(HeaderTest, test_dataTypes_1) {
std::string header("0NUMPY6789{'descr': '>f4"); std::string header("0NUMPY6789{'descr': '>f4");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::FLOAT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
TEST_F(HeaderTest, test_dataTypes_2) { TEST_F(HeaderTest, test_dataTypes_2) {
std::string header("0NUMPY6789{'descr': '>f8"); std::string header("0NUMPY6789{'descr': '>f8");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::DOUBLE, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
TEST_F(HeaderTest, test_dataTypes_3) { TEST_F(HeaderTest, test_dataTypes_3) {
std::string header("0NUMPY6789{'descr': '<i4"); std::string header("0NUMPY6789{'descr': '<i4");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::INT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::INT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
TEST_F(HeaderTest, test_dataTypes_4) { TEST_F(HeaderTest, test_dataTypes_4) {
std::string header("0NUMPY6789{'descr': '>u2"); std::string header("0NUMPY6789{'descr': '>u2");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::UINT16, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::UINT16, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
/* /*
@ -88,8 +84,7 @@ TEST_F(LoadFromStringTest,PathTest) {
ASSERT_EQ(4.0,data[3]); ASSERT_EQ(4.0,data[3]);
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr); Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr); int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
NativeOps nativeOps; Nd4jPointer pointer1 = dataPointForNumpy(loaded);
Nd4jPointer pointer1 = nativeOps.dataPointForNumpy(loaded);
delete[] shapeBuffer; delete[] shapeBuffer;
double *data2 = reinterpret_cast<double *>(pointer1); double *data2 = reinterpret_cast<double *>(pointer1);

View File

@ -472,9 +472,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
/* /*
Nd4jLong *buffer = new Nd4jLong[100000]; Nd4jLong *buffer = new Nd4jLong[100000];
NativeOps nativeOps; nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("RNG initialization failed"); throw std::runtime_error("RNG initialization failed");
@ -496,7 +494,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
ASSERT_TRUE(x->sumNumber() > 0.0); ASSERT_TRUE(x->sumNumber() > 0.0);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
delete variableSpace; delete variableSpace;
@ -1450,8 +1448,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// ////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) { // TEST_F(DeclarableOpsTests1, TestLegacyExecution1) {
// NativeOps nativeOps;
// auto x = NDArrayFactory::create_<float>('c', {10, 10}); // auto x = NDArrayFactory::create_<float>('c', {10, 10});
// x->assign(1.0f); // x->assign(1.0f);
@ -1483,8 +1479,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// outputShapes[0] = (Nd4jPointer) z->getShapeInfo(); // outputShapes[0] = (Nd4jPointer) z->getShapeInfo();
// //auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false); // //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
// auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); // auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
// ASSERT_EQ(ND4J_STATUS_OK, status); // ASSERT_EQ(ND4J_STATUS_OK, status);
// // z->printIndexedBuffer("Output add"); // // z->printIndexedBuffer("Output add");
// ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5); // ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5);
@ -1503,8 +1499,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// ////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) { // TEST_F(DeclarableOpsTests1, TestLegacyExecution2) {
// NativeOps nativeOps;
// auto x = NDArrayFactory::create_<float>('c', {10, 10}); // auto x = NDArrayFactory::create_<float>('c', {10, 10});
// x->assign(1.0f); // x->assign(1.0f);
@ -1532,7 +1526,7 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// auto outputBuffers = new Nd4jPointer[1]; // auto outputBuffers = new Nd4jPointer[1];
// auto outputShapes = new Nd4jPointer[1]; // auto outputShapes = new Nd4jPointer[1];
// nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true); // execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
// ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5); // ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5);
// ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5); // ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5);

View File

@ -876,14 +876,13 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims); auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims); auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
NativeOps op;
Nd4jPointer nativeStart[2]; Nd4jPointer nativeStart[2];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
nativeStart[1] = *(x.getContext()->getCudaStream()); nativeStart[1] = *(x.getContext()->getCudaStream());
#endif #endif
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
4, pidx, 4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
@ -912,12 +911,11 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims); auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims); auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
NativeOps op;
Nd4jPointer nativeStart[2]; Nd4jPointer nativeStart[2];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
nativeStart[1] = *(x.getContext()->getCudaStream()); nativeStart[1] = *(x.getContext()->getCudaStream());
#endif #endif
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
4, pidx, 4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),

View File

@ -110,8 +110,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
double extraParams[] = {lambda}; double extraParams[] = {lambda};
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
NativeOps nativeOps; auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
@ -122,7 +121,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01); ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
} }
@ -142,8 +141,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
NativeOps nativeOps; auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
@ -155,7 +153,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01); ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
} }
@ -172,8 +170,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
double extraParams[] = {lambda}; double extraParams[] = {lambda};
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
NativeOps nativeOps; auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
@ -184,7 +181,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01); ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
} }
*/ */
@ -206,14 +203,13 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
// Nd4jPointer extra[2]; // Nd4jPointer extra[2];
#ifndef __CUDABLAS__ #ifndef __CUDABLAS__
NativeOps nativeOps; nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams); functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
#endif #endif
const double actualMean = x.meanNumber().e<double>(0); const double actualMean = x.meanNumber().e<double>(0);
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0); const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
@ -1005,12 +1001,10 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
x0.linspace(1); x0.linspace(1);
x1.linspace(1); x1.linspace(1);
/* /*
NativeOps nativeOps;
float prob[] = {0.5f}; float prob[] = {0.5f};
Nd4jLong* _bufferA = new Nd4jLong[100000]; Nd4jLong* _bufferA = new Nd4jLong[100000];
long _seed = 119L; long _seed = 119L;
auto _rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); auto _rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
x0. applyTransform(random::DropOutInverted, &x0, prob); x0. applyTransform(random::DropOutInverted, &x0, prob);
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob); // x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
@ -1026,7 +1020,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
// ASSERT_FALSE(x0.equalsTo(nexp0)); // ASSERT_FALSE(x0.equalsTo(nexp0));
// ASSERT_FALSE(x0.equalsTo(nexp1)); // ASSERT_FALSE(x0.equalsTo(nexp1));
// ASSERT_FALSE(x0.equalsTo(nexp2)); // ASSERT_FALSE(x0.equalsTo(nexp2));
nativeOps.destroyRandom(_rngA); destroyRandom(_rngA);
delete [] _bufferA; delete [] _bufferA;
*/ */
nd4j::ops::dropout op; nd4j::ops::dropout op;

View File

@ -51,9 +51,7 @@ public:
*/ */
TEST_F(GraphStateTests, Basic_Tests_1) { TEST_F(GraphStateTests, Basic_Tests_1) {
NativeOps nativeOps; auto state = (GraphState *) getGraphState(117L);
auto state = (GraphState *) nativeOps.getGraphState(117L);
ASSERT_EQ(117L, state->id()); ASSERT_EQ(117L, state->id());
// this call will create scope internally // this call will create scope internally
@ -72,14 +70,12 @@ TEST_F(GraphStateTests, Basic_Tests_1) {
ASSERT_TRUE(scope != nullptr); ASSERT_TRUE(scope != nullptr);
ASSERT_EQ(2, scope->size()); ASSERT_EQ(2, scope->size());
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
// just separate case for doubles wrapper in NativeOps, nothing else // just separate case for doubles wrapper in NativeOps, nothing else
TEST_F(GraphStateTests, Basic_Tests_2) { TEST_F(GraphStateTests, Basic_Tests_2) {
NativeOps nativeOps; auto state = (GraphState *) getGraphState(117L);
auto state = (GraphState *) nativeOps.getGraphState(117L);
ASSERT_EQ(117L, state->id()); ASSERT_EQ(117L, state->id());
// this call will create scope internally // this call will create scope internally
@ -98,46 +94,40 @@ TEST_F(GraphStateTests, Basic_Tests_2) {
ASSERT_TRUE(scope != nullptr); ASSERT_TRUE(scope != nullptr);
ASSERT_EQ(2, scope->size()); ASSERT_EQ(2, scope->size());
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
TEST_F(GraphStateTests, Stateful_Execution_1) { TEST_F(GraphStateTests, Stateful_Execution_1) {
NativeOps nativeOps; auto state = getGraphState(117L);
auto state = nativeOps.getGraphState(117L);
Nd4jLong scopes[] = {22, 33}; Nd4jLong scopes[] = {22, 33};
//auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); //auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
ASSERT_EQ(Status::THROW(), status); ASSERT_EQ(Status::THROW(), status);
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
TEST_F(GraphStateTests, Stateful_Execution_2) { TEST_F(GraphStateTests, Stateful_Execution_2) {
NativeOps nativeOps; auto state = (GraphState *) getGraphState(117L);
auto state = (GraphState *) nativeOps.getGraphState(117L);
state->registerScope(22); state->registerScope(22);
state->registerScope(33); state->registerScope(33);
Nd4jLong scopes[] = {22, 33}; Nd4jLong scopes[] = {22, 33};
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
// it's no-op: just LogicScope // it's no-op: just LogicScope
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
/** /**
* This test checks WHILE loop * This test checks WHILE loop
*/ */
TEST_F(GraphStateTests, Stateful_Execution_3) { TEST_F(GraphStateTests, Stateful_Execution_3) {
NativeOps nativeOps;
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto var1 = NDArrayFactory::create<float>(11.0f); auto var1 = NDArrayFactory::create<float>(11.0f);
auto var2 = NDArrayFactory::create<float>(2.0f); auto var2 = NDArrayFactory::create<float>(2.0f);
@ -147,7 +137,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
auto res2 = NDArrayFactory::create<float>(0.0f); auto res2 = NDArrayFactory::create<float>(0.0f);
// registering our GraphState holder // registering our GraphState holder
auto state = (GraphState *) nativeOps.getGraphState(117L); auto state = (GraphState *) getGraphState(117L);
// we're prepping pointers to input/output buffers // we're prepping pointers to input/output buffers
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()}; Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
@ -197,7 +187,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
Nd4jLong scopes[] = {22, 33}; Nd4jLong scopes[] = {22, 33};
// we're executing while loop // we're executing while loop
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3); auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
// now we check provided result array // now we check provided result array
@ -211,7 +201,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
// nd4j_printf("0 ------------------\n",""); // nd4j_printf("0 ------------------\n","");
nativeOps.deleteGraphState(state); deleteGraphState(state);
// nd4j_printf("1 ------------------\n",""); // nd4j_printf("1 ------------------\n","");
} }
@ -220,8 +210,6 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
* This test checks CONDITIONAL execution for FALSE * This test checks CONDITIONAL execution for FALSE
*/ */
TEST_F(GraphStateTests, Stateful_Execution_4) { TEST_F(GraphStateTests, Stateful_Execution_4) {
NativeOps nativeOps;
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto var1 = NDArrayFactory::create<float>(5.0f); auto var1 = NDArrayFactory::create<float>(5.0f);
@ -232,7 +220,7 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
// registering our GraphState holder // registering our GraphState holder
auto state = (GraphState *) nativeOps.getGraphState(117L); auto state = (GraphState *) getGraphState(117L);
// we're prepping pointers to input/output buffers // we're prepping pointers to input/output buffers
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
@ -283,14 +271,14 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
Nd4jLong scopes[] = {22, 33, 44}; Nd4jLong scopes[] = {22, 33, 44};
// we're executing conditional op // we're executing conditional op
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(&res0)); ASSERT_TRUE(exp.isSameShape(&res0));
ASSERT_TRUE(exp.equalsTo(&res0)); ASSERT_TRUE(exp.equalsTo(&res0));
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
@ -298,8 +286,6 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
* This test checks CONDITIONAL execution for TRUE * This test checks CONDITIONAL execution for TRUE
*/ */
TEST_F(GraphStateTests, Stateful_Execution_5) { TEST_F(GraphStateTests, Stateful_Execution_5) {
NativeOps nativeOps;
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto var1 = NDArrayFactory::create<float>(5.0f); auto var1 = NDArrayFactory::create<float>(5.0f);
@ -310,7 +296,7 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
// registering our GraphState holder // registering our GraphState holder
auto state = (GraphState *) nativeOps.getGraphState(117L); auto state = (GraphState *) getGraphState(117L);
// we're prepping pointers to input/output buffers // we're prepping pointers to input/output buffers
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
@ -361,12 +347,11 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
Nd4jLong scopes[] = {22, 33, 44}; Nd4jLong scopes[] = {22, 33, 44};
// we're executing conditional op // we're executing conditional op
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(&res0)); ASSERT_TRUE(exp.isSameShape(&res0));
ASSERT_TRUE(exp.equalsTo(&res0)); ASSERT_TRUE(exp.equalsTo(&res0));
deleteGraphState(state);
nativeOps.deleteGraphState(state);
} }

View File

@ -42,7 +42,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
e.assign(2.f); e.assign(2.f);
nd4j::ops::add op; nd4j::ops::add op;
NativeOps nativeOps;
Context context(1); Context context(1);
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
@ -53,7 +52,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
nd4j_printf("Starting execution...\n",""); nd4j_printf("Starting execution...\n","");
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1"); PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context); execCustomOp2(nullptr, op.getOpHash(), &context);
pm.synchronize(); pm.synchronize();
@ -71,7 +70,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
e.assign(false); e.assign(false);
nd4j::ops::equals op; nd4j::ops::equals op;
NativeOps nativeOps;
Context context(1); Context context(1);
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
@ -82,7 +80,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
nd4j_printf("Starting execution...\n",""); nd4j_printf("Starting execution...\n","");
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2"); PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context); execCustomOp2(nullptr, op.getOpHash(), &context);
pm.synchronize(); pm.synchronize();

View File

@ -41,8 +41,6 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3}); auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4}); auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
NativeOps nativeOps;
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
std::vector<double> tArgs({}); std::vector<double> tArgs({});
@ -50,7 +48,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()}; Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
ASSERT_EQ(1, shapeList->size()); ASSERT_EQ(1, shapeList->size());
@ -64,7 +62,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
//delete[] ptr; //delete[] ptr;
//delete shapeList; //delete shapeList;
nativeOps.deleteShapeList((Nd4jPointer) shapeList); deleteShapeList((Nd4jPointer) shapeList);
} }
@ -72,9 +70,6 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4}); auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4}); auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
NativeOps nativeOps;
nd4j::ops::shape_of op; nd4j::ops::shape_of op;
std::vector<double> tArgs({}); std::vector<double> tArgs({});
@ -83,14 +78,14 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()}; Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
ASSERT_EQ(1, shapeList->size()); ASSERT_EQ(1, shapeList->size());
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
nativeOps.deleteShapeList((Nd4jPointer) shapeList); deleteShapeList((Nd4jPointer) shapeList);
} }
TEST_F(JavaInteropTests, TestShapeExposure3) { TEST_F(JavaInteropTests, TestShapeExposure3) {
@ -112,13 +107,12 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()}; Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()}; Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
NativeOps nativeOps;
nd4j::ops::split_v op; nd4j::ops::split_v op;
Nd4jLong iArgs[] = {1}; Nd4jLong iArgs[] = {1};
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto shapeList = nativeOps.calculateOutputShapes(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0); auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
ASSERT_EQ(3, shapeList->size()); ASSERT_EQ(3, shapeList->size());
@ -126,7 +120,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1))); ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2))); ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
nativeOps.deleteShapeList((Nd4jPointer) shapeList); deleteShapeList((Nd4jPointer) shapeList);
} }
TEST_F(JavaInteropTests, Test_Squeeze_1) { TEST_F(JavaInteropTests, Test_Squeeze_1) {
@ -143,10 +137,7 @@ TEST_F(JavaInteropTests, Test_Squeeze_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NativeOps nativeOps;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -167,10 +158,7 @@ TEST_F(JavaInteropTests, Test_RDiv_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NativeOps nativeOps;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -203,11 +191,9 @@ TEST_F(JavaInteropTests, TestSconv2d_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
NativeOps nativeOps;
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}; Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1, execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
nullptr, 0, exp, 9, nullptr, 0, false); nullptr, 0, exp, 9, nullptr, 0, false);
//output.printBuffer("output"); //output.printBuffer("output");
@ -238,11 +224,9 @@ TEST_F(JavaInteropTests, TestSconv2d_2) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
NativeOps nativeOps;
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0}; Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
//output.printBuffer("output"); //output.printBuffer("output");
@ -266,9 +250,7 @@ TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
nd4j::ops::maxpool2d op; nd4j::ops::maxpool2d op;
NativeOps nativeOps; Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
Nd4jStatus status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
} }
@ -294,13 +276,11 @@ TEST_F(JavaInteropTests, TestCol2Im_1) {
nd4j::ops::col2im op; nd4j::ops::col2im op;
NativeOps nativeOps;
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1}; Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
auto hash = op.getOpHash(); auto hash = op.getOpHash();
nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f); ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
} }
@ -320,8 +300,6 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3}); auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
input.linspace(1); input.linspace(1);
NativeOps nativeOps;
nd4j::ops::pnormpool2d op; nd4j::ops::pnormpool2d op;
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0}; Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
@ -332,7 +310,7 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0); ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
} }
@ -343,8 +321,6 @@ TEST_F(JavaInteropTests, TestInplace_1) {
//auto exp('c', {10, 10}); //auto exp('c', {10, 10});
input.linspace(1); input.linspace(1);
NativeOps nativeOps;
nd4j::ops::clipbyvalue op; nd4j::ops::clipbyvalue op;
double extras[] = {-1.0f, 1.0f}; double extras[] = {-1.0f, 1.0f};
@ -353,7 +329,7 @@ TEST_F(JavaInteropTests, TestInplace_1) {
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()}; Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
Nd4jStatus result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true); Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
ASSERT_EQ(ND4J_STATUS_OK, result); ASSERT_EQ(ND4J_STATUS_OK, result);
@ -415,7 +391,6 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
x.linspace(1.0); x.linspace(1.0);
z.linspace(1.0); z.linspace(1.0);
NativeOps nativeOps;
nd4j::ops::avgpool2d op; nd4j::ops::avgpool2d op;
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); //auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
@ -427,7 +402,7 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
auto result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), result); ASSERT_EQ(Status::OK(), result);
@ -496,15 +471,13 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
/* /*
TEST_F(JavaInteropTests, Test_GraphReuse_1) { TEST_F(JavaInteropTests, Test_GraphReuse_1) {
NativeOps nativeOps;
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data); registerGraph(nullptr, 119, (Nd4jPointer) data);
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
nativeOps.unregisterGraph(nullptr, 119); unregisterGraph(nullptr, 119);
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
@ -520,8 +493,6 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6}); auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9}); auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
NativeOps nativeOps;
// we load graph from file, because we're not in java here, and dont have buffer ready // we load graph from file, because we're not in java here, and dont have buffer ready
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
@ -529,7 +500,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
// register the graph, to call for it later // register the graph, to call for it later
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data); registerGraph(nullptr, 119, (Nd4jPointer) data);
// and ensure we're ok // and ensure we're ok
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
@ -547,7 +518,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()}; Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
// now we're executing stored graph and providing replacement for input variable // now we're executing stored graph and providing replacement for input variable
auto res_0 = nativeOps.executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1); auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_0->status()); ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
ASSERT_EQ(1, res_0->size()); ASSERT_EQ(1, res_0->size());
@ -562,7 +533,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()}; Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
// doing it again // doing it again
auto res_1 = nativeOps.executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1); auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_1->status()); ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
ASSERT_EQ(1, res_1->size()); ASSERT_EQ(1, res_1->size());
@ -577,7 +548,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()}; Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
// and again // and again
auto res_2 = nativeOps.executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1); auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_1->status()); ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
ASSERT_EQ(1, res_2->size()); ASSERT_EQ(1, res_2->size());
@ -586,7 +557,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
//////// clean out //////// clean out
nativeOps.unregisterGraph(nullptr, 119); unregisterGraph(nullptr, 119);
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
@ -616,9 +587,7 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
NativeOps nativeOps; execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
o.printIndexedBuffer("Greater JIT"); o.printIndexedBuffer("Greater JIT");
ASSERT_TRUE(exp.equalsTo(&o)); ASSERT_TRUE(exp.equalsTo(&o));
} }
@ -641,9 +610,7 @@ TEST_F(JavaInteropTests, Test_Greater_2) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
NativeOps nativeOps; execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_TRUE(exp.equalsTo(&o)); ASSERT_TRUE(exp.equalsTo(&o));
} }
@ -662,9 +629,8 @@ TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.equalsTo(&o)); ASSERT_TRUE(exp.equalsTo(&o));
@ -685,9 +651,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
@ -710,9 +675,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
@ -736,9 +700,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
Nd4jLong iArgs[] = {1}; Nd4jLong iArgs[] = {1};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(e.isSameShape(output)); ASSERT_TRUE(e.isSameShape(output));
@ -753,8 +716,7 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1}); auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
NativeOps nativeOps; execReduce3Tad(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
nativeOps.execReduce3(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
y.buffer(), y.shapeInfo(), nullptr, nullptr, y.buffer(), y.shapeInfo(), nullptr, nullptr,
z.buffer(), z.shapeInfo(), nullptr, nullptr, z.buffer(), z.shapeInfo(), nullptr, nullptr,
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr); dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
@ -764,10 +726,8 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
Environment::getInstance()->setDebug(true); Environment::getInstance()->setDebug(true);
Environment::getInstance()->setVerbose(false); Environment::getInstance()->setVerbose(false);
NativeOps ops;
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb"); auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
auto ptr = ops.executeFlatGraph(nullptr, pl); auto ptr = executeFlatGraph(nullptr, pl);
Environment::getInstance()->setDebug(false); Environment::getInstance()->setDebug(false);
Environment::getInstance()->setVerbose(false); Environment::getInstance()->setVerbose(false);
@ -792,9 +752,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())}; Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
@ -818,9 +777,8 @@ TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
nd4j::ops::maxpool2d op; nd4j::ops::maxpool2d op;
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -843,9 +801,8 @@ TEST_F(JavaInteropTests, Test_Unstack_1) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -864,9 +821,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())}; Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
@ -883,8 +839,7 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0}); auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8}); auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
NativeOps ops; execPairwiseTransform(nullptr, pairwise::Add,
ops.execPairwiseTransform(nullptr, pairwise::Add,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr, arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
@ -898,7 +853,6 @@ TEST_F(JavaInteropTests, Test_Add_1) {
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1}); auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2}); auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
NativeOps nativeOps;
nd4j::ops::add op; nd4j::ops::add op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()}; Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
@ -907,7 +861,7 @@ TEST_F(JavaInteropTests, Test_Add_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(e, x); ASSERT_EQ(e, x);
} }
@ -920,7 +874,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
NativeOps nativeOps;
nd4j::ops::zeta op; nd4j::ops::zeta op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()}; Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
@ -929,7 +882,7 @@ TEST_F(JavaInteropTests, zeta_test10) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
@ -939,8 +892,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1) {
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0}); auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0}); auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
NativeOps ops; execTransformAny(nullptr, transform::IsMax,
ops.execTransformAny(nullptr, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
nullptr); nullptr);
@ -953,8 +905,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0}); auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0}); auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
NativeOps ops; execTransformAny(nullptr, transform::IsMax,
ops.execTransformAny(nullptr, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
nullptr); nullptr);
@ -970,8 +921,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_2) {
Nd4jLong *ex[] = {tad, off}; Nd4jLong *ex[] = {tad, off};
float ea[] = {2, 1, 2}; float ea[] = {2, 1, 2};
NativeOps ops; execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
ops.execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
ea); ea);
@ -995,8 +945,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
nd4j::ops::greater_equal op; nd4j::ops::greater_equal op;
NativeOps ops; auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
auto shapeList = ops.calculateOutputShapes(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
delete shapeList; delete shapeList;
} }
@ -1013,8 +962,7 @@ TEST_F(JavaInteropTests, Test_L2_Loss_3) {
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())}; Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
nd4j::ops::l2_loss op; nd4j::ops::l2_loss op;
NativeOps ops; auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
auto status = ops.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
z.printIndexedBuffer("z"); z.printIndexedBuffer("z");
@ -1036,9 +984,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_3) {
ASSERT_EQ(2, ctx.width()); ASSERT_EQ(2, ctx.width());
NativeOps nativeOps;
nd4j::ops::add op; nd4j::ops::add op;
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(exp, z); ASSERT_EQ(exp, z);
} }
@ -1054,9 +1001,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_4) {
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
ctx.setIArguments(iArgs, 3); ctx.setIArguments(iArgs, 3);
NativeOps nativeOps;
nd4j::ops::tri op; nd4j::ops::tri op;
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(exp, z); ASSERT_EQ(exp, z);
} }
@ -1074,9 +1020,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_5) {
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo()); ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
NativeOps nativeOps;
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -1104,9 +1049,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_6) {
ctx.setIArguments(iArgs, 3); ctx.setIArguments(iArgs, 3);
NativeOps nativeOps;
nd4j::ops::matmul_bp op; nd4j::ops::matmul_bp op;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -1122,7 +1066,6 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
ctx.setIArguments(iArgs, 1); ctx.setIArguments(iArgs, 1);
NativeOps nativeOps;
nd4j::ops::concat op; nd4j::ops::concat op;
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
@ -1130,7 +1073,7 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -1138,10 +1081,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
/* /*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) { TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
NativeOps ops;
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
auto ptr = ops.executeFlatGraph(nullptr, pl); auto ptr = executeFlatGraph(nullptr, pl);
// at this point we have FlatResults // at this point we have FlatResults
auto flatResult = GetFlatResult(ptr->pointer()); auto flatResult = GetFlatResult(ptr->pointer());
@ -1190,8 +1131,6 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
} }
*/ */
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) { // TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
// NativeOps ops;
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f}; // std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
// std::array<float, 60> syn1; // std::array<float, 60> syn1;
// std::array<float, 100000> exp; // std::array<float, 100000> exp;
@ -1283,5 +1222,5 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data()); // ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
// ops.execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data()); // execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
// } // }

View File

@ -53,8 +53,7 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
x.syncToDevice(); x.syncToDevice();
NativeOps nativeOps; sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
nativeOps.sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
x.tickWriteDevice(); x.tickWriteDevice();
ASSERT_EQ(e, x); ASSERT_EQ(e, x);

View File

@ -501,8 +501,7 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps; execReduce3Tad(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
nativeOps.execReduce3(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
nullptr, nullptr, nullptr, nullptr); nullptr, nullptr, nullptr, nullptr);
} }
@ -517,9 +516,8 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps;
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, execReduce3Tad(nullptr, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
@ -543,9 +541,8 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps;
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, execReduce3Tad(nullptr, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
@ -569,9 +566,8 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps;
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, execReduce3Tad(nullptr, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
@ -593,8 +589,7 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1); auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
NativeOps ops; execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
ops.execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),

View File

@ -33,8 +33,6 @@ public:
}; };
TEST_F(MmapTests, Test_Basic_Mmap_1) { TEST_F(MmapTests, Test_Basic_Mmap_1) {
NativeOps nativeOps;
// just 10GB // just 10GB
Nd4jLong size = 100000L; Nd4jLong size = 100000L;
@ -43,11 +41,11 @@ TEST_F(MmapTests, Test_Basic_Mmap_1) {
ofs.write("", 1); ofs.write("", 1);
ofs.close(); ofs.close();
auto result = nativeOps.mmapFile(nullptr, "file", size); auto result = mmapFile(nullptr, "file", size);
ASSERT_FALSE(result == nullptr); ASSERT_FALSE(result == nullptr);
nativeOps.munmapFile(nullptr, result, size); munmapFile(nullptr, result, size);
remove("file"); remove("file");
} }

View File

@ -2258,7 +2258,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9}); auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {5, 8}); auto z = NDArrayFactory::create<float>('c', {5, 8});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(4); std::vector<void*> buffers(4);
@ -2272,7 +2271,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
} }
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat result linear"); z.printBuffer("C Concat result linear");
@ -2281,7 +2280,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9}); auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('f', {5, 8}); auto z = NDArrayFactory::create<float>('f', {5, 8});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(4); std::vector<void*> buffers(4);
@ -2295,7 +2293,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
} }
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("F Concat result linear"); z.printBuffer("F Concat result linear");
@ -2304,7 +2302,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6}); auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9}); auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('f', {3, 3}); auto z = NDArrayFactory::create<float>('f', {3, 3});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(2); std::vector<void*> buffers(2);
@ -2321,7 +2318,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
//} //}
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("F Concat result linear"); z.printBuffer("F Concat result linear");
@ -2331,7 +2328,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6}); auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9}); auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {3, 3}); auto z = NDArrayFactory::create<float>('c', {3, 3});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(2); std::vector<void*> buffers(2);
@ -2348,7 +2344,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
//} //}
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat result linear"); z.printBuffer("C Concat result linear");
@ -2358,7 +2354,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6}); auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12}); auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {2, 2, 3}); auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(2); std::vector<void*> buffers(2);
@ -2375,7 +2370,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
//} //}
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat result linear"); z.printBuffer("C Concat result linear");
@ -2385,7 +2380,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12}); auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18}); auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24}); auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {4, 2, 3}); auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(3); std::vector<void*> buffers(3);
@ -2406,7 +2400,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
printf("The third array is %p\n", buffers[2]); printf("The third array is %p\n", buffers[2]);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat3D result linear"); z.printBuffer("C Concat3D result linear");
@ -2417,7 +2411,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
auto x1 = NDArrayFactory::create<float>(1); auto x1 = NDArrayFactory::create<float>(1);
auto x2 = NDArrayFactory::create<float>(2); auto x2 = NDArrayFactory::create<float>(2);
auto x3 = NDArrayFactory::create<float>(3); auto x3 = NDArrayFactory::create<float>(3);
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3}); auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(3); std::vector<void*> buffers(3);
@ -2438,7 +2431,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
printf("The third array is %p\n", buffers[2]); printf("The third array is %p\n", buffers[2]);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat scalar result linear"); z.printBuffer("C Concat scalar result linear");
@ -2462,7 +2455,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
lx[i].assign(i); lx[i].assign(i);
} }
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {totalCount, width}); auto z = NDArrayFactory::create<float>('c', {totalCount, width});
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(totalCount); std::vector<void*> buffers(totalCount);
@ -2478,7 +2470,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
printf("The third array is %p\n", buffers[2]); printf("The third array is %p\n", buffers[2]);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1))); nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
//z.printIndexedBuffer("Concat result"); //z.printIndexedBuffer("Concat result");
@ -2496,7 +2488,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
arrays.emplace_back(input); arrays.emplace_back(input);
} }
auto z = NDArrayFactory::create<float>('c', {total, 10, 10}); auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
NativeOps native;
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
@ -2512,7 +2503,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
hostShapes[i] = arrays[i].shapeInfo(); hostShapes[i] = arrays[i].shapeInfo();
} }
native.concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
nd4j::ops::tear op; nd4j::ops::tear op;
auto result = op.execute({&z}, {}, {1, 2}); auto result = op.execute({&z}, {}, {1, 2});
@ -2536,7 +2527,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
arrays.emplace_back(input); arrays.emplace_back(input);
} }
auto z = NDArrayFactory::create<float>('c', {10, 10, 10}); auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
NativeOps native;
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
@ -2552,7 +2542,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
hostShapes[i] = arrays[i].shapeInfo(); hostShapes[i] = arrays[i].shapeInfo();
} }
std::vector<int> dimsToExclude({1,2}); std::vector<int> dimsToExclude({1,2});
native.concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
// z.syncToHost(); // z.syncToHost();
// z.printBuffer("Pile OK"); // z.printBuffer("Pile OK");
// z.printIndexedBuffer("Pile 10x10"); // z.printIndexedBuffer("Pile 10x10");
@ -2569,7 +2559,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
Nd4jPointer target = arrays[i].specialBuffer(); Nd4jPointer target = arrays[i].specialBuffer();
cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice); cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice);
} }
native.tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets()); ::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
// auto result = op.execute({&z}, {}, {1, 2}); // auto result = op.execute({&z}, {}, {1, 2});
// nd4j_printf("Result count is %lu\n", result->size()); // nd4j_printf("Result count is %lu\n", result->size());
//ASSERT_EQ(10, result->size()); //ASSERT_EQ(10, result->size());

View File

@ -313,12 +313,10 @@ TEST_F(PlaygroundTests, test_reduce_3) {
Nd4jLong max = 0L; Nd4jLong max = 0L;
Nd4jLong min = DataTypeUtils::max<Nd4jLong>(); Nd4jLong min = DataTypeUtils::max<Nd4jLong>();
NativeOps nativeOps;
for (int e = 0; e < iterations; e++) { for (int e = 0; e < iterations; e++) {
auto timeStart = std::chrono::system_clock::now(); auto timeStart = std::chrono::system_clock::now();
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), execReduce3Tad(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(),
y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr, dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr,
@ -964,8 +962,6 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count(); auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count();
NativeOps nativeOps;
Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0}; Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0};
Nd4jPointer inputBuffers[] = {input.buffer()}; Nd4jPointer inputBuffers[] = {input.buffer()};
Nd4jPointer inputShapes[] = {input.shapeInfo()}; Nd4jPointer inputShapes[] = {input.shapeInfo()};
@ -976,7 +972,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
auto javaStart = std::chrono::system_clock::now(); auto javaStart = std::chrono::system_clock::now();
for (int e = 0; e < iterations; e++) { for (int e = 0; e < iterations; e++) {
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
} }
auto javaEnd = std::chrono::system_clock::now(); auto javaEnd = std::chrono::system_clock::now();
@ -990,7 +986,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
for (int e = 0; e < iterations; e++) { for (int e = 0; e < iterations; e++) {
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
} }
auto javaPermEnd = std::chrono::system_clock::now(); auto javaPermEnd = std::chrono::system_clock::now();
@ -1020,9 +1016,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_2) {
Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()}; Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()};
Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()}; Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()};
NativeOps nativeOps; execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
} }
TEST_F(PlaygroundTests, Test_Col2Im_1) { TEST_F(PlaygroundTests, Test_Col2Im_1) {
@ -1140,8 +1134,6 @@ TEST_F(PlaygroundTests, loop_test_1) {
int length = (int) array->lengthOf(); int length = (int) array->lengthOf();
int span = (int) (array->lengthOf() / 6) + 8; int span = (int) (array->lengthOf() / 6) + 8;
NativeOps ops;
auto t = new int[1000000]; auto t = new int[1000000];
@ -1150,7 +1142,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
FloatBits fb; FloatBits fb;
float threshold = 0.99f; float threshold = 0.99f;
fb.f_ = threshold; fb.f_ = threshold;
int le = ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold); int le = estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
t[0] = le; t[0] = le;
t[1] = length; t[1] = length;
@ -1162,7 +1154,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
for (int x = 0; x < iterations; x++) { for (int x = 0; x < iterations; x++) {
auto permStart = std::chrono::system_clock::now(); auto permStart = std::chrono::system_clock::now();
ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold); estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t); TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t);
auto permEnd = std::chrono::system_clock::now(); auto permEnd = std::chrono::system_clock::now();

View File

@ -29,7 +29,6 @@ using namespace nd4j;
class RNGTests : public testing::Test { class RNGTests : public testing::Test {
private: private:
NativeOps nativeOps;
//Nd4jLong *_bufferA; //Nd4jLong *_bufferA;
//Nd4jLong *_bufferB; //Nd4jLong *_bufferB;
@ -47,8 +46,8 @@ public:
RNGTests() { RNGTests() {
//_bufferA = new Nd4jLong[100000]; //_bufferA = new Nd4jLong[100000];
//_bufferB = new Nd4jLong[100000]; //_bufferB = new Nd4jLong[100000];
//_rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); //_rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
//_rngB = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); //_rngB = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
_rngA.setStates(_seed, _seed); _rngA.setStates(_seed, _seed);
_rngB.setStates(_seed, _seed); _rngB.setStates(_seed, _seed);
nexp0->assign(-1.0f); nexp0->assign(-1.0f);
@ -57,8 +56,8 @@ public:
} }
~RNGTests() { ~RNGTests() {
//nativeOps.destroyRandom(_rngA); //destroyRandom(_rngA);
//nativeOps.destroyRandom(_rngB); //destroyRandom(_rngB);
//delete[] _bufferA; //delete[] _bufferA;
//delete[] _bufferB; //delete[] _bufferB;
@ -791,14 +790,13 @@ namespace nd4j {
} }
TEST_F(RNGTests, Test_Reproducibility_9) { TEST_F(RNGTests, Test_Reproducibility_9) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 28, 28}; std::vector<Nd4jLong> shape = {32, 3, 28, 28};
const int bufferSize = 10000; const int bufferSize = 10000;
int64_t buffer[bufferSize]; int64_t buffer[bufferSize];
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer); auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
const int length = 4000000; const int length = 4000000;
int *arrayE = new int[length]; int *arrayE = new int[length];
@ -809,7 +807,7 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
rng->rewindH(static_cast<Nd4jLong>(length)); rng->rewindH(static_cast<Nd4jLong>(length));
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng)); refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
for (int e = 0; e < length; e++) for (int e = 0; e < length; e++)
arrayT[e] = rng->relativeInt(e); arrayT[e] = rng->relativeInt(e);
@ -825,18 +823,17 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
delete[] arrayE; delete[] arrayE;
delete[] arrayT; delete[] arrayT;
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng)); destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
} }
TEST_F(RNGTests, Test_Reproducibility_8) { TEST_F(RNGTests, Test_Reproducibility_8) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<int> shape = {32, 3, 28, 28}; std::vector<int> shape = {32, 3, 28, 28};
const int bufferSize = 10000; const int bufferSize = 10000;
int64_t buffer[bufferSize]; int64_t buffer[bufferSize];
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer); auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
const int length = 4000000; const int length = 4000000;
int *arrayE = new int[length]; int *arrayE = new int[length];
@ -847,7 +844,7 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
rng->rewindH(static_cast<Nd4jLong>(length)); rng->rewindH(static_cast<Nd4jLong>(length));
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng)); refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
for (int e = 0; e < length; e++) for (int e = 0; e < length; e++)
arrayT[e] = static_cast<int>(rng->relativeT<float>(e)); arrayT[e] = static_cast<int>(rng->relativeT<float>(e));
@ -863,29 +860,27 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
delete[] arrayE; delete[] arrayE;
delete[] arrayT; delete[] arrayT;
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng)); destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
} }
TEST_F(RNGTests, Test_RandomBuffer_Half_1) { TEST_F(RNGTests, Test_RandomBuffer_Half_1) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 28, 28}; std::vector<Nd4jLong> shape = {32, 3, 28, 28};
const int bufferSize = 10000; const int bufferSize = 10000;
int64_t buffer[bufferSize]; int64_t buffer[bufferSize];
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer); auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
auto r0 = rng->relativeT<float16>(12L); auto r0 = rng->relativeT<float16>(12L);
auto r1 = rng->relativeT<float16>(13L); auto r1 = rng->relativeT<float16>(13L);
ASSERT_NE(r0, r1); ASSERT_NE(r0, r1);
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng)); destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
} }
TEST_F(RNGTests, Test_Reproducibility_1) { TEST_F(RNGTests, Test_Reproducibility_1) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 28, 28}; std::vector<Nd4jLong> shape = {32, 3, 28, 28};
@ -918,7 +913,6 @@ TEST_F(RNGTests, Test_Reproducibility_1) {
#ifndef DEBUG_BUILD #ifndef DEBUG_BUILD
TEST_F(RNGTests, Test_Reproducibility_2) { TEST_F(RNGTests, Test_Reproducibility_2) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 64, 64}; std::vector<Nd4jLong> shape = {32, 3, 64, 64};

View File

@ -44,8 +44,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_key_1) {
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
NativeOps nativeOps; sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
@ -62,8 +61,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_val_1) {
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
NativeOps nativeOps; sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
@ -81,8 +79,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_key_1) {
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
@ -100,8 +97,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_val_1) {
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);

View File

@ -42,8 +42,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_key_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NativeOps nativeOps; sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
@ -60,8 +59,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NativeOps nativeOps; sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
@ -78,8 +76,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NativeOps nativeOps; sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
k.printIndexedBuffer("KEYS"); k.printIndexedBuffer("KEYS");
@ -97,8 +94,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
@ -119,8 +115,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_val_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();

View File

@ -58,8 +58,7 @@ TEST_F(TypeCastTests, Test_ConvertDtype_1) {
float16 dst[5]; float16 dst[5];
float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f}; float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f};
NativeOps ops; convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
ops.convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
for (int e = 0; e < 5; e++) for (int e = 0; e < 5; e++)
ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f); ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f);