Refactor NativeOps.h to export C functions
parent
fad8da878f
commit
dcc72e23b2
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -27,11 +27,10 @@ namespace nd4j {
|
|||
ProviderRNG::ProviderRNG() {
|
||||
|
||||
Nd4jLong *buffer = new Nd4jLong[100000];
|
||||
NativeOps nativeOps;
|
||||
std::lock_guard<std::mutex> lock(_mutex);
|
||||
#ifndef __CUDABLAS__
|
||||
// at this moment we don't have streams etc, so let's just skip this for now
|
||||
_rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||
_rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||
#endif
|
||||
// if(_rng != nullptr)
|
||||
}
|
||||
|
|
|
@ -41,8 +41,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
|
||||
NativeOps nativeOps;
|
||||
nativeOps.refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
||||
refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -110,11 +110,9 @@ namespace helpers {
|
|||
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
|
||||
|
||||
NDArray scores(*scales);
|
||||
NativeOps nativeOps;
|
||||
|
||||
Nd4jPointer extras[2] = {nullptr, stream};
|
||||
|
||||
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||
// TO DO: sort indices using scales as value row
|
||||
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
||||
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
||||
|
|
|
@ -60,8 +60,7 @@ namespace helpers {
|
|||
params[1] = context->getCudaStream();
|
||||
|
||||
if (input->isVector()) {
|
||||
NativeOps ops;
|
||||
ops.sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
|
||||
sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
|
||||
|
||||
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
|
||||
}
|
||||
|
@ -74,8 +73,7 @@ namespace helpers {
|
|||
auto pTadShapeH = packX.primaryShapeInfo();
|
||||
auto pTadOffsets = packX.specialOffsets();
|
||||
// auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int));
|
||||
NativeOps ops;
|
||||
ops.sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
|
||||
sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
|
||||
// manager.synchronize();
|
||||
sortedVals.tickWriteDevice();
|
||||
sortedVals.syncToHost();
|
||||
|
|
|
@ -38,32 +38,28 @@ TEST_F(HeaderTest, test_dataTypes_1) {
|
|||
std::string header("0NUMPY6789{'descr': '>f4");
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
ASSERT_EQ(nd4j::DataType::FLOAT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
ASSERT_EQ(nd4j::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
}
|
||||
|
||||
TEST_F(HeaderTest, test_dataTypes_2) {
|
||||
std::string header("0NUMPY6789{'descr': '>f8");
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
ASSERT_EQ(nd4j::DataType::DOUBLE, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
ASSERT_EQ(nd4j::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
}
|
||||
|
||||
TEST_F(HeaderTest, test_dataTypes_3) {
|
||||
std::string header("0NUMPY6789{'descr': '<i4");
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
ASSERT_EQ(nd4j::DataType::INT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
ASSERT_EQ(nd4j::DataType::INT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
}
|
||||
|
||||
TEST_F(HeaderTest, test_dataTypes_4) {
|
||||
std::string header("0NUMPY6789{'descr': '>u2");
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
ASSERT_EQ(nd4j::DataType::UINT16, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
ASSERT_EQ(nd4j::DataType::UINT16, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -88,8 +84,7 @@ TEST_F(LoadFromStringTest,PathTest) {
|
|||
ASSERT_EQ(4.0,data[3]);
|
||||
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
|
||||
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
|
||||
NativeOps nativeOps;
|
||||
Nd4jPointer pointer1 = nativeOps.dataPointForNumpy(loaded);
|
||||
Nd4jPointer pointer1 = dataPointForNumpy(loaded);
|
||||
delete[] shapeBuffer;
|
||||
|
||||
double *data2 = reinterpret_cast<double *>(pointer1);
|
||||
|
|
|
@ -472,9 +472,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
|
|||
/*
|
||||
Nd4jLong *buffer = new Nd4jLong[100000];
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||
|
||||
if (rng == nullptr)
|
||||
throw std::runtime_error("RNG initialization failed");
|
||||
|
@ -496,7 +494,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
|
|||
|
||||
ASSERT_TRUE(x->sumNumber() > 0.0);
|
||||
|
||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
||||
destroyRandom((Nd4jPointer) rng);
|
||||
delete[] buffer;
|
||||
|
||||
delete variableSpace;
|
||||
|
@ -1450,8 +1448,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
|||
|
||||
// //////////////////////////////////////////////////////////////////////
|
||||
// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) {
|
||||
// NativeOps nativeOps;
|
||||
|
||||
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
||||
// x->assign(1.0f);
|
||||
|
||||
|
@ -1483,8 +1479,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
|||
// outputShapes[0] = (Nd4jPointer) z->getShapeInfo();
|
||||
|
||||
|
||||
// //auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
|
||||
// auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
|
||||
// auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
// ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
// // z->printIndexedBuffer("Output add");
|
||||
// ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5);
|
||||
|
@ -1503,8 +1499,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
|||
|
||||
// //////////////////////////////////////////////////////////////////////
|
||||
// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) {
|
||||
// NativeOps nativeOps;
|
||||
|
||||
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
||||
// x->assign(1.0f);
|
||||
|
||||
|
@ -1532,7 +1526,7 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
|||
// auto outputBuffers = new Nd4jPointer[1];
|
||||
// auto outputShapes = new Nd4jPointer[1];
|
||||
|
||||
// nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
|
||||
// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
|
||||
|
||||
// ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5);
|
||||
// ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5);
|
||||
|
|
|
@ -876,14 +876,13 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
|
|||
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
||||
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
||||
|
||||
NativeOps op;
|
||||
Nd4jPointer nativeStart[2];
|
||||
|
||||
#ifdef __CUDABLAS__
|
||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
||||
#endif
|
||||
|
||||
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
||||
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||
4, pidx,
|
||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||
|
@ -912,12 +911,11 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
|
|||
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
||||
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
||||
|
||||
NativeOps op;
|
||||
Nd4jPointer nativeStart[2];
|
||||
#ifdef __CUDABLAS__
|
||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
||||
#endif
|
||||
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||
4, pidx,
|
||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||
|
|
|
@ -110,8 +110,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
|
|||
double extraParams[] = {lambda};
|
||||
|
||||
Nd4jLong *buffer = new Nd4jLong[N];
|
||||
NativeOps nativeOps;
|
||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
if (rng == nullptr)
|
||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
|
||||
|
||||
|
@ -122,7 +121,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
|
|||
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||
ASSERT_NEAR(std, actualStd, 0.01);
|
||||
|
||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
||||
destroyRandom((Nd4jPointer) rng);
|
||||
delete[] buffer;
|
||||
|
||||
}
|
||||
|
@ -142,8 +141,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
|
|||
|
||||
|
||||
Nd4jLong *buffer = new Nd4jLong[N];
|
||||
NativeOps nativeOps;
|
||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
if (rng == nullptr)
|
||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
|
||||
|
||||
|
@ -155,7 +153,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
|
|||
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||
ASSERT_NEAR(std, actualStd, 0.01);
|
||||
|
||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
||||
destroyRandom((Nd4jPointer) rng);
|
||||
delete[] buffer;
|
||||
|
||||
}
|
||||
|
@ -172,8 +170,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
|
|||
double extraParams[] = {lambda};
|
||||
|
||||
Nd4jLong *buffer = new Nd4jLong[N];
|
||||
NativeOps nativeOps;
|
||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
if (rng == nullptr)
|
||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
|
||||
|
||||
|
@ -184,7 +181,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
|
|||
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||
ASSERT_NEAR(std, actualStd, 0.01);
|
||||
|
||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
||||
destroyRandom((Nd4jPointer) rng);
|
||||
delete[] buffer;
|
||||
}
|
||||
*/
|
||||
|
@ -206,14 +203,13 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
|
|||
Nd4jLong *buffer = new Nd4jLong[N];
|
||||
// Nd4jPointer extra[2];
|
||||
#ifndef __CUDABLAS__
|
||||
NativeOps nativeOps;
|
||||
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||
if (rng == nullptr)
|
||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
|
||||
|
||||
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
|
||||
|
||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
||||
destroyRandom((Nd4jPointer) rng);
|
||||
#endif
|
||||
const double actualMean = x.meanNumber().e<double>(0);
|
||||
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
|
||||
|
@ -1005,12 +1001,10 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
|||
x0.linspace(1);
|
||||
x1.linspace(1);
|
||||
/*
|
||||
NativeOps nativeOps;
|
||||
|
||||
float prob[] = {0.5f};
|
||||
Nd4jLong* _bufferA = new Nd4jLong[100000];
|
||||
long _seed = 119L;
|
||||
auto _rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||
auto _rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||
|
||||
x0. applyTransform(random::DropOutInverted, &x0, prob);
|
||||
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
|
||||
|
@ -1026,7 +1020,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
|||
// ASSERT_FALSE(x0.equalsTo(nexp0));
|
||||
// ASSERT_FALSE(x0.equalsTo(nexp1));
|
||||
// ASSERT_FALSE(x0.equalsTo(nexp2));
|
||||
nativeOps.destroyRandom(_rngA);
|
||||
destroyRandom(_rngA);
|
||||
delete [] _bufferA;
|
||||
*/
|
||||
nd4j::ops::dropout op;
|
||||
|
|
|
@ -51,9 +51,7 @@ public:
|
|||
*/
|
||||
|
||||
TEST_F(GraphStateTests, Basic_Tests_1) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
||||
auto state = (GraphState *) getGraphState(117L);
|
||||
ASSERT_EQ(117L, state->id());
|
||||
|
||||
// this call will create scope internally
|
||||
|
@ -72,14 +70,12 @@ TEST_F(GraphStateTests, Basic_Tests_1) {
|
|||
ASSERT_TRUE(scope != nullptr);
|
||||
ASSERT_EQ(2, scope->size());
|
||||
|
||||
nativeOps.deleteGraphState(state);
|
||||
deleteGraphState(state);
|
||||
}
|
||||
|
||||
// just separate case for doubles wrapper in NativeOps, nothing else
|
||||
TEST_F(GraphStateTests, Basic_Tests_2) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
||||
auto state = (GraphState *) getGraphState(117L);
|
||||
ASSERT_EQ(117L, state->id());
|
||||
|
||||
// this call will create scope internally
|
||||
|
@ -98,46 +94,40 @@ TEST_F(GraphStateTests, Basic_Tests_2) {
|
|||
ASSERT_TRUE(scope != nullptr);
|
||||
ASSERT_EQ(2, scope->size());
|
||||
|
||||
nativeOps.deleteGraphState(state);
|
||||
deleteGraphState(state);
|
||||
}
|
||||
|
||||
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto state = nativeOps.getGraphState(117L);
|
||||
auto state = getGraphState(117L);
|
||||
|
||||
Nd4jLong scopes[] = {22, 33};
|
||||
//auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||
//auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||
|
||||
ASSERT_EQ(Status::THROW(), status);
|
||||
|
||||
nativeOps.deleteGraphState(state);
|
||||
deleteGraphState(state);
|
||||
}
|
||||
|
||||
TEST_F(GraphStateTests, Stateful_Execution_2) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
||||
auto state = (GraphState *) getGraphState(117L);
|
||||
|
||||
state->registerScope(22);
|
||||
state->registerScope(33);
|
||||
|
||||
Nd4jLong scopes[] = {22, 33};
|
||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||
|
||||
// it's no-op: just LogicScope
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
nativeOps.deleteGraphState(state);
|
||||
deleteGraphState(state);
|
||||
}
|
||||
|
||||
/**
|
||||
* This test checks WHILE loop
|
||||
*/
|
||||
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||
auto var1 = NDArrayFactory::create<float>(11.0f);
|
||||
auto var2 = NDArrayFactory::create<float>(2.0f);
|
||||
|
@ -147,7 +137,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
|||
auto res2 = NDArrayFactory::create<float>(0.0f);
|
||||
|
||||
// registering our GraphState holder
|
||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
||||
auto state = (GraphState *) getGraphState(117L);
|
||||
|
||||
// we're prepping pointers to input/output buffers
|
||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
|
||||
|
@ -197,7 +187,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
|||
Nd4jLong scopes[] = {22, 33};
|
||||
|
||||
// we're executing while loop
|
||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
|
||||
auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
// now we check provided result array
|
||||
|
@ -211,7 +201,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
|||
|
||||
// nd4j_printf("0 ------------------\n","");
|
||||
|
||||
nativeOps.deleteGraphState(state);
|
||||
deleteGraphState(state);
|
||||
|
||||
// nd4j_printf("1 ------------------\n","");
|
||||
}
|
||||
|
@ -220,8 +210,6 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
|||
* This test checks CONDITIONAL execution for FALSE
|
||||
*/
|
||||
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||
|
||||
|
@ -232,7 +220,7 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
|||
|
||||
|
||||
// registering our GraphState holder
|
||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
||||
auto state = (GraphState *) getGraphState(117L);
|
||||
|
||||
// we're prepping pointers to input/output buffers
|
||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||
|
@ -283,14 +271,14 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
|||
Nd4jLong scopes[] = {22, 33, 44};
|
||||
|
||||
// we're executing conditional op
|
||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||
|
||||
|
||||
nativeOps.deleteGraphState(state);
|
||||
deleteGraphState(state);
|
||||
}
|
||||
|
||||
|
||||
|
@ -298,8 +286,6 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
|||
* This test checks CONDITIONAL execution for TRUE
|
||||
*/
|
||||
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||
|
||||
|
@ -310,7 +296,7 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
|
|||
|
||||
|
||||
// registering our GraphState holder
|
||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
||||
auto state = (GraphState *) getGraphState(117L);
|
||||
|
||||
// we're prepping pointers to input/output buffers
|
||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||
|
@ -361,12 +347,11 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
|
|||
Nd4jLong scopes[] = {22, 33, 44};
|
||||
|
||||
// we're executing conditional op
|
||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||
|
||||
|
||||
nativeOps.deleteGraphState(state);
|
||||
deleteGraphState(state);
|
||||
}
|
||||
|
|
|
@ -42,7 +42,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
|
|||
e.assign(2.f);
|
||||
|
||||
nd4j::ops::add op;
|
||||
NativeOps nativeOps;
|
||||
Context context(1);
|
||||
|
||||
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||
|
@ -53,7 +52,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
|
|||
|
||||
nd4j_printf("Starting execution...\n","");
|
||||
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context);
|
||||
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||
|
||||
pm.synchronize();
|
||||
|
||||
|
@ -71,7 +70,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
|
|||
e.assign(false);
|
||||
|
||||
nd4j::ops::equals op;
|
||||
NativeOps nativeOps;
|
||||
Context context(1);
|
||||
|
||||
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||
|
@ -82,7 +80,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
|
|||
|
||||
nd4j_printf("Starting execution...\n","");
|
||||
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context);
|
||||
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||
|
||||
pm.synchronize();
|
||||
|
||||
|
|
|
@ -41,8 +41,6 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
|||
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
|
||||
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nd4j::ops::conv2d op;
|
||||
|
||||
std::vector<double> tArgs({});
|
||||
|
@ -50,7 +48,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
|||
|
||||
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
|
||||
|
||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||
|
||||
ASSERT_EQ(1, shapeList->size());
|
||||
|
||||
|
@ -64,7 +62,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
|||
//delete[] ptr;
|
||||
//delete shapeList;
|
||||
|
||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
||||
deleteShapeList((Nd4jPointer) shapeList);
|
||||
}
|
||||
|
||||
|
||||
|
@ -72,9 +70,6 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
|
|||
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
||||
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nd4j::ops::shape_of op;
|
||||
|
||||
std::vector<double> tArgs({});
|
||||
|
@ -83,14 +78,14 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
|
|||
|
||||
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
|
||||
|
||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||
|
||||
ASSERT_EQ(1, shapeList->size());
|
||||
|
||||
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
||||
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
||||
|
||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
||||
deleteShapeList((Nd4jPointer) shapeList);
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||
|
@ -112,13 +107,12 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
|||
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
|
||||
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::split_v op;
|
||||
|
||||
Nd4jLong iArgs[] = {1};
|
||||
auto hash = op.getOpHash();
|
||||
|
||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
||||
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
||||
|
||||
ASSERT_EQ(3, shapeList->size());
|
||||
|
||||
|
@ -126,7 +120,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
|||
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
|
||||
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
|
||||
|
||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
||||
deleteShapeList((Nd4jPointer) shapeList);
|
||||
}
|
||||
|
||||
TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
||||
|
@ -143,10 +137,7 @@ TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
|
@ -167,10 +158,7 @@ TEST_F(JavaInteropTests, Test_RDiv_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
|
@ -203,11 +191,9 @@ TEST_F(JavaInteropTests, TestSconv2d_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
||||
nullptr, 0, exp, 9, nullptr, 0, false);
|
||||
|
||||
//output.printBuffer("output");
|
||||
|
@ -238,11 +224,9 @@ TEST_F(JavaInteropTests, TestSconv2d_2) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||
|
||||
//output.printBuffer("output");
|
||||
|
||||
|
@ -266,9 +250,7 @@ TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
|
|||
|
||||
nd4j::ops::maxpool2d op;
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
Nd4jStatus status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
||||
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
|
||||
}
|
||||
|
@ -294,13 +276,11 @@ TEST_F(JavaInteropTests, TestCol2Im_1) {
|
|||
|
||||
nd4j::ops::col2im op;
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
|
||||
|
||||
auto hash = op.getOpHash();
|
||||
|
||||
nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||
|
||||
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
||||
}
|
||||
|
@ -320,8 +300,6 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
|
|||
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
||||
input.linspace(1);
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nd4j::ops::pnormpool2d op;
|
||||
|
||||
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
||||
|
@ -332,7 +310,7 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||
|
||||
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
||||
}
|
||||
|
@ -343,8 +321,6 @@ TEST_F(JavaInteropTests, TestInplace_1) {
|
|||
//auto exp('c', {10, 10});
|
||||
input.linspace(1);
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nd4j::ops::clipbyvalue op;
|
||||
|
||||
double extras[] = {-1.0f, 1.0f};
|
||||
|
@ -353,7 +329,7 @@ TEST_F(JavaInteropTests, TestInplace_1) {
|
|||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
||||
|
||||
|
||||
Nd4jStatus result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
||||
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||
|
||||
|
@ -415,7 +391,6 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
|||
x.linspace(1.0);
|
||||
z.linspace(1.0);
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::avgpool2d op;
|
||||
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
||||
|
||||
|
@ -427,7 +402,7 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||
|
||||
auto result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
|
||||
|
@ -496,15 +471,13 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
|||
|
||||
/*
|
||||
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
||||
|
||||
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||
|
||||
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
||||
|
||||
nativeOps.unregisterGraph(nullptr, 119);
|
||||
unregisterGraph(nullptr, 119);
|
||||
|
||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||
|
||||
|
@ -520,8 +493,6 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|||
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
|
||||
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
// we load graph from file, because we're not in java here, and dont have buffer ready
|
||||
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
||||
|
||||
|
@ -529,7 +500,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||
|
||||
// register the graph, to call for it later
|
||||
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||
|
||||
// and ensure we're ok
|
||||
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
||||
|
@ -547,7 +518,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|||
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
|
||||
|
||||
// now we're executing stored graph and providing replacement for input variable
|
||||
auto res_0 = nativeOps.executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
||||
auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
|
||||
ASSERT_EQ(1, res_0->size());
|
||||
|
||||
|
@ -562,7 +533,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|||
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
|
||||
|
||||
// doing it again
|
||||
auto res_1 = nativeOps.executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
||||
auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
||||
ASSERT_EQ(1, res_1->size());
|
||||
|
||||
|
@ -577,7 +548,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|||
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
|
||||
|
||||
// and again
|
||||
auto res_2 = nativeOps.executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
||||
auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
||||
ASSERT_EQ(1, res_2->size());
|
||||
|
||||
|
@ -586,7 +557,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
|||
|
||||
|
||||
//////// clean out
|
||||
nativeOps.unregisterGraph(nullptr, 119);
|
||||
unregisterGraph(nullptr, 119);
|
||||
|
||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||
|
||||
|
@ -616,9 +587,7 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
o.printIndexedBuffer("Greater JIT");
|
||||
ASSERT_TRUE(exp.equalsTo(&o));
|
||||
}
|
||||
|
@ -641,9 +610,7 @@ TEST_F(JavaInteropTests, Test_Greater_2) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(&o));
|
||||
}
|
||||
|
@ -662,9 +629,8 @@ TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_TRUE(exp.equalsTo(&o));
|
||||
|
@ -685,9 +651,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
|
@ -710,9 +675,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(z));
|
||||
|
@ -736,9 +700,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
|
|||
|
||||
Nd4jLong iArgs[] = {1};
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_TRUE(e.isSameShape(output));
|
||||
|
@ -753,8 +716,7 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
|
|||
|
||||
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.execReduce3(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
|
||||
execReduce3Tad(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
|
||||
y.buffer(), y.shapeInfo(), nullptr, nullptr,
|
||||
z.buffer(), z.shapeInfo(), nullptr, nullptr,
|
||||
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
|
||||
|
@ -764,10 +726,8 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
|
|||
Environment::getInstance()->setDebug(true);
|
||||
Environment::getInstance()->setVerbose(false);
|
||||
|
||||
NativeOps ops;
|
||||
|
||||
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
|
||||
auto ptr = ops.executeFlatGraph(nullptr, pl);
|
||||
auto ptr = executeFlatGraph(nullptr, pl);
|
||||
|
||||
Environment::getInstance()->setDebug(false);
|
||||
Environment::getInstance()->setVerbose(false);
|
||||
|
@ -792,9 +752,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
|
|||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
|
||||
|
@ -818,9 +777,8 @@ TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
|
|||
|
||||
nd4j::ops::maxpool2d op;
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
||||
|
@ -843,9 +801,8 @@ TEST_F(JavaInteropTests, Test_Unstack_1) {
|
|||
|
||||
nd4j::ops::unstack op;
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
||||
|
@ -864,9 +821,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
|
|||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||
|
||||
NativeOps nativeOps;
|
||||
auto hash = op.getOpHash();
|
||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
|
||||
|
@ -883,8 +839,7 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
|
|||
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
||||
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
||||
|
||||
NativeOps ops;
|
||||
ops.execPairwiseTransform(nullptr, pairwise::Add,
|
||||
execPairwiseTransform(nullptr, pairwise::Add,
|
||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
|
||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||
|
@ -898,7 +853,6 @@ TEST_F(JavaInteropTests, Test_Add_1) {
|
|||
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
||||
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::add op;
|
||||
|
||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
|
||||
|
@ -907,7 +861,7 @@ TEST_F(JavaInteropTests, Test_Add_1) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
|
||||
ASSERT_EQ(e, x);
|
||||
}
|
||||
|
@ -920,7 +874,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
|||
|
||||
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::zeta op;
|
||||
|
||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
|
||||
|
@ -929,7 +882,7 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
|||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
}
|
||||
|
@ -939,8 +892,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1) {
|
|||
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
||||
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
||||
|
||||
NativeOps ops;
|
||||
ops.execTransformAny(nullptr, transform::IsMax,
|
||||
execTransformAny(nullptr, transform::IsMax,
|
||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||
nullptr);
|
||||
|
@ -953,8 +905,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
|||
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
||||
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
||||
|
||||
NativeOps ops;
|
||||
ops.execTransformAny(nullptr, transform::IsMax,
|
||||
execTransformAny(nullptr, transform::IsMax,
|
||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||
nullptr);
|
||||
|
@ -970,8 +921,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
|||
Nd4jLong *ex[] = {tad, off};
|
||||
float ea[] = {2, 1, 2};
|
||||
|
||||
NativeOps ops;
|
||||
ops.execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
||||
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||
ea);
|
||||
|
@ -995,8 +945,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
|||
|
||||
|
||||
nd4j::ops::greater_equal op;
|
||||
NativeOps ops;
|
||||
auto shapeList = ops.calculateOutputShapes(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
|
||||
delete shapeList;
|
||||
}
|
||||
|
@ -1013,8 +962,7 @@ TEST_F(JavaInteropTests, Test_L2_Loss_3) {
|
|||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||
|
||||
nd4j::ops::l2_loss op;
|
||||
NativeOps ops;
|
||||
auto status = ops.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
z.printIndexedBuffer("z");
|
||||
|
@ -1036,9 +984,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_3) {
|
|||
|
||||
ASSERT_EQ(2, ctx.width());
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::add op;
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
||||
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||
|
||||
ASSERT_EQ(exp, z);
|
||||
}
|
||||
|
@ -1054,9 +1001,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_4) {
|
|||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||
ctx.setIArguments(iArgs, 3);
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::tri op;
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
||||
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||
|
||||
ASSERT_EQ(exp, z);
|
||||
}
|
||||
|
@ -1074,9 +1020,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_5) {
|
|||
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
||||
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::matmul op;
|
||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
||||
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
@ -1104,9 +1049,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_6) {
|
|||
|
||||
ctx.setIArguments(iArgs, 3);
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::matmul_bp op;
|
||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
||||
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
@ -1122,7 +1066,6 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|||
|
||||
ctx.setIArguments(iArgs, 1);
|
||||
|
||||
NativeOps nativeOps;
|
||||
nd4j::ops::concat op;
|
||||
|
||||
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
||||
|
@ -1130,7 +1073,7 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|||
|
||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||
|
||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
||||
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(e, z);
|
||||
|
@ -1138,10 +1081,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
|||
|
||||
/*
|
||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||
NativeOps ops;
|
||||
|
||||
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||
auto ptr = ops.executeFlatGraph(nullptr, pl);
|
||||
auto ptr = executeFlatGraph(nullptr, pl);
|
||||
|
||||
// at this point we have FlatResults
|
||||
auto flatResult = GetFlatResult(ptr->pointer());
|
||||
|
@ -1190,8 +1131,6 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
|||
}
|
||||
*/
|
||||
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
|
||||
// NativeOps ops;
|
||||
|
||||
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
|
||||
// std::array<float, 60> syn1;
|
||||
// std::array<float, 100000> exp;
|
||||
|
@ -1283,5 +1222,5 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
|||
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
|
||||
|
||||
|
||||
// ops.execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
||||
// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
||||
// }
|
|
@ -53,8 +53,7 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
|
|||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||
|
||||
x.syncToDevice();
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
|
||||
sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
|
||||
x.tickWriteDevice();
|
||||
|
||||
ASSERT_EQ(e, x);
|
||||
|
|
|
@ -501,8 +501,7 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
|
|||
|
||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.execReduce3(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||
execReduce3Tad(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||
nullptr, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
|
@ -517,9 +516,8 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
|
|||
|
||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
||||
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||
nullptr,
|
||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||
|
@ -543,9 +541,8 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
|
|||
|
||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
||||
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||
nullptr,
|
||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||
|
@ -569,9 +566,8 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
|
|||
|
||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
||||
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||
nullptr,
|
||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||
|
@ -593,8 +589,7 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
|||
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
|
||||
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
|
||||
|
||||
NativeOps ops;
|
||||
ops.execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||
execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||
|
|
|
@ -33,8 +33,6 @@ public:
|
|||
};
|
||||
|
||||
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
||||
NativeOps nativeOps;
|
||||
|
||||
// just 10GB
|
||||
Nd4jLong size = 100000L;
|
||||
|
||||
|
@ -43,11 +41,11 @@ TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
|||
ofs.write("", 1);
|
||||
ofs.close();
|
||||
|
||||
auto result = nativeOps.mmapFile(nullptr, "file", size);
|
||||
auto result = mmapFile(nullptr, "file", size);
|
||||
|
||||
ASSERT_FALSE(result == nullptr);
|
||||
|
||||
nativeOps.munmapFile(nullptr, result, size);
|
||||
munmapFile(nullptr, result, size);
|
||||
|
||||
remove("file");
|
||||
}
|
|
@ -2258,7 +2258,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) {
|
|||
|
||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('c', {5, 8});
|
||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(4);
|
||||
|
@ -2272,7 +2271,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
|||
}
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
z.printIndexedBuffer("Concat result");
|
||||
z.printBuffer("C Concat result linear");
|
||||
|
@ -2281,7 +2280,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
|||
|
||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('f', {5, 8});
|
||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(4);
|
||||
|
@ -2295,7 +2293,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
|||
}
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
z.printIndexedBuffer("Concat result");
|
||||
z.printBuffer("F Concat result linear");
|
||||
|
@ -2304,7 +2302,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
|||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('f', {3, 3});
|
||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(2);
|
||||
|
@ -2321,7 +2318,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
|||
//}
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
z.printIndexedBuffer("Concat result");
|
||||
z.printBuffer("F Concat result linear");
|
||||
|
@ -2331,7 +2328,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
|||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('c', {3, 3});
|
||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(2);
|
||||
|
@ -2348,7 +2344,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
|||
//}
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
z.printIndexedBuffer("Concat result");
|
||||
z.printBuffer("C Concat result linear");
|
||||
|
@ -2358,7 +2354,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
|||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
||||
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
|
||||
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
|
||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(2);
|
||||
|
@ -2375,7 +2370,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
|||
//}
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
z.printIndexedBuffer("Concat result");
|
||||
z.printBuffer("C Concat result linear");
|
||||
|
@ -2385,7 +2380,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
|||
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
|
||||
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
|
||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(3);
|
||||
|
@ -2406,7 +2400,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
|||
printf("The third array is %p\n", buffers[2]);
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
z.printIndexedBuffer("Concat result");
|
||||
z.printBuffer("C Concat3D result linear");
|
||||
|
@ -2417,7 +2411,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
|||
auto x1 = NDArrayFactory::create<float>(1);
|
||||
auto x2 = NDArrayFactory::create<float>(2);
|
||||
auto x3 = NDArrayFactory::create<float>(3);
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
|
||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(3);
|
||||
|
@ -2438,7 +2431,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
|||
printf("The third array is %p\n", buffers[2]);
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
z.printIndexedBuffer("Concat result");
|
||||
z.printBuffer("C Concat scalar result linear");
|
||||
|
@ -2462,7 +2455,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
|||
lx[i].assign(i);
|
||||
}
|
||||
|
||||
NativeOps native;
|
||||
auto z = NDArrayFactory::create<float>('c', {totalCount, width});
|
||||
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
std::vector<void*> buffers(totalCount);
|
||||
|
@ -2478,7 +2470,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
|||
printf("The third array is %p\n", buffers[2]);
|
||||
Nd4jPointer extra[2];
|
||||
extra[1] = *stream;
|
||||
native.concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
z.syncToHost();
|
||||
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
|
||||
//z.printIndexedBuffer("Concat result");
|
||||
|
@ -2496,7 +2488,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
|||
arrays.emplace_back(input);
|
||||
}
|
||||
auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
|
||||
NativeOps native;
|
||||
|
||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
Nd4jPointer extra[2];
|
||||
|
@ -2512,7 +2503,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
|||
hostShapes[i] = arrays[i].shapeInfo();
|
||||
}
|
||||
|
||||
native.concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
nd4j::ops::tear op;
|
||||
|
||||
auto result = op.execute({&z}, {}, {1, 2});
|
||||
|
@ -2536,7 +2527,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
|||
arrays.emplace_back(input);
|
||||
}
|
||||
auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
|
||||
NativeOps native;
|
||||
|
||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||
Nd4jPointer extra[2];
|
||||
|
@ -2552,7 +2542,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
|||
hostShapes[i] = arrays[i].shapeInfo();
|
||||
}
|
||||
std::vector<int> dimsToExclude({1,2});
|
||||
native.concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||
// z.syncToHost();
|
||||
// z.printBuffer("Pile OK");
|
||||
// z.printIndexedBuffer("Pile 10x10");
|
||||
|
@ -2569,7 +2559,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
|||
Nd4jPointer target = arrays[i].specialBuffer();
|
||||
cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice);
|
||||
}
|
||||
native.tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
||||
::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
||||
// auto result = op.execute({&z}, {}, {1, 2});
|
||||
// nd4j_printf("Result count is %lu\n", result->size());
|
||||
//ASSERT_EQ(10, result->size());
|
||||
|
|
|
@ -313,12 +313,10 @@ TEST_F(PlaygroundTests, test_reduce_3) {
|
|||
Nd4jLong max = 0L;
|
||||
Nd4jLong min = DataTypeUtils::max<Nd4jLong>();
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
for (int e = 0; e < iterations; e++) {
|
||||
auto timeStart = std::chrono::system_clock::now();
|
||||
|
||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
|
||||
execReduce3Tad(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
|
||||
x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(),
|
||||
y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr,
|
||||
|
@ -964,8 +962,6 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
|||
auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count();
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0};
|
||||
Nd4jPointer inputBuffers[] = {input.buffer()};
|
||||
Nd4jPointer inputShapes[] = {input.shapeInfo()};
|
||||
|
@ -976,7 +972,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
|||
auto javaStart = std::chrono::system_clock::now();
|
||||
|
||||
for (int e = 0; e < iterations; e++) {
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||
}
|
||||
|
||||
auto javaEnd = std::chrono::system_clock::now();
|
||||
|
@ -990,7 +986,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
|||
|
||||
|
||||
for (int e = 0; e < iterations; e++) {
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||
}
|
||||
|
||||
auto javaPermEnd = std::chrono::system_clock::now();
|
||||
|
@ -1020,9 +1016,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_2) {
|
|||
Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()};
|
||||
Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
|
||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||
}
|
||||
|
||||
TEST_F(PlaygroundTests, Test_Col2Im_1) {
|
||||
|
@ -1140,8 +1134,6 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
|||
int length = (int) array->lengthOf();
|
||||
int span = (int) (array->lengthOf() / 6) + 8;
|
||||
|
||||
NativeOps ops;
|
||||
|
||||
auto t = new int[1000000];
|
||||
|
||||
|
||||
|
@ -1150,7 +1142,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
|||
FloatBits fb;
|
||||
float threshold = 0.99f;
|
||||
fb.f_ = threshold;
|
||||
int le = ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||
int le = estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||
|
||||
t[0] = le;
|
||||
t[1] = length;
|
||||
|
@ -1162,7 +1154,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
|||
|
||||
for (int x = 0; x < iterations; x++) {
|
||||
auto permStart = std::chrono::system_clock::now();
|
||||
ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||
estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||
TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t);
|
||||
|
||||
auto permEnd = std::chrono::system_clock::now();
|
||||
|
|
|
@ -29,7 +29,6 @@ using namespace nd4j;
|
|||
|
||||
class RNGTests : public testing::Test {
|
||||
private:
|
||||
NativeOps nativeOps;
|
||||
//Nd4jLong *_bufferA;
|
||||
//Nd4jLong *_bufferB;
|
||||
|
||||
|
@ -47,8 +46,8 @@ public:
|
|||
RNGTests() {
|
||||
//_bufferA = new Nd4jLong[100000];
|
||||
//_bufferB = new Nd4jLong[100000];
|
||||
//_rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||
//_rngB = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
|
||||
//_rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||
//_rngB = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
|
||||
_rngA.setStates(_seed, _seed);
|
||||
_rngB.setStates(_seed, _seed);
|
||||
nexp0->assign(-1.0f);
|
||||
|
@ -57,8 +56,8 @@ public:
|
|||
}
|
||||
|
||||
~RNGTests() {
|
||||
//nativeOps.destroyRandom(_rngA);
|
||||
//nativeOps.destroyRandom(_rngB);
|
||||
//destroyRandom(_rngA);
|
||||
//destroyRandom(_rngB);
|
||||
//delete[] _bufferA;
|
||||
//delete[] _bufferB;
|
||||
|
||||
|
@ -791,14 +790,13 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Reproducibility_9) {
|
||||
NativeOps ops;
|
||||
Nd4jLong seed = 123;
|
||||
|
||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||
const int bufferSize = 10000;
|
||||
int64_t buffer[bufferSize];
|
||||
|
||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
||||
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||
|
||||
const int length = 4000000;
|
||||
int *arrayE = new int[length];
|
||||
|
@ -809,7 +807,7 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
|
|||
|
||||
rng->rewindH(static_cast<Nd4jLong>(length));
|
||||
|
||||
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||
refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||
|
||||
for (int e = 0; e < length; e++)
|
||||
arrayT[e] = rng->relativeInt(e);
|
||||
|
@ -825,18 +823,17 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
|
|||
delete[] arrayE;
|
||||
delete[] arrayT;
|
||||
|
||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Reproducibility_8) {
|
||||
NativeOps ops;
|
||||
Nd4jLong seed = 123;
|
||||
|
||||
std::vector<int> shape = {32, 3, 28, 28};
|
||||
const int bufferSize = 10000;
|
||||
int64_t buffer[bufferSize];
|
||||
|
||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
||||
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||
|
||||
const int length = 4000000;
|
||||
int *arrayE = new int[length];
|
||||
|
@ -847,7 +844,7 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
|
|||
|
||||
rng->rewindH(static_cast<Nd4jLong>(length));
|
||||
|
||||
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||
refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||
|
||||
for (int e = 0; e < length; e++)
|
||||
arrayT[e] = static_cast<int>(rng->relativeT<float>(e));
|
||||
|
@ -863,29 +860,27 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
|
|||
delete[] arrayE;
|
||||
delete[] arrayT;
|
||||
|
||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_RandomBuffer_Half_1) {
|
||||
NativeOps ops;
|
||||
Nd4jLong seed = 123;
|
||||
|
||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||
const int bufferSize = 10000;
|
||||
int64_t buffer[bufferSize];
|
||||
|
||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
||||
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||
|
||||
auto r0 = rng->relativeT<float16>(12L);
|
||||
auto r1 = rng->relativeT<float16>(13L);
|
||||
|
||||
ASSERT_NE(r0, r1);
|
||||
|
||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||
}
|
||||
|
||||
TEST_F(RNGTests, Test_Reproducibility_1) {
|
||||
NativeOps ops;
|
||||
Nd4jLong seed = 123;
|
||||
|
||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||
|
@ -918,7 +913,6 @@ TEST_F(RNGTests, Test_Reproducibility_1) {
|
|||
|
||||
#ifndef DEBUG_BUILD
|
||||
TEST_F(RNGTests, Test_Reproducibility_2) {
|
||||
NativeOps ops;
|
||||
Nd4jLong seed = 123;
|
||||
|
||||
std::vector<Nd4jLong> shape = {32, 3, 64, 64};
|
||||
|
|
|
@ -44,8 +44,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_key_1) {
|
|||
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
|
||||
ASSERT_EQ(ek, k);
|
||||
ASSERT_EQ(ev, v);
|
||||
|
@ -62,8 +61,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_val_1) {
|
|||
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
||||
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
|
||||
ASSERT_EQ(ek, k);
|
||||
ASSERT_EQ(ev, v);
|
||||
|
@ -81,8 +79,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_key_1) {
|
|||
|
||||
|
||||
int axis = 1;
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
|
||||
ASSERT_EQ(ek, k);
|
||||
ASSERT_EQ(ev, v);
|
||||
|
@ -100,8 +97,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_val_1) {
|
|||
|
||||
|
||||
int axis = 1;
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
|
||||
ASSERT_EQ(ek, k);
|
||||
ASSERT_EQ(ev, v);
|
||||
|
|
|
@ -42,8 +42,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_key_1) {
|
|||
|
||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
k.tickWriteDevice();
|
||||
v.tickWriteDevice();
|
||||
|
||||
|
@ -60,8 +59,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
|
|||
|
||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||
k.tickWriteDevice();
|
||||
v.tickWriteDevice();
|
||||
|
||||
|
@ -78,8 +76,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
|
|||
|
||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
||||
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
||||
k.tickWriteDevice();
|
||||
v.tickWriteDevice();
|
||||
k.printIndexedBuffer("KEYS");
|
||||
|
@ -97,8 +94,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
|
|||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||
|
||||
int axis = 1;
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
k.tickWriteDevice();
|
||||
v.tickWriteDevice();
|
||||
|
||||
|
@ -119,8 +115,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_val_1) {
|
|||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||
|
||||
int axis = 1;
|
||||
NativeOps nativeOps;
|
||||
nativeOps.sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||
k.tickWriteDevice();
|
||||
v.tickWriteDevice();
|
||||
|
||||
|
|
|
@ -58,8 +58,7 @@ TEST_F(TypeCastTests, Test_ConvertDtype_1) {
|
|||
float16 dst[5];
|
||||
float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f};
|
||||
|
||||
NativeOps ops;
|
||||
ops.convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
|
||||
convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
|
||||
|
||||
for (int e = 0; e < 5; e++)
|
||||
ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f);
|
||||
|
|
Loading…
Reference in New Issue