Refactor NativeOps.h to export C functions

master
Samuel Audet 2019-07-22 20:34:08 +09:00 committed by AlexDBlack
parent fad8da878f
commit dcc72e23b2
23 changed files with 1950 additions and 2089 deletions

View File

@ -77,10 +77,7 @@ bool verbose = false;
#include <graph/ResultWrapper.h> #include <graph/ResultWrapper.h>
#include <DebugInfo.h> #include <DebugInfo.h>
class ND4J_EXPORT NativeOps { extern "C" {
public:
NativeOps();
/** /**
* *
@ -259,7 +256,7 @@ public:
* @param result * @param result
* @param resultShapeInfo * @param resultShapeInfo
*/ */
void execReduceFloat(Nd4jPointer *extraPointers, void execReduceFloat2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -270,7 +267,7 @@ public:
void *dDimension, Nd4jLong *dDimensionShape); void *dDimension, Nd4jLong *dDimensionShape);
void execReduceSame(Nd4jPointer *extraPointers, void execReduceSame2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -281,7 +278,7 @@ public:
void *dDimension, Nd4jLong *dDimensionShape); void *dDimension, Nd4jLong *dDimensionShape);
void execReduceBool(Nd4jPointer *extraPointers, void execReduceBool2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -292,7 +289,7 @@ public:
void *dDimension, Nd4jLong *dDimensionShape); void *dDimension, Nd4jLong *dDimensionShape);
void execReduceLong(Nd4jPointer *extraPointers, void execReduceLong2(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -354,7 +351,7 @@ public:
* @param dimension * @param dimension
* @param dimensionLength * @param dimensionLength
*/ */
void execReduce3(Nd4jPointer *extraPointers, void execReduce3Tad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -457,7 +454,7 @@ public:
* @param dimension * @param dimension
* @param dimensionLength * @param dimensionLength
*/ */
void execSummaryStats(Nd4jPointer *extraPointers, void execSummaryStatsTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -532,7 +529,7 @@ public:
* @param dimension * @param dimension
* @param dimensionLength * @param dimensionLength
*/ */
void execScalar(Nd4jPointer *extraPointers, void execScalarTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -546,7 +543,7 @@ public:
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ); Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
void execScalarBool(Nd4jPointer *extraPointers, void execScalarBoolTad(Nd4jPointer *extraPointers,
int opNum, int opNum,
void *hX, Nd4jLong *hXShapeInfo, void *hX, Nd4jLong *hXShapeInfo,
void *dX, Nd4jLong *dXShapeInfo, void *dX, Nd4jLong *dXShapeInfo,
@ -743,7 +740,7 @@ public:
* Returns amount of free memory for current device * Returns amount of free memory for current device
* @return * @return
*/ */
Nd4jLong getDeviceFreeMemory(); Nd4jLong getDeviceFreeMemoryDefault();
/** /**
* *
@ -789,7 +786,7 @@ public:
* @param reserved * @param reserved
* @return * @return
*/ */
int memcpy(Nd4jPointer dst, int memcpySync(Nd4jPointer dst,
Nd4jPointer src, Nd4jPointer src,
Nd4jLong size, Nd4jLong size,
int flags, int flags,
@ -819,7 +816,7 @@ public:
* @param reserved * @param reserved
* @return * @return
*/ */
int memset(Nd4jPointer dst, int memsetSync(Nd4jPointer dst,
int value, int value,
Nd4jLong size, Nd4jLong size,
int flags, int flags,
@ -1058,8 +1055,7 @@ public:
nd4j::DataType dtype); nd4j::DataType dtype);
template <typename T> void batchExecutor(Nd4jPointer *extraPointers,
void _batchExecutor(Nd4jPointer *extraPointers,
int numAggregates, int numAggregates,
int opNum, int opNum,
int maxArgs, int maxArgs,
@ -1116,7 +1112,7 @@ public:
* @param zShapeBuffer * @param zShapeBuffer
* @param extraArguments * @param extraArguments
*/ */
void execRandom(Nd4jPointer *extraPointers, void execRandom3(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer, void *hX, Nd4jLong *hXShapeBuffer,
@ -1138,7 +1134,7 @@ public:
* @param zShapeBuffer * @param zShapeBuffer
* @param extraArguments * @param extraArguments
*/ */
void execRandom(Nd4jPointer *extraPointers, void execRandom2(Nd4jPointer *extraPointers,
int opNum, int opNum,
Nd4jPointer state, Nd4jPointer state,
void *hX, Nd4jLong *hXShapeBuffer, void *hX, Nd4jLong *hXShapeBuffer,
@ -1232,6 +1228,9 @@ public:
double scalarB); double scalarB);
*/ */
}
/** /**
* *
* @param data * @param data
@ -1267,7 +1266,9 @@ public:
return reinterpret_cast<Nd4jPointer>(ret); return reinterpret_cast<Nd4jPointer>(ret);
} }
Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) { extern "C" {
static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) {
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer); auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
auto type = nd4j::ArrayOptions::dataType(shapeBufferCast); auto type = nd4j::ArrayOptions::dataType(shapeBufferCast);
BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES);
@ -1279,7 +1280,7 @@ public:
* @param data the header data to parse * @param data the header data to parse
* @return a pointer to a numpy cnpy:NpyArray struct * @return a pointer to a numpy cnpy:NpyArray struct
*/ */
Nd4jPointer loadNpyFromHeader(Nd4jPointer data) { static Nd4jPointer loadNpyFromHeader(Nd4jPointer data) {
char *header = reinterpret_cast<char *>(data); char *header = reinterpret_cast<char *>(data);
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header); cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header);
@ -1295,6 +1296,7 @@ public:
return reinterpret_cast<Nd4jPointer>(ret); return reinterpret_cast<Nd4jPointer>(ret);
} }
}
/** /**
* Create a numpy array from an nd4j * Create a numpy array from an nd4j
@ -1329,8 +1331,9 @@ public:
return rettPointer; return rettPointer;
} }
extern "C" {
Nd4jPointer numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) { static Nd4jPointer numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) {
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer); auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
auto type = nd4j::ArrayOptions::dataType(shapeBufferCast); auto type = nd4j::ArrayOptions::dataType(shapeBufferCast);
BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, (data, shapeBuffer, wordSize), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, (data, shapeBuffer, wordSize), LIBND4J_TYPES);
@ -1352,7 +1355,7 @@ public:
* @param npyArray * @param npyArray
* @return * @return
*/ */
Nd4jPointer shapeBufferForNumpyHeader(Nd4jPointer npyArray) { static Nd4jPointer shapeBufferForNumpyHeader(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray)); cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
auto shape = new unsigned int[arr.shape.size()]; auto shape = new unsigned int[arr.shape.size()];
for(unsigned int i = 0; i < arr.shape.size(); i++) { for(unsigned int i = 0; i < arr.shape.size(); i++) {
@ -1371,7 +1374,7 @@ public:
* @param npyArray * @param npyArray
* @return * @return
*/ */
Nd4jPointer dataPointForNumpyHeader(Nd4jPointer npyArray) { static Nd4jPointer dataPointForNumpyHeader(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray)); cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arr.data); unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arr.data);
return dataToPrint; return dataToPrint;
@ -1382,7 +1385,7 @@ public:
* @param npyArray * @param npyArray
* @return * @return
*/ */
Nd4jPointer dataPointForNumpyStruct(Nd4jPointer npyArrayStruct) { static Nd4jPointer dataPointForNumpyStruct(Nd4jPointer npyArrayStruct) {
cnpy::NpyArray *arrPointer = reinterpret_cast<cnpy::NpyArray *>(npyArrayStruct); cnpy::NpyArray *arrPointer = reinterpret_cast<cnpy::NpyArray *>(npyArrayStruct);
unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arrPointer->data); unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arrPointer->data);
return reinterpret_cast<Nd4jPointer>(dataToPrint); return reinterpret_cast<Nd4jPointer>(dataToPrint);
@ -1394,7 +1397,7 @@ public:
* @param fromFile * @param fromFile
* @return * @return
*/ */
Nd4jPointer dataPointForNumpy(Nd4jPointer npyArray) { static Nd4jPointer dataPointForNumpy(Nd4jPointer npyArray) {
char *npyArrayBuffer = reinterpret_cast< char *>(npyArray); char *npyArrayBuffer = reinterpret_cast< char *>(npyArray);
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer); cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer);
return dataPointForNumpyStruct(reinterpret_cast<Nd4jPointer>(&arr)); return dataPointForNumpyStruct(reinterpret_cast<Nd4jPointer>(&arr));
@ -1406,7 +1409,7 @@ public:
* @param path * @param path
* @return * @return
*/ */
Nd4jPointer numpyFromFile(std::string path) { static Nd4jPointer numpyFromFile(std::string path) {
char *numpyBuffer = cnpy::loadFile(path.data()); char *numpyBuffer = cnpy::loadFile(path.data());
return reinterpret_cast<Nd4jPointer >(numpyBuffer); return reinterpret_cast<Nd4jPointer >(numpyBuffer);
} }
@ -1414,7 +1417,7 @@ public:
////// NPZ ////// ////// NPZ //////
void* mapFromNpzFile(std::string path){ static void* mapFromNpzFile(std::string path){
cnpy::npz_t* mapPtr = new cnpy::npz_t(); cnpy::npz_t* mapPtr = new cnpy::npz_t();
cnpy::npz_t map = cnpy::npzLoad(path); cnpy::npz_t map = cnpy::npzLoad(path);
mapPtr->insert(map.begin(), map.end()); mapPtr->insert(map.begin(), map.end());
@ -1422,13 +1425,13 @@ public:
} }
int getNumNpyArraysInMap(void *map){ static int getNumNpyArraysInMap(void *map){
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map); cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
int n = arrays->size(); int n = arrays->size();
return n; return n;
} }
const char* getNpyArrayNameFromMap(void *map, int index){ static const char* getNpyArrayNameFromMap(void *map, int index){
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map); cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
cnpy::npz_t::iterator it = arrays->begin(); cnpy::npz_t::iterator it = arrays->begin();
cnpy::npz_t::iterator end = arrays->end(); cnpy::npz_t::iterator end = arrays->end();
@ -1442,7 +1445,7 @@ public:
throw std::runtime_error("No array at index."); throw std::runtime_error("No array at index.");
} }
void* getNpyArrayFromMap(void *map, int index){ static void* getNpyArrayFromMap(void *map, int index){
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map); cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
cnpy::npz_t::iterator it = arrays->begin(); cnpy::npz_t::iterator it = arrays->begin();
cnpy::npz_t::iterator end = arrays->end(); cnpy::npz_t::iterator end = arrays->end();
@ -1459,18 +1462,18 @@ public:
int dataTypeFromNpyHeader(void *header); int dataTypeFromNpyHeader(void *header);
void* getNpyArrayData(void *npArray){ static void* getNpyArrayData(void *npArray){
cnpy::NpyArray* npyArray2 = reinterpret_cast<cnpy::NpyArray*>(npArray); cnpy::NpyArray* npyArray2 = reinterpret_cast<cnpy::NpyArray*>(npArray);
return reinterpret_cast<void*>(npyArray2->data); return reinterpret_cast<void*>(npyArray2->data);
} }
int getNpyArrayRank(void *npArray){ static int getNpyArrayRank(void *npArray){
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray); cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
int rank = arr->shape.size(); int rank = arr->shape.size();
return rank; return rank;
} }
Nd4jLong* getNpyArrayShape(void *npArray){ static Nd4jLong* getNpyArrayShape(void *npArray){
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray); cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
int ndim = arr->shape.size(); int ndim = arr->shape.size();
Nd4jLong* shape = new Nd4jLong[ndim]; Nd4jLong* shape = new Nd4jLong[ndim];
@ -1480,22 +1483,22 @@ public:
return shape; return shape;
} }
char getNpyArrayOrder(void *npArray){ static char getNpyArrayOrder(void *npArray){
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray); cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
return (arr->fortranOrder)?'f':'c'; return (arr->fortranOrder)?'f':'c';
} }
int getNpyArrayElemSize(void *npArray){ static int getNpyArrayElemSize(void *npArray){
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray); cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
return arr->wordSize; return arr->wordSize;
} }
void deleteNPArrayStruct(void *npArray){ static void deleteNPArrayStruct(void *npArray){
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray); cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
delete arr; delete arr;
} }
void deleteNPArrayMap(void *map){ static void deleteNPArrayMap(void *map){
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map); cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
delete arrays; delete arrays;
} }
@ -1507,7 +1510,7 @@ public:
* to get the length for * to get the length for
* @return * @return
*/ */
int elementSizeForNpyArray(Nd4jPointer npyArray) { static int elementSizeForNpyArray(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray)); cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
cnpy::NpyArray *arrPointer = &arr; cnpy::NpyArray *arrPointer = &arr;
int size = arrPointer->wordSize; int size = arrPointer->wordSize;
@ -1522,7 +1525,7 @@ public:
* to get the length for * to get the length for
* @return * @return
*/ */
int elementSizeForNpyArrayHeader(Nd4jPointer npyArray) { static int elementSizeForNpyArrayHeader(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray)); cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
cnpy::NpyArray *arrPointer = &arr; cnpy::NpyArray *arrPointer = &arr;
int size = arrPointer->wordSize; int size = arrPointer->wordSize;
@ -1530,7 +1533,7 @@ public:
} }
void releaseNumpy(Nd4jPointer npyArray) { static void releaseNumpy(Nd4jPointer npyArray) {
free(reinterpret_cast<void *>(npyArray)); free(reinterpret_cast<void *>(npyArray));
} }
@ -1647,10 +1650,10 @@ public:
// customOp executioner // customOp executioner
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace); int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace);
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext); int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext);
nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs); nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
void deleteShapeList(Nd4jPointer shapeList); void deleteShapeList(Nd4jPointer shapeList);
@ -1690,25 +1693,22 @@ public:
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
int* hIindexes, int* dIindexes); int* hIindexes, int* dIindexes);
void deleteShapeBuffer(Nd4jPointer ptr);
void deleteTadPack(Nd4jPointer ptr);
void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty); nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length); nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length);
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, double *data, int length); nd4j::ConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor); nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
void deleteShapeBuffer(Nd4jPointer ptr);
void deleteTadPack(Nd4jPointer ptr);
const char* runLightBenchmarkSuit(bool printOut); const char* runLightBenchmarkSuit(bool printOut);
const char* runFullBenchmarkSuit(bool printOut); const char* runFullBenchmarkSuit(bool printOut);
};
}
#endif //NATIVEOPERATIONS_NATIVEOPS_H #endif //NATIVEOPERATIONS_NATIVEOPS_H

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -27,11 +27,10 @@ namespace nd4j {
ProviderRNG::ProviderRNG() { ProviderRNG::ProviderRNG() {
Nd4jLong *buffer = new Nd4jLong[100000]; Nd4jLong *buffer = new Nd4jLong[100000];
NativeOps nativeOps;
std::lock_guard<std::mutex> lock(_mutex); std::lock_guard<std::mutex> lock(_mutex);
#ifndef __CUDABLAS__ #ifndef __CUDABLAS__
// at this moment we don't have streams etc, so let's just skip this for now // at this moment we don't have streams etc, so let's just skip this for now
_rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); _rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
#endif #endif
// if(_rng != nullptr) // if(_rng != nullptr)
} }

View File

@ -41,8 +41,7 @@ namespace nd4j {
} }
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream // FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
NativeOps nativeOps; refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
nativeOps.refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
return Status::OK(); return Status::OK();
} }

View File

@ -110,11 +110,9 @@ namespace helpers {
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
NDArray scores(*scales); NDArray scores(*scales);
NativeOps nativeOps;
Nd4jPointer extras[2] = {nullptr, stream}; Nd4jPointer extras[2] = {nullptr, stream};
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true); sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
// TO DO: sort indices using scales as value row // TO DO: sort indices using scales as value row
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);}); //std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer()); I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());

View File

@ -60,8 +60,7 @@ namespace helpers {
params[1] = context->getCudaStream(); params[1] = context->getCudaStream();
if (input->isVector()) { if (input->isVector()) {
NativeOps ops; sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
ops.sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice); cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
} }
@ -74,8 +73,7 @@ namespace helpers {
auto pTadShapeH = packX.primaryShapeInfo(); auto pTadShapeH = packX.primaryShapeInfo();
auto pTadOffsets = packX.specialOffsets(); auto pTadOffsets = packX.specialOffsets();
// auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int)); // auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int));
NativeOps ops; sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
ops.sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
// manager.synchronize(); // manager.synchronize();
sortedVals.tickWriteDevice(); sortedVals.tickWriteDevice();
sortedVals.syncToHost(); sortedVals.syncToHost();

View File

@ -38,32 +38,28 @@ TEST_F(HeaderTest, test_dataTypes_1) {
std::string header("0NUMPY6789{'descr': '>f4"); std::string header("0NUMPY6789{'descr': '>f4");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::FLOAT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
TEST_F(HeaderTest, test_dataTypes_2) { TEST_F(HeaderTest, test_dataTypes_2) {
std::string header("0NUMPY6789{'descr': '>f8"); std::string header("0NUMPY6789{'descr': '>f8");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::DOUBLE, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
TEST_F(HeaderTest, test_dataTypes_3) { TEST_F(HeaderTest, test_dataTypes_3) {
std::string header("0NUMPY6789{'descr': '<i4"); std::string header("0NUMPY6789{'descr': '<i4");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::INT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::INT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
TEST_F(HeaderTest, test_dataTypes_4) { TEST_F(HeaderTest, test_dataTypes_4) {
std::string header("0NUMPY6789{'descr': '>u2"); std::string header("0NUMPY6789{'descr': '>u2");
NativeOps nativeOps; ASSERT_EQ(nd4j::DataType::UINT16, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
ASSERT_EQ(nd4j::DataType::UINT16, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
} }
/* /*
@ -88,8 +84,7 @@ TEST_F(LoadFromStringTest,PathTest) {
ASSERT_EQ(4.0,data[3]); ASSERT_EQ(4.0,data[3]);
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr); Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr); int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
NativeOps nativeOps; Nd4jPointer pointer1 = dataPointForNumpy(loaded);
Nd4jPointer pointer1 = nativeOps.dataPointForNumpy(loaded);
delete[] shapeBuffer; delete[] shapeBuffer;
double *data2 = reinterpret_cast<double *>(pointer1); double *data2 = reinterpret_cast<double *>(pointer1);

View File

@ -472,9 +472,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
/* /*
Nd4jLong *buffer = new Nd4jLong[100000]; Nd4jLong *buffer = new Nd4jLong[100000];
NativeOps nativeOps; nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("RNG initialization failed"); throw std::runtime_error("RNG initialization failed");
@ -496,7 +494,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
ASSERT_TRUE(x->sumNumber() > 0.0); ASSERT_TRUE(x->sumNumber() > 0.0);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
delete variableSpace; delete variableSpace;
@ -1450,8 +1448,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// ////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) { // TEST_F(DeclarableOpsTests1, TestLegacyExecution1) {
// NativeOps nativeOps;
// auto x = NDArrayFactory::create_<float>('c', {10, 10}); // auto x = NDArrayFactory::create_<float>('c', {10, 10});
// x->assign(1.0f); // x->assign(1.0f);
@ -1483,8 +1479,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// outputShapes[0] = (Nd4jPointer) z->getShapeInfo(); // outputShapes[0] = (Nd4jPointer) z->getShapeInfo();
// //auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false); // //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
// auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); // auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
// ASSERT_EQ(ND4J_STATUS_OK, status); // ASSERT_EQ(ND4J_STATUS_OK, status);
// // z->printIndexedBuffer("Output add"); // // z->printIndexedBuffer("Output add");
// ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5); // ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5);
@ -1503,8 +1499,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// ////////////////////////////////////////////////////////////////////// // //////////////////////////////////////////////////////////////////////
// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) { // TEST_F(DeclarableOpsTests1, TestLegacyExecution2) {
// NativeOps nativeOps;
// auto x = NDArrayFactory::create_<float>('c', {10, 10}); // auto x = NDArrayFactory::create_<float>('c', {10, 10});
// x->assign(1.0f); // x->assign(1.0f);
@ -1532,7 +1526,7 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
// auto outputBuffers = new Nd4jPointer[1]; // auto outputBuffers = new Nd4jPointer[1];
// auto outputShapes = new Nd4jPointer[1]; // auto outputShapes = new Nd4jPointer[1];
// nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true); // execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
// ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5); // ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5);
// ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5); // ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5);

View File

@ -876,14 +876,13 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims); auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims); auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
NativeOps op;
Nd4jPointer nativeStart[2]; Nd4jPointer nativeStart[2];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
nativeStart[1] = *(x.getContext()->getCudaStream()); nativeStart[1] = *(x.getContext()->getCudaStream());
#endif #endif
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
4, pidx, 4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
@ -912,12 +911,11 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims); auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims); auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
NativeOps op;
Nd4jPointer nativeStart[2]; Nd4jPointer nativeStart[2];
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
nativeStart[1] = *(x.getContext()->getCudaStream()); nativeStart[1] = *(x.getContext()->getCudaStream());
#endif #endif
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(), pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(), z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
4, pidx, 4, pidx,
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),

View File

@ -110,8 +110,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
double extraParams[] = {lambda}; double extraParams[] = {lambda};
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
NativeOps nativeOps; auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
@ -122,7 +121,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01); ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
} }
@ -142,8 +141,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
NativeOps nativeOps; auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
@ -155,7 +153,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01); ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
} }
@ -172,8 +170,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
double extraParams[] = {lambda}; double extraParams[] = {lambda};
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
NativeOps nativeOps; auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
@ -184,7 +181,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
ASSERT_NEAR(mean, actualMean, 0.01); ASSERT_NEAR(mean, actualMean, 0.01);
ASSERT_NEAR(std, actualStd, 0.01); ASSERT_NEAR(std, actualStd, 0.01);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
delete[] buffer; delete[] buffer;
} }
*/ */
@ -206,14 +203,13 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
Nd4jLong *buffer = new Nd4jLong[N]; Nd4jLong *buffer = new Nd4jLong[N];
// Nd4jPointer extra[2]; // Nd4jPointer extra[2];
#ifndef __CUDABLAS__ #ifndef __CUDABLAS__
NativeOps nativeOps; nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
if (rng == nullptr) if (rng == nullptr)
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !"); throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams); functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
nativeOps.destroyRandom((Nd4jPointer) rng); destroyRandom((Nd4jPointer) rng);
#endif #endif
const double actualMean = x.meanNumber().e<double>(0); const double actualMean = x.meanNumber().e<double>(0);
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0); const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
@ -1005,12 +1001,10 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
x0.linspace(1); x0.linspace(1);
x1.linspace(1); x1.linspace(1);
/* /*
NativeOps nativeOps;
float prob[] = {0.5f}; float prob[] = {0.5f};
Nd4jLong* _bufferA = new Nd4jLong[100000]; Nd4jLong* _bufferA = new Nd4jLong[100000];
long _seed = 119L; long _seed = 119L;
auto _rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); auto _rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
x0. applyTransform(random::DropOutInverted, &x0, prob); x0. applyTransform(random::DropOutInverted, &x0, prob);
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob); // x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
@ -1026,7 +1020,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
// ASSERT_FALSE(x0.equalsTo(nexp0)); // ASSERT_FALSE(x0.equalsTo(nexp0));
// ASSERT_FALSE(x0.equalsTo(nexp1)); // ASSERT_FALSE(x0.equalsTo(nexp1));
// ASSERT_FALSE(x0.equalsTo(nexp2)); // ASSERT_FALSE(x0.equalsTo(nexp2));
nativeOps.destroyRandom(_rngA); destroyRandom(_rngA);
delete [] _bufferA; delete [] _bufferA;
*/ */
nd4j::ops::dropout op; nd4j::ops::dropout op;

View File

@ -51,9 +51,7 @@ public:
*/ */
TEST_F(GraphStateTests, Basic_Tests_1) { TEST_F(GraphStateTests, Basic_Tests_1) {
NativeOps nativeOps; auto state = (GraphState *) getGraphState(117L);
auto state = (GraphState *) nativeOps.getGraphState(117L);
ASSERT_EQ(117L, state->id()); ASSERT_EQ(117L, state->id());
// this call will create scope internally // this call will create scope internally
@ -72,14 +70,12 @@ TEST_F(GraphStateTests, Basic_Tests_1) {
ASSERT_TRUE(scope != nullptr); ASSERT_TRUE(scope != nullptr);
ASSERT_EQ(2, scope->size()); ASSERT_EQ(2, scope->size());
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
// just separate case for doubles wrapper in NativeOps, nothing else // just separate case for doubles wrapper in NativeOps, nothing else
TEST_F(GraphStateTests, Basic_Tests_2) { TEST_F(GraphStateTests, Basic_Tests_2) {
NativeOps nativeOps; auto state = (GraphState *) getGraphState(117L);
auto state = (GraphState *) nativeOps.getGraphState(117L);
ASSERT_EQ(117L, state->id()); ASSERT_EQ(117L, state->id());
// this call will create scope internally // this call will create scope internally
@ -98,46 +94,40 @@ TEST_F(GraphStateTests, Basic_Tests_2) {
ASSERT_TRUE(scope != nullptr); ASSERT_TRUE(scope != nullptr);
ASSERT_EQ(2, scope->size()); ASSERT_EQ(2, scope->size());
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
TEST_F(GraphStateTests, Stateful_Execution_1) { TEST_F(GraphStateTests, Stateful_Execution_1) {
NativeOps nativeOps; auto state = getGraphState(117L);
auto state = nativeOps.getGraphState(117L);
Nd4jLong scopes[] = {22, 33}; Nd4jLong scopes[] = {22, 33};
//auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); //auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
ASSERT_EQ(Status::THROW(), status); ASSERT_EQ(Status::THROW(), status);
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
TEST_F(GraphStateTests, Stateful_Execution_2) { TEST_F(GraphStateTests, Stateful_Execution_2) {
NativeOps nativeOps; auto state = (GraphState *) getGraphState(117L);
auto state = (GraphState *) nativeOps.getGraphState(117L);
state->registerScope(22); state->registerScope(22);
state->registerScope(33); state->registerScope(33);
Nd4jLong scopes[] = {22, 33}; Nd4jLong scopes[] = {22, 33};
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
// it's no-op: just LogicScope // it's no-op: just LogicScope
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
/** /**
* This test checks WHILE loop * This test checks WHILE loop
*/ */
TEST_F(GraphStateTests, Stateful_Execution_3) { TEST_F(GraphStateTests, Stateful_Execution_3) {
NativeOps nativeOps;
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto var1 = NDArrayFactory::create<float>(11.0f); auto var1 = NDArrayFactory::create<float>(11.0f);
auto var2 = NDArrayFactory::create<float>(2.0f); auto var2 = NDArrayFactory::create<float>(2.0f);
@ -147,7 +137,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
auto res2 = NDArrayFactory::create<float>(0.0f); auto res2 = NDArrayFactory::create<float>(0.0f);
// registering our GraphState holder // registering our GraphState holder
auto state = (GraphState *) nativeOps.getGraphState(117L); auto state = (GraphState *) getGraphState(117L);
// we're prepping pointers to input/output buffers // we're prepping pointers to input/output buffers
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()}; Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
@ -197,7 +187,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
Nd4jLong scopes[] = {22, 33}; Nd4jLong scopes[] = {22, 33};
// we're executing while loop // we're executing while loop
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3); auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
// now we check provided result array // now we check provided result array
@ -211,7 +201,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
// nd4j_printf("0 ------------------\n",""); // nd4j_printf("0 ------------------\n","");
nativeOps.deleteGraphState(state); deleteGraphState(state);
// nd4j_printf("1 ------------------\n",""); // nd4j_printf("1 ------------------\n","");
} }
@ -220,8 +210,6 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
* This test checks CONDITIONAL execution for FALSE * This test checks CONDITIONAL execution for FALSE
*/ */
TEST_F(GraphStateTests, Stateful_Execution_4) { TEST_F(GraphStateTests, Stateful_Execution_4) {
NativeOps nativeOps;
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto var1 = NDArrayFactory::create<float>(5.0f); auto var1 = NDArrayFactory::create<float>(5.0f);
@ -232,7 +220,7 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
// registering our GraphState holder // registering our GraphState holder
auto state = (GraphState *) nativeOps.getGraphState(117L); auto state = (GraphState *) getGraphState(117L);
// we're prepping pointers to input/output buffers // we're prepping pointers to input/output buffers
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
@ -283,14 +271,14 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
Nd4jLong scopes[] = {22, 33, 44}; Nd4jLong scopes[] = {22, 33, 44};
// we're executing conditional op // we're executing conditional op
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(&res0)); ASSERT_TRUE(exp.isSameShape(&res0));
ASSERT_TRUE(exp.equalsTo(&res0)); ASSERT_TRUE(exp.equalsTo(&res0));
nativeOps.deleteGraphState(state); deleteGraphState(state);
} }
@ -298,8 +286,6 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
* This test checks CONDITIONAL execution for TRUE * This test checks CONDITIONAL execution for TRUE
*/ */
TEST_F(GraphStateTests, Stateful_Execution_5) { TEST_F(GraphStateTests, Stateful_Execution_5) {
NativeOps nativeOps;
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4}); auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
auto var1 = NDArrayFactory::create<float>(5.0f); auto var1 = NDArrayFactory::create<float>(5.0f);
@ -310,7 +296,7 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
// registering our GraphState holder // registering our GraphState holder
auto state = (GraphState *) nativeOps.getGraphState(117L); auto state = (GraphState *) getGraphState(117L);
// we're prepping pointers to input/output buffers // we're prepping pointers to input/output buffers
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
@ -361,12 +347,11 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
Nd4jLong scopes[] = {22, 33, 44}; Nd4jLong scopes[] = {22, 33, 44};
// we're executing conditional op // we're executing conditional op
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(&res0)); ASSERT_TRUE(exp.isSameShape(&res0));
ASSERT_TRUE(exp.equalsTo(&res0)); ASSERT_TRUE(exp.equalsTo(&res0));
deleteGraphState(state);
nativeOps.deleteGraphState(state);
} }

View File

@ -42,7 +42,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
e.assign(2.f); e.assign(2.f);
nd4j::ops::add op; nd4j::ops::add op;
NativeOps nativeOps;
Context context(1); Context context(1);
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
@ -53,7 +52,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
nd4j_printf("Starting execution...\n",""); nd4j_printf("Starting execution...\n","");
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1"); PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context); execCustomOp2(nullptr, op.getOpHash(), &context);
pm.synchronize(); pm.synchronize();
@ -71,7 +70,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
e.assign(false); e.assign(false);
nd4j::ops::equals op; nd4j::ops::equals op;
NativeOps nativeOps;
Context context(1); Context context(1);
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
@ -82,7 +80,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
nd4j_printf("Starting execution...\n",""); nd4j_printf("Starting execution...\n","");
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2"); PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context); execCustomOp2(nullptr, op.getOpHash(), &context);
pm.synchronize(); pm.synchronize();

View File

@ -41,8 +41,6 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3}); auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4}); auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
NativeOps nativeOps;
nd4j::ops::conv2d op; nd4j::ops::conv2d op;
std::vector<double> tArgs({}); std::vector<double> tArgs({});
@ -50,7 +48,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()}; Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
ASSERT_EQ(1, shapeList->size()); ASSERT_EQ(1, shapeList->size());
@ -64,7 +62,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
//delete[] ptr; //delete[] ptr;
//delete shapeList; //delete shapeList;
nativeOps.deleteShapeList((Nd4jPointer) shapeList); deleteShapeList((Nd4jPointer) shapeList);
} }
@ -72,9 +70,6 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4}); auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4}); auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
NativeOps nativeOps;
nd4j::ops::shape_of op; nd4j::ops::shape_of op;
std::vector<double> tArgs({}); std::vector<double> tArgs({});
@ -83,14 +78,14 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()}; Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
ASSERT_EQ(1, shapeList->size()); ASSERT_EQ(1, shapeList->size());
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
nativeOps.deleteShapeList((Nd4jPointer) shapeList); deleteShapeList((Nd4jPointer) shapeList);
} }
TEST_F(JavaInteropTests, TestShapeExposure3) { TEST_F(JavaInteropTests, TestShapeExposure3) {
@ -112,13 +107,12 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()}; Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()}; Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
NativeOps nativeOps;
nd4j::ops::split_v op; nd4j::ops::split_v op;
Nd4jLong iArgs[] = {1}; Nd4jLong iArgs[] = {1};
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto shapeList = nativeOps.calculateOutputShapes(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0); auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
ASSERT_EQ(3, shapeList->size()); ASSERT_EQ(3, shapeList->size());
@ -126,7 +120,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1))); ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2))); ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
nativeOps.deleteShapeList((Nd4jPointer) shapeList); deleteShapeList((Nd4jPointer) shapeList);
} }
TEST_F(JavaInteropTests, Test_Squeeze_1) { TEST_F(JavaInteropTests, Test_Squeeze_1) {
@ -143,10 +137,7 @@ TEST_F(JavaInteropTests, Test_Squeeze_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NativeOps nativeOps;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -167,10 +158,7 @@ TEST_F(JavaInteropTests, Test_RDiv_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
NativeOps nativeOps;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -203,11 +191,9 @@ TEST_F(JavaInteropTests, TestSconv2d_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
NativeOps nativeOps;
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}; Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1, execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
nullptr, 0, exp, 9, nullptr, 0, false); nullptr, 0, exp, 9, nullptr, 0, false);
//output.printBuffer("output"); //output.printBuffer("output");
@ -238,11 +224,9 @@ TEST_F(JavaInteropTests, TestSconv2d_2) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
NativeOps nativeOps;
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0}; Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
//output.printBuffer("output"); //output.printBuffer("output");
@ -266,9 +250,7 @@ TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
nd4j::ops::maxpool2d op; nd4j::ops::maxpool2d op;
NativeOps nativeOps; Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
Nd4jStatus status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
ASSERT_EQ(ND4J_STATUS_OK, status); ASSERT_EQ(ND4J_STATUS_OK, status);
} }
@ -294,13 +276,11 @@ TEST_F(JavaInteropTests, TestCol2Im_1) {
nd4j::ops::col2im op; nd4j::ops::col2im op;
NativeOps nativeOps;
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1}; Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
auto hash = op.getOpHash(); auto hash = op.getOpHash();
nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f); ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
} }
@ -320,8 +300,6 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3}); auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
input.linspace(1); input.linspace(1);
NativeOps nativeOps;
nd4j::ops::pnormpool2d op; nd4j::ops::pnormpool2d op;
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0}; Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
@ -332,7 +310,7 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0); ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
} }
@ -343,8 +321,6 @@ TEST_F(JavaInteropTests, TestInplace_1) {
//auto exp('c', {10, 10}); //auto exp('c', {10, 10});
input.linspace(1); input.linspace(1);
NativeOps nativeOps;
nd4j::ops::clipbyvalue op; nd4j::ops::clipbyvalue op;
double extras[] = {-1.0f, 1.0f}; double extras[] = {-1.0f, 1.0f};
@ -353,7 +329,7 @@ TEST_F(JavaInteropTests, TestInplace_1) {
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()}; Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
Nd4jStatus result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true); Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
ASSERT_EQ(ND4J_STATUS_OK, result); ASSERT_EQ(ND4J_STATUS_OK, result);
@ -415,7 +391,6 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
x.linspace(1.0); x.linspace(1.0);
z.linspace(1.0); z.linspace(1.0);
NativeOps nativeOps;
nd4j::ops::avgpool2d op; nd4j::ops::avgpool2d op;
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); //auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
@ -427,7 +402,7 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
auto result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), result); ASSERT_EQ(Status::OK(), result);
@ -496,15 +471,13 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
/* /*
TEST_F(JavaInteropTests, Test_GraphReuse_1) { TEST_F(JavaInteropTests, Test_GraphReuse_1) {
NativeOps nativeOps;
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data); registerGraph(nullptr, 119, (Nd4jPointer) data);
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
nativeOps.unregisterGraph(nullptr, 119); unregisterGraph(nullptr, 119);
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
@ -520,8 +493,6 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6}); auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9}); auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
NativeOps nativeOps;
// we load graph from file, because we're not in java here, and dont have buffer ready // we load graph from file, because we're not in java here, and dont have buffer ready
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
@ -529,7 +500,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
// register the graph, to call for it later // register the graph, to call for it later
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data); registerGraph(nullptr, 119, (Nd4jPointer) data);
// and ensure we're ok // and ensure we're ok
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
@ -547,7 +518,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()}; Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
// now we're executing stored graph and providing replacement for input variable // now we're executing stored graph and providing replacement for input variable
auto res_0 = nativeOps.executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1); auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_0->status()); ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
ASSERT_EQ(1, res_0->size()); ASSERT_EQ(1, res_0->size());
@ -562,7 +533,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()}; Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
// doing it again // doing it again
auto res_1 = nativeOps.executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1); auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_1->status()); ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
ASSERT_EQ(1, res_1->size()); ASSERT_EQ(1, res_1->size());
@ -577,7 +548,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()}; Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
// and again // and again
auto res_2 = nativeOps.executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1); auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
ASSERT_EQ(ND4J_STATUS_OK, res_1->status()); ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
ASSERT_EQ(1, res_2->size()); ASSERT_EQ(1, res_2->size());
@ -586,7 +557,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
//////// clean out //////// clean out
nativeOps.unregisterGraph(nullptr, 119); unregisterGraph(nullptr, 119);
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119)); ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
@ -616,9 +587,7 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
NativeOps nativeOps; execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
o.printIndexedBuffer("Greater JIT"); o.printIndexedBuffer("Greater JIT");
ASSERT_TRUE(exp.equalsTo(&o)); ASSERT_TRUE(exp.equalsTo(&o));
} }
@ -641,9 +610,7 @@ TEST_F(JavaInteropTests, Test_Greater_2) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
NativeOps nativeOps; execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_TRUE(exp.equalsTo(&o)); ASSERT_TRUE(exp.equalsTo(&o));
} }
@ -662,9 +629,8 @@ TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.equalsTo(&o)); ASSERT_TRUE(exp.equalsTo(&o));
@ -685,9 +651,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
@ -710,9 +675,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(e.isSameShape(z)); ASSERT_TRUE(e.isSameShape(z));
@ -736,9 +700,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
Nd4jLong iArgs[] = {1}; Nd4jLong iArgs[] = {1};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_TRUE(e.isSameShape(output)); ASSERT_TRUE(e.isSameShape(output));
@ -753,8 +716,7 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1}); auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
NativeOps nativeOps; execReduce3Tad(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
nativeOps.execReduce3(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
y.buffer(), y.shapeInfo(), nullptr, nullptr, y.buffer(), y.shapeInfo(), nullptr, nullptr,
z.buffer(), z.shapeInfo(), nullptr, nullptr, z.buffer(), z.shapeInfo(), nullptr, nullptr,
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr); dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
@ -764,10 +726,8 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
Environment::getInstance()->setDebug(true); Environment::getInstance()->setDebug(true);
Environment::getInstance()->setVerbose(false); Environment::getInstance()->setVerbose(false);
NativeOps ops;
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb"); auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
auto ptr = ops.executeFlatGraph(nullptr, pl); auto ptr = executeFlatGraph(nullptr, pl);
Environment::getInstance()->setDebug(false); Environment::getInstance()->setDebug(false);
Environment::getInstance()->setVerbose(false); Environment::getInstance()->setVerbose(false);
@ -792,9 +752,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())}; Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
@ -818,9 +777,8 @@ TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
nd4j::ops::maxpool2d op; nd4j::ops::maxpool2d op;
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -843,9 +801,8 @@ TEST_F(JavaInteropTests, Test_Unstack_1) {
nd4j::ops::unstack op; nd4j::ops::unstack op;
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -864,9 +821,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())}; Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
NativeOps nativeOps;
auto hash = op.getOpHash(); auto hash = op.getOpHash();
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
@ -883,8 +839,7 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0}); auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8}); auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
NativeOps ops; execPairwiseTransform(nullptr, pairwise::Add,
ops.execPairwiseTransform(nullptr, pairwise::Add,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr, arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
@ -898,7 +853,6 @@ TEST_F(JavaInteropTests, Test_Add_1) {
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1}); auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2}); auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
NativeOps nativeOps;
nd4j::ops::add op; nd4j::ops::add op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()}; Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
@ -907,7 +861,7 @@ TEST_F(JavaInteropTests, Test_Add_1) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(e, x); ASSERT_EQ(e, x);
} }
@ -920,7 +874,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
NativeOps nativeOps;
nd4j::ops::zeta op; nd4j::ops::zeta op;
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()}; Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
@ -929,7 +882,7 @@ TEST_F(JavaInteropTests, zeta_test10) {
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()}; Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()}; Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
} }
@ -939,8 +892,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1) {
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0}); auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0}); auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
NativeOps ops; execTransformAny(nullptr, transform::IsMax,
ops.execTransformAny(nullptr, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
nullptr); nullptr);
@ -953,8 +905,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0}); auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0}); auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
NativeOps ops; execTransformAny(nullptr, transform::IsMax,
ops.execTransformAny(nullptr, transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
nullptr); nullptr);
@ -970,8 +921,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_2) {
Nd4jLong *ex[] = {tad, off}; Nd4jLong *ex[] = {tad, off};
float ea[] = {2, 1, 2}; float ea[] = {2, 1, 2};
NativeOps ops; execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
ops.execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr, arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr, arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
ea); ea);
@ -995,8 +945,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
nd4j::ops::greater_equal op; nd4j::ops::greater_equal op;
NativeOps ops; auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
auto shapeList = ops.calculateOutputShapes(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
delete shapeList; delete shapeList;
} }
@ -1013,8 +962,7 @@ TEST_F(JavaInteropTests, Test_L2_Loss_3) {
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())}; Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
nd4j::ops::l2_loss op; nd4j::ops::l2_loss op;
NativeOps ops; auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
auto status = ops.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
z.printIndexedBuffer("z"); z.printIndexedBuffer("z");
@ -1036,9 +984,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_3) {
ASSERT_EQ(2, ctx.width()); ASSERT_EQ(2, ctx.width());
NativeOps nativeOps;
nd4j::ops::add op; nd4j::ops::add op;
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(exp, z); ASSERT_EQ(exp, z);
} }
@ -1054,9 +1001,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_4) {
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
ctx.setIArguments(iArgs, 3); ctx.setIArguments(iArgs, 3);
NativeOps nativeOps;
nd4j::ops::tri op; nd4j::ops::tri op;
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(exp, z); ASSERT_EQ(exp, z);
} }
@ -1074,9 +1020,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_5) {
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo()); ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
NativeOps nativeOps;
nd4j::ops::matmul op; nd4j::ops::matmul op;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -1104,9 +1049,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_6) {
ctx.setIArguments(iArgs, 3); ctx.setIArguments(iArgs, 3);
NativeOps nativeOps;
nd4j::ops::matmul_bp op; nd4j::ops::matmul_bp op;
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }
@ -1122,7 +1066,6 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
ctx.setIArguments(iArgs, 1); ctx.setIArguments(iArgs, 1);
NativeOps nativeOps;
nd4j::ops::concat op; nd4j::ops::concat op;
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
@ -1130,7 +1073,7 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx); auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z); ASSERT_EQ(e, z);
@ -1138,10 +1081,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
/* /*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) { TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
NativeOps ops;
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
auto ptr = ops.executeFlatGraph(nullptr, pl); auto ptr = executeFlatGraph(nullptr, pl);
// at this point we have FlatResults // at this point we have FlatResults
auto flatResult = GetFlatResult(ptr->pointer()); auto flatResult = GetFlatResult(ptr->pointer());
@ -1190,8 +1131,6 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
} }
*/ */
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) { // TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
// NativeOps ops;
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f}; // std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
// std::array<float, 60> syn1; // std::array<float, 60> syn1;
// std::array<float, 100000> exp; // std::array<float, 100000> exp;
@ -1283,5 +1222,5 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data()); // ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
// ops.execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data()); // execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
// } // }

View File

@ -53,8 +53,7 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
x.syncToDevice(); x.syncToDevice();
NativeOps nativeOps; sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
nativeOps.sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
x.tickWriteDevice(); x.tickWriteDevice();
ASSERT_EQ(e, x); ASSERT_EQ(e, x);

View File

@ -501,8 +501,7 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps; execReduce3Tad(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
nativeOps.execReduce3(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
nullptr, nullptr, nullptr, nullptr); nullptr, nullptr, nullptr, nullptr);
} }
@ -517,9 +516,8 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps;
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, execReduce3Tad(nullptr, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
@ -543,9 +541,8 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps;
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, execReduce3Tad(nullptr, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
@ -569,9 +566,8 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
auto dim = NDArrayFactory::create<int>('c', {1}, {1}); auto dim = NDArrayFactory::create<int>('c', {1}, {1});
NativeOps nativeOps;
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, execReduce3Tad(nullptr, reduce3::CosineDistance,
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, nullptr,
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
@ -593,8 +589,7 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1); auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1); auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
NativeOps ops; execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
ops.execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),

View File

@ -33,8 +33,6 @@ public:
}; };
TEST_F(MmapTests, Test_Basic_Mmap_1) { TEST_F(MmapTests, Test_Basic_Mmap_1) {
NativeOps nativeOps;
// just 10GB // just 10GB
Nd4jLong size = 100000L; Nd4jLong size = 100000L;
@ -43,11 +41,11 @@ TEST_F(MmapTests, Test_Basic_Mmap_1) {
ofs.write("", 1); ofs.write("", 1);
ofs.close(); ofs.close();
auto result = nativeOps.mmapFile(nullptr, "file", size); auto result = mmapFile(nullptr, "file", size);
ASSERT_FALSE(result == nullptr); ASSERT_FALSE(result == nullptr);
nativeOps.munmapFile(nullptr, result, size); munmapFile(nullptr, result, size);
remove("file"); remove("file");
} }

View File

@ -2258,7 +2258,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9}); auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {5, 8}); auto z = NDArrayFactory::create<float>('c', {5, 8});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(4); std::vector<void*> buffers(4);
@ -2272,7 +2271,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
} }
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat result linear"); z.printBuffer("C Concat result linear");
@ -2281,7 +2280,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9}); auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('f', {5, 8}); auto z = NDArrayFactory::create<float>('f', {5, 8});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(4); std::vector<void*> buffers(4);
@ -2295,7 +2293,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
} }
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("F Concat result linear"); z.printBuffer("F Concat result linear");
@ -2304,7 +2302,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6}); auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9}); auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('f', {3, 3}); auto z = NDArrayFactory::create<float>('f', {3, 3});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(2); std::vector<void*> buffers(2);
@ -2321,7 +2318,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
//} //}
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("F Concat result linear"); z.printBuffer("F Concat result linear");
@ -2331,7 +2328,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6}); auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9}); auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {3, 3}); auto z = NDArrayFactory::create<float>('c', {3, 3});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(2); std::vector<void*> buffers(2);
@ -2348,7 +2344,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
//} //}
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat result linear"); z.printBuffer("C Concat result linear");
@ -2358,7 +2354,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) { TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6}); auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12}); auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {2, 2, 3}); auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(2); std::vector<void*> buffers(2);
@ -2375,7 +2370,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
//} //}
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat result linear"); z.printBuffer("C Concat result linear");
@ -2385,7 +2380,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12}); auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18}); auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24}); auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {4, 2, 3}); auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(3); std::vector<void*> buffers(3);
@ -2406,7 +2400,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
printf("The third array is %p\n", buffers[2]); printf("The third array is %p\n", buffers[2]);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat3D result linear"); z.printBuffer("C Concat3D result linear");
@ -2417,7 +2411,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
auto x1 = NDArrayFactory::create<float>(1); auto x1 = NDArrayFactory::create<float>(1);
auto x2 = NDArrayFactory::create<float>(2); auto x2 = NDArrayFactory::create<float>(2);
auto x3 = NDArrayFactory::create<float>(3); auto x3 = NDArrayFactory::create<float>(3);
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3}); auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(3); std::vector<void*> buffers(3);
@ -2438,7 +2431,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
printf("The third array is %p\n", buffers[2]); printf("The third array is %p\n", buffers[2]);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
z.printIndexedBuffer("Concat result"); z.printIndexedBuffer("Concat result");
z.printBuffer("C Concat scalar result linear"); z.printBuffer("C Concat scalar result linear");
@ -2462,7 +2455,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
lx[i].assign(i); lx[i].assign(i);
} }
NativeOps native;
auto z = NDArrayFactory::create<float>('c', {totalCount, width}); auto z = NDArrayFactory::create<float>('c', {totalCount, width});
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
std::vector<void*> buffers(totalCount); std::vector<void*> buffers(totalCount);
@ -2478,7 +2470,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
printf("The third array is %p\n", buffers[2]); printf("The third array is %p\n", buffers[2]);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
extra[1] = *stream; extra[1] = *stream;
native.concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
z.syncToHost(); z.syncToHost();
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1))); nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
//z.printIndexedBuffer("Concat result"); //z.printIndexedBuffer("Concat result");
@ -2496,7 +2488,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
arrays.emplace_back(input); arrays.emplace_back(input);
} }
auto z = NDArrayFactory::create<float>('c', {total, 10, 10}); auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
NativeOps native;
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
@ -2512,7 +2503,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
hostShapes[i] = arrays[i].shapeInfo(); hostShapes[i] = arrays[i].shapeInfo();
} }
native.concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
nd4j::ops::tear op; nd4j::ops::tear op;
auto result = op.execute({&z}, {}, {1, 2}); auto result = op.execute({&z}, {}, {1, 2});
@ -2536,7 +2527,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
arrays.emplace_back(input); arrays.emplace_back(input);
} }
auto z = NDArrayFactory::create<float>('c', {10, 10, 10}); auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
NativeOps native;
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream); auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
Nd4jPointer extra[2]; Nd4jPointer extra[2];
@ -2552,7 +2542,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
hostShapes[i] = arrays[i].shapeInfo(); hostShapes[i] = arrays[i].shapeInfo();
} }
std::vector<int> dimsToExclude({1,2}); std::vector<int> dimsToExclude({1,2});
native.concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr); ::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
// z.syncToHost(); // z.syncToHost();
// z.printBuffer("Pile OK"); // z.printBuffer("Pile OK");
// z.printIndexedBuffer("Pile 10x10"); // z.printIndexedBuffer("Pile 10x10");
@ -2569,7 +2559,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
Nd4jPointer target = arrays[i].specialBuffer(); Nd4jPointer target = arrays[i].specialBuffer();
cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice); cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice);
} }
native.tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets()); ::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
// auto result = op.execute({&z}, {}, {1, 2}); // auto result = op.execute({&z}, {}, {1, 2});
// nd4j_printf("Result count is %lu\n", result->size()); // nd4j_printf("Result count is %lu\n", result->size());
//ASSERT_EQ(10, result->size()); //ASSERT_EQ(10, result->size());

View File

@ -313,12 +313,10 @@ TEST_F(PlaygroundTests, test_reduce_3) {
Nd4jLong max = 0L; Nd4jLong max = 0L;
Nd4jLong min = DataTypeUtils::max<Nd4jLong>(); Nd4jLong min = DataTypeUtils::max<Nd4jLong>();
NativeOps nativeOps;
for (int e = 0; e < iterations; e++) { for (int e = 0; e < iterations; e++) {
auto timeStart = std::chrono::system_clock::now(); auto timeStart = std::chrono::system_clock::now();
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), execReduce3Tad(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(),
y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr, dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr,
@ -964,8 +962,6 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count(); auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count();
NativeOps nativeOps;
Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0}; Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0};
Nd4jPointer inputBuffers[] = {input.buffer()}; Nd4jPointer inputBuffers[] = {input.buffer()};
Nd4jPointer inputShapes[] = {input.shapeInfo()}; Nd4jPointer inputShapes[] = {input.shapeInfo()};
@ -976,7 +972,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
auto javaStart = std::chrono::system_clock::now(); auto javaStart = std::chrono::system_clock::now();
for (int e = 0; e < iterations; e++) { for (int e = 0; e < iterations; e++) {
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
} }
auto javaEnd = std::chrono::system_clock::now(); auto javaEnd = std::chrono::system_clock::now();
@ -990,7 +986,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
for (int e = 0; e < iterations; e++) { for (int e = 0; e < iterations; e++) {
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false); execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
} }
auto javaPermEnd = std::chrono::system_clock::now(); auto javaPermEnd = std::chrono::system_clock::now();
@ -1020,9 +1016,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_2) {
Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()}; Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()};
Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()}; Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()};
NativeOps nativeOps; execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
} }
TEST_F(PlaygroundTests, Test_Col2Im_1) { TEST_F(PlaygroundTests, Test_Col2Im_1) {
@ -1140,8 +1134,6 @@ TEST_F(PlaygroundTests, loop_test_1) {
int length = (int) array->lengthOf(); int length = (int) array->lengthOf();
int span = (int) (array->lengthOf() / 6) + 8; int span = (int) (array->lengthOf() / 6) + 8;
NativeOps ops;
auto t = new int[1000000]; auto t = new int[1000000];
@ -1150,7 +1142,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
FloatBits fb; FloatBits fb;
float threshold = 0.99f; float threshold = 0.99f;
fb.f_ = threshold; fb.f_ = threshold;
int le = ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold); int le = estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
t[0] = le; t[0] = le;
t[1] = length; t[1] = length;
@ -1162,7 +1154,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
for (int x = 0; x < iterations; x++) { for (int x = 0; x < iterations; x++) {
auto permStart = std::chrono::system_clock::now(); auto permStart = std::chrono::system_clock::now();
ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold); estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t); TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t);
auto permEnd = std::chrono::system_clock::now(); auto permEnd = std::chrono::system_clock::now();

View File

@ -29,7 +29,6 @@ using namespace nd4j;
class RNGTests : public testing::Test { class RNGTests : public testing::Test {
private: private:
NativeOps nativeOps;
//Nd4jLong *_bufferA; //Nd4jLong *_bufferA;
//Nd4jLong *_bufferB; //Nd4jLong *_bufferB;
@ -47,8 +46,8 @@ public:
RNGTests() { RNGTests() {
//_bufferA = new Nd4jLong[100000]; //_bufferA = new Nd4jLong[100000];
//_bufferB = new Nd4jLong[100000]; //_bufferB = new Nd4jLong[100000];
//_rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); //_rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
//_rngB = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB); //_rngB = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
_rngA.setStates(_seed, _seed); _rngA.setStates(_seed, _seed);
_rngB.setStates(_seed, _seed); _rngB.setStates(_seed, _seed);
nexp0->assign(-1.0f); nexp0->assign(-1.0f);
@ -57,8 +56,8 @@ public:
} }
~RNGTests() { ~RNGTests() {
//nativeOps.destroyRandom(_rngA); //destroyRandom(_rngA);
//nativeOps.destroyRandom(_rngB); //destroyRandom(_rngB);
//delete[] _bufferA; //delete[] _bufferA;
//delete[] _bufferB; //delete[] _bufferB;
@ -791,14 +790,13 @@ namespace nd4j {
} }
TEST_F(RNGTests, Test_Reproducibility_9) { TEST_F(RNGTests, Test_Reproducibility_9) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 28, 28}; std::vector<Nd4jLong> shape = {32, 3, 28, 28};
const int bufferSize = 10000; const int bufferSize = 10000;
int64_t buffer[bufferSize]; int64_t buffer[bufferSize];
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer); auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
const int length = 4000000; const int length = 4000000;
int *arrayE = new int[length]; int *arrayE = new int[length];
@ -809,7 +807,7 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
rng->rewindH(static_cast<Nd4jLong>(length)); rng->rewindH(static_cast<Nd4jLong>(length));
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng)); refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
for (int e = 0; e < length; e++) for (int e = 0; e < length; e++)
arrayT[e] = rng->relativeInt(e); arrayT[e] = rng->relativeInt(e);
@ -825,18 +823,17 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
delete[] arrayE; delete[] arrayE;
delete[] arrayT; delete[] arrayT;
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng)); destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
} }
TEST_F(RNGTests, Test_Reproducibility_8) { TEST_F(RNGTests, Test_Reproducibility_8) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<int> shape = {32, 3, 28, 28}; std::vector<int> shape = {32, 3, 28, 28};
const int bufferSize = 10000; const int bufferSize = 10000;
int64_t buffer[bufferSize]; int64_t buffer[bufferSize];
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer); auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
const int length = 4000000; const int length = 4000000;
int *arrayE = new int[length]; int *arrayE = new int[length];
@ -847,7 +844,7 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
rng->rewindH(static_cast<Nd4jLong>(length)); rng->rewindH(static_cast<Nd4jLong>(length));
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng)); refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
for (int e = 0; e < length; e++) for (int e = 0; e < length; e++)
arrayT[e] = static_cast<int>(rng->relativeT<float>(e)); arrayT[e] = static_cast<int>(rng->relativeT<float>(e));
@ -863,29 +860,27 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
delete[] arrayE; delete[] arrayE;
delete[] arrayT; delete[] arrayT;
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng)); destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
} }
TEST_F(RNGTests, Test_RandomBuffer_Half_1) { TEST_F(RNGTests, Test_RandomBuffer_Half_1) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 28, 28}; std::vector<Nd4jLong> shape = {32, 3, 28, 28};
const int bufferSize = 10000; const int bufferSize = 10000;
int64_t buffer[bufferSize]; int64_t buffer[bufferSize];
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer); auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
auto r0 = rng->relativeT<float16>(12L); auto r0 = rng->relativeT<float16>(12L);
auto r1 = rng->relativeT<float16>(13L); auto r1 = rng->relativeT<float16>(13L);
ASSERT_NE(r0, r1); ASSERT_NE(r0, r1);
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng)); destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
} }
TEST_F(RNGTests, Test_Reproducibility_1) { TEST_F(RNGTests, Test_Reproducibility_1) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 28, 28}; std::vector<Nd4jLong> shape = {32, 3, 28, 28};
@ -918,7 +913,6 @@ TEST_F(RNGTests, Test_Reproducibility_1) {
#ifndef DEBUG_BUILD #ifndef DEBUG_BUILD
TEST_F(RNGTests, Test_Reproducibility_2) { TEST_F(RNGTests, Test_Reproducibility_2) {
NativeOps ops;
Nd4jLong seed = 123; Nd4jLong seed = 123;
std::vector<Nd4jLong> shape = {32, 3, 64, 64}; std::vector<Nd4jLong> shape = {32, 3, 64, 64};

View File

@ -44,8 +44,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_key_1) {
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
NativeOps nativeOps; sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
@ -62,8 +61,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_val_1) {
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
NativeOps nativeOps; sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
@ -81,8 +79,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_key_1) {
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);
@ -100,8 +97,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_val_1) {
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
ASSERT_EQ(ek, k); ASSERT_EQ(ek, k);
ASSERT_EQ(ev, v); ASSERT_EQ(ev, v);

View File

@ -42,8 +42,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_key_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NativeOps nativeOps; sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
@ -60,8 +59,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NativeOps nativeOps; sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
@ -78,8 +76,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
NativeOps nativeOps; sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
k.printIndexedBuffer("KEYS"); k.printIndexedBuffer("KEYS");
@ -97,8 +94,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();
@ -119,8 +115,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_val_1) {
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
int axis = 1; int axis = 1;
NativeOps nativeOps; sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
nativeOps.sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
k.tickWriteDevice(); k.tickWriteDevice();
v.tickWriteDevice(); v.tickWriteDevice();

View File

@ -58,8 +58,7 @@ TEST_F(TypeCastTests, Test_ConvertDtype_1) {
float16 dst[5]; float16 dst[5];
float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f}; float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f};
NativeOps ops; convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
ops.convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
for (int e = 0; e < 5; e++) for (int e = 0; e < 5; e++)
ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f); ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f);