Refactor NativeOps.h to export C functions
parent
fad8da878f
commit
dcc72e23b2
|
@ -77,10 +77,7 @@ bool verbose = false;
|
||||||
#include <graph/ResultWrapper.h>
|
#include <graph/ResultWrapper.h>
|
||||||
#include <DebugInfo.h>
|
#include <DebugInfo.h>
|
||||||
|
|
||||||
class ND4J_EXPORT NativeOps {
|
extern "C" {
|
||||||
|
|
||||||
public:
|
|
||||||
NativeOps();
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -259,7 +256,7 @@ public:
|
||||||
* @param result
|
* @param result
|
||||||
* @param resultShapeInfo
|
* @param resultShapeInfo
|
||||||
*/
|
*/
|
||||||
void execReduceFloat(Nd4jPointer *extraPointers,
|
void execReduceFloat2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -270,7 +267,7 @@ public:
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
void *dDimension, Nd4jLong *dDimensionShape);
|
||||||
|
|
||||||
|
|
||||||
void execReduceSame(Nd4jPointer *extraPointers,
|
void execReduceSame2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -281,7 +278,7 @@ public:
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
void *dDimension, Nd4jLong *dDimensionShape);
|
||||||
|
|
||||||
|
|
||||||
void execReduceBool(Nd4jPointer *extraPointers,
|
void execReduceBool2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -292,7 +289,7 @@ public:
|
||||||
void *dDimension, Nd4jLong *dDimensionShape);
|
void *dDimension, Nd4jLong *dDimensionShape);
|
||||||
|
|
||||||
|
|
||||||
void execReduceLong(Nd4jPointer *extraPointers,
|
void execReduceLong2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -354,7 +351,7 @@ public:
|
||||||
* @param dimension
|
* @param dimension
|
||||||
* @param dimensionLength
|
* @param dimensionLength
|
||||||
*/
|
*/
|
||||||
void execReduce3(Nd4jPointer *extraPointers,
|
void execReduce3Tad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -457,7 +454,7 @@ public:
|
||||||
* @param dimension
|
* @param dimension
|
||||||
* @param dimensionLength
|
* @param dimensionLength
|
||||||
*/
|
*/
|
||||||
void execSummaryStats(Nd4jPointer *extraPointers,
|
void execSummaryStatsTad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -532,7 +529,7 @@ public:
|
||||||
* @param dimension
|
* @param dimension
|
||||||
* @param dimensionLength
|
* @param dimensionLength
|
||||||
*/
|
*/
|
||||||
void execScalar(Nd4jPointer *extraPointers,
|
void execScalarTad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -546,7 +543,7 @@ public:
|
||||||
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets,
|
||||||
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
Nd4jLong *tadShapeInfoZ, Nd4jLong *tadOffsetsZ);
|
||||||
|
|
||||||
void execScalarBool(Nd4jPointer *extraPointers,
|
void execScalarBoolTad(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
void *hX, Nd4jLong *hXShapeInfo,
|
void *hX, Nd4jLong *hXShapeInfo,
|
||||||
void *dX, Nd4jLong *dXShapeInfo,
|
void *dX, Nd4jLong *dXShapeInfo,
|
||||||
|
@ -743,7 +740,7 @@ public:
|
||||||
* Returns amount of free memory for current device
|
* Returns amount of free memory for current device
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jLong getDeviceFreeMemory();
|
Nd4jLong getDeviceFreeMemoryDefault();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -789,7 +786,7 @@ public:
|
||||||
* @param reserved
|
* @param reserved
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
int memcpy(Nd4jPointer dst,
|
int memcpySync(Nd4jPointer dst,
|
||||||
Nd4jPointer src,
|
Nd4jPointer src,
|
||||||
Nd4jLong size,
|
Nd4jLong size,
|
||||||
int flags,
|
int flags,
|
||||||
|
@ -819,7 +816,7 @@ public:
|
||||||
* @param reserved
|
* @param reserved
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
int memset(Nd4jPointer dst,
|
int memsetSync(Nd4jPointer dst,
|
||||||
int value,
|
int value,
|
||||||
Nd4jLong size,
|
Nd4jLong size,
|
||||||
int flags,
|
int flags,
|
||||||
|
@ -1058,8 +1055,7 @@ public:
|
||||||
nd4j::DataType dtype);
|
nd4j::DataType dtype);
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
void batchExecutor(Nd4jPointer *extraPointers,
|
||||||
void _batchExecutor(Nd4jPointer *extraPointers,
|
|
||||||
int numAggregates,
|
int numAggregates,
|
||||||
int opNum,
|
int opNum,
|
||||||
int maxArgs,
|
int maxArgs,
|
||||||
|
@ -1116,7 +1112,7 @@ public:
|
||||||
* @param zShapeBuffer
|
* @param zShapeBuffer
|
||||||
* @param extraArguments
|
* @param extraArguments
|
||||||
*/
|
*/
|
||||||
void execRandom(Nd4jPointer *extraPointers,
|
void execRandom3(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Nd4jPointer state,
|
Nd4jPointer state,
|
||||||
void *hX, Nd4jLong *hXShapeBuffer,
|
void *hX, Nd4jLong *hXShapeBuffer,
|
||||||
|
@ -1138,7 +1134,7 @@ public:
|
||||||
* @param zShapeBuffer
|
* @param zShapeBuffer
|
||||||
* @param extraArguments
|
* @param extraArguments
|
||||||
*/
|
*/
|
||||||
void execRandom(Nd4jPointer *extraPointers,
|
void execRandom2(Nd4jPointer *extraPointers,
|
||||||
int opNum,
|
int opNum,
|
||||||
Nd4jPointer state,
|
Nd4jPointer state,
|
||||||
void *hX, Nd4jLong *hXShapeBuffer,
|
void *hX, Nd4jLong *hXShapeBuffer,
|
||||||
|
@ -1232,6 +1228,9 @@ public:
|
||||||
double scalarB);
|
double scalarB);
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param data
|
* @param data
|
||||||
|
@ -1267,7 +1266,9 @@ public:
|
||||||
return reinterpret_cast<Nd4jPointer>(ret);
|
return reinterpret_cast<Nd4jPointer>(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) {
|
extern "C" {
|
||||||
|
|
||||||
|
static Nd4jPointer numpyHeaderForNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize,Nd4jLong *headerSize) {
|
||||||
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
|
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
|
||||||
auto type = nd4j::ArrayOptions::dataType(shapeBufferCast);
|
auto type = nd4j::ArrayOptions::dataType(shapeBufferCast);
|
||||||
BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(type, return _numpyHeaderForNd4j, (data, shapeBuffer, wordSize, headerSize), LIBND4J_TYPES);
|
||||||
|
@ -1279,7 +1280,7 @@ public:
|
||||||
* @param data the header data to parse
|
* @param data the header data to parse
|
||||||
* @return a pointer to a numpy cnpy:NpyArray struct
|
* @return a pointer to a numpy cnpy:NpyArray struct
|
||||||
*/
|
*/
|
||||||
Nd4jPointer loadNpyFromHeader(Nd4jPointer data) {
|
static Nd4jPointer loadNpyFromHeader(Nd4jPointer data) {
|
||||||
char *header = reinterpret_cast<char *>(data);
|
char *header = reinterpret_cast<char *>(data);
|
||||||
|
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header);
|
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(header);
|
||||||
|
@ -1295,6 +1296,7 @@ public:
|
||||||
return reinterpret_cast<Nd4jPointer>(ret);
|
return reinterpret_cast<Nd4jPointer>(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Create a numpy array from an nd4j
|
* Create a numpy array from an nd4j
|
||||||
|
@ -1329,8 +1331,9 @@ public:
|
||||||
return rettPointer;
|
return rettPointer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
Nd4jPointer numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) {
|
static Nd4jPointer numpyFromNd4j(Nd4jPointer data,Nd4jPointer shapeBuffer,Nd4jLong wordSize) {
|
||||||
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
|
auto shapeBufferCast = reinterpret_cast<Nd4jLong *>(shapeBuffer);
|
||||||
auto type = nd4j::ArrayOptions::dataType(shapeBufferCast);
|
auto type = nd4j::ArrayOptions::dataType(shapeBufferCast);
|
||||||
BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, (data, shapeBuffer, wordSize), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(type, return _numpyFromNd4j, (data, shapeBuffer, wordSize), LIBND4J_TYPES);
|
||||||
|
@ -1352,7 +1355,7 @@ public:
|
||||||
* @param npyArray
|
* @param npyArray
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jPointer shapeBufferForNumpyHeader(Nd4jPointer npyArray) {
|
static Nd4jPointer shapeBufferForNumpyHeader(Nd4jPointer npyArray) {
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
|
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
|
||||||
auto shape = new unsigned int[arr.shape.size()];
|
auto shape = new unsigned int[arr.shape.size()];
|
||||||
for(unsigned int i = 0; i < arr.shape.size(); i++) {
|
for(unsigned int i = 0; i < arr.shape.size(); i++) {
|
||||||
|
@ -1371,7 +1374,7 @@ public:
|
||||||
* @param npyArray
|
* @param npyArray
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jPointer dataPointForNumpyHeader(Nd4jPointer npyArray) {
|
static Nd4jPointer dataPointForNumpyHeader(Nd4jPointer npyArray) {
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
|
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
|
||||||
unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arr.data);
|
unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arr.data);
|
||||||
return dataToPrint;
|
return dataToPrint;
|
||||||
|
@ -1382,7 +1385,7 @@ public:
|
||||||
* @param npyArray
|
* @param npyArray
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jPointer dataPointForNumpyStruct(Nd4jPointer npyArrayStruct) {
|
static Nd4jPointer dataPointForNumpyStruct(Nd4jPointer npyArrayStruct) {
|
||||||
cnpy::NpyArray *arrPointer = reinterpret_cast<cnpy::NpyArray *>(npyArrayStruct);
|
cnpy::NpyArray *arrPointer = reinterpret_cast<cnpy::NpyArray *>(npyArrayStruct);
|
||||||
unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arrPointer->data);
|
unsigned char *dataToPrint = reinterpret_cast<unsigned char *>(arrPointer->data);
|
||||||
return reinterpret_cast<Nd4jPointer>(dataToPrint);
|
return reinterpret_cast<Nd4jPointer>(dataToPrint);
|
||||||
|
@ -1394,7 +1397,7 @@ public:
|
||||||
* @param fromFile
|
* @param fromFile
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jPointer dataPointForNumpy(Nd4jPointer npyArray) {
|
static Nd4jPointer dataPointForNumpy(Nd4jPointer npyArray) {
|
||||||
char *npyArrayBuffer = reinterpret_cast< char *>(npyArray);
|
char *npyArrayBuffer = reinterpret_cast< char *>(npyArray);
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer);
|
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(npyArrayBuffer);
|
||||||
return dataPointForNumpyStruct(reinterpret_cast<Nd4jPointer>(&arr));
|
return dataPointForNumpyStruct(reinterpret_cast<Nd4jPointer>(&arr));
|
||||||
|
@ -1406,7 +1409,7 @@ public:
|
||||||
* @param path
|
* @param path
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
Nd4jPointer numpyFromFile(std::string path) {
|
static Nd4jPointer numpyFromFile(std::string path) {
|
||||||
char *numpyBuffer = cnpy::loadFile(path.data());
|
char *numpyBuffer = cnpy::loadFile(path.data());
|
||||||
return reinterpret_cast<Nd4jPointer >(numpyBuffer);
|
return reinterpret_cast<Nd4jPointer >(numpyBuffer);
|
||||||
}
|
}
|
||||||
|
@ -1414,7 +1417,7 @@ public:
|
||||||
|
|
||||||
////// NPZ //////
|
////// NPZ //////
|
||||||
|
|
||||||
void* mapFromNpzFile(std::string path){
|
static void* mapFromNpzFile(std::string path){
|
||||||
cnpy::npz_t* mapPtr = new cnpy::npz_t();
|
cnpy::npz_t* mapPtr = new cnpy::npz_t();
|
||||||
cnpy::npz_t map = cnpy::npzLoad(path);
|
cnpy::npz_t map = cnpy::npzLoad(path);
|
||||||
mapPtr->insert(map.begin(), map.end());
|
mapPtr->insert(map.begin(), map.end());
|
||||||
|
@ -1422,13 +1425,13 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
int getNumNpyArraysInMap(void *map){
|
static int getNumNpyArraysInMap(void *map){
|
||||||
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
||||||
int n = arrays->size();
|
int n = arrays->size();
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* getNpyArrayNameFromMap(void *map, int index){
|
static const char* getNpyArrayNameFromMap(void *map, int index){
|
||||||
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
||||||
cnpy::npz_t::iterator it = arrays->begin();
|
cnpy::npz_t::iterator it = arrays->begin();
|
||||||
cnpy::npz_t::iterator end = arrays->end();
|
cnpy::npz_t::iterator end = arrays->end();
|
||||||
|
@ -1442,7 +1445,7 @@ public:
|
||||||
throw std::runtime_error("No array at index.");
|
throw std::runtime_error("No array at index.");
|
||||||
}
|
}
|
||||||
|
|
||||||
void* getNpyArrayFromMap(void *map, int index){
|
static void* getNpyArrayFromMap(void *map, int index){
|
||||||
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
||||||
cnpy::npz_t::iterator it = arrays->begin();
|
cnpy::npz_t::iterator it = arrays->begin();
|
||||||
cnpy::npz_t::iterator end = arrays->end();
|
cnpy::npz_t::iterator end = arrays->end();
|
||||||
|
@ -1459,18 +1462,18 @@ public:
|
||||||
|
|
||||||
int dataTypeFromNpyHeader(void *header);
|
int dataTypeFromNpyHeader(void *header);
|
||||||
|
|
||||||
void* getNpyArrayData(void *npArray){
|
static void* getNpyArrayData(void *npArray){
|
||||||
cnpy::NpyArray* npyArray2 = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
cnpy::NpyArray* npyArray2 = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
||||||
return reinterpret_cast<void*>(npyArray2->data);
|
return reinterpret_cast<void*>(npyArray2->data);
|
||||||
}
|
}
|
||||||
|
|
||||||
int getNpyArrayRank(void *npArray){
|
static int getNpyArrayRank(void *npArray){
|
||||||
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
||||||
int rank = arr->shape.size();
|
int rank = arr->shape.size();
|
||||||
return rank;
|
return rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* getNpyArrayShape(void *npArray){
|
static Nd4jLong* getNpyArrayShape(void *npArray){
|
||||||
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
||||||
int ndim = arr->shape.size();
|
int ndim = arr->shape.size();
|
||||||
Nd4jLong* shape = new Nd4jLong[ndim];
|
Nd4jLong* shape = new Nd4jLong[ndim];
|
||||||
|
@ -1480,22 +1483,22 @@ public:
|
||||||
return shape;
|
return shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
char getNpyArrayOrder(void *npArray){
|
static char getNpyArrayOrder(void *npArray){
|
||||||
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
||||||
return (arr->fortranOrder)?'f':'c';
|
return (arr->fortranOrder)?'f':'c';
|
||||||
}
|
}
|
||||||
|
|
||||||
int getNpyArrayElemSize(void *npArray){
|
static int getNpyArrayElemSize(void *npArray){
|
||||||
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
||||||
return arr->wordSize;
|
return arr->wordSize;
|
||||||
}
|
}
|
||||||
|
|
||||||
void deleteNPArrayStruct(void *npArray){
|
static void deleteNPArrayStruct(void *npArray){
|
||||||
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
cnpy::NpyArray* arr = reinterpret_cast<cnpy::NpyArray*>(npArray);
|
||||||
delete arr;
|
delete arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void deleteNPArrayMap(void *map){
|
static void deleteNPArrayMap(void *map){
|
||||||
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
cnpy::npz_t* arrays = reinterpret_cast<cnpy::npz_t*>(map);
|
||||||
delete arrays;
|
delete arrays;
|
||||||
}
|
}
|
||||||
|
@ -1507,7 +1510,7 @@ public:
|
||||||
* to get the length for
|
* to get the length for
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
int elementSizeForNpyArray(Nd4jPointer npyArray) {
|
static int elementSizeForNpyArray(Nd4jPointer npyArray) {
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||||
cnpy::NpyArray *arrPointer = &arr;
|
cnpy::NpyArray *arrPointer = &arr;
|
||||||
int size = arrPointer->wordSize;
|
int size = arrPointer->wordSize;
|
||||||
|
@ -1522,7 +1525,7 @@ public:
|
||||||
* to get the length for
|
* to get the length for
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
int elementSizeForNpyArrayHeader(Nd4jPointer npyArray) {
|
static int elementSizeForNpyArrayHeader(Nd4jPointer npyArray) {
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
|
cnpy::NpyArray arr = cnpy::loadNpyFromHeader(reinterpret_cast<char *>(npyArray));
|
||||||
cnpy::NpyArray *arrPointer = &arr;
|
cnpy::NpyArray *arrPointer = &arr;
|
||||||
int size = arrPointer->wordSize;
|
int size = arrPointer->wordSize;
|
||||||
|
@ -1530,7 +1533,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void releaseNumpy(Nd4jPointer npyArray) {
|
static void releaseNumpy(Nd4jPointer npyArray) {
|
||||||
free(reinterpret_cast<void *>(npyArray));
|
free(reinterpret_cast<void *>(npyArray));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1647,10 +1650,10 @@ public:
|
||||||
|
|
||||||
// customOp executioner
|
// customOp executioner
|
||||||
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace);
|
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace);
|
||||||
int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext);
|
int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext);
|
||||||
|
|
||||||
nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
||||||
nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
||||||
|
|
||||||
void deleteShapeList(Nd4jPointer shapeList);
|
void deleteShapeList(Nd4jPointer shapeList);
|
||||||
|
|
||||||
|
@ -1690,25 +1693,22 @@ public:
|
||||||
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
|
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
|
||||||
int* hIindexes, int* dIindexes);
|
int* hIindexes, int* dIindexes);
|
||||||
|
|
||||||
|
void deleteShapeBuffer(Nd4jPointer ptr);
|
||||||
|
void deleteTadPack(Nd4jPointer ptr);
|
||||||
|
|
||||||
void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
|
void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
|
||||||
|
|
||||||
|
|
||||||
nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
|
nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
|
||||||
|
|
||||||
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, Nd4jLong *data, int length);
|
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length);
|
||||||
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, double *data, int length);
|
nd4j::ConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
|
||||||
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
||||||
|
|
||||||
void deleteShapeBuffer(Nd4jPointer ptr);
|
|
||||||
void deleteTadPack(Nd4jPointer ptr);
|
|
||||||
|
|
||||||
const char* runLightBenchmarkSuit(bool printOut);
|
const char* runLightBenchmarkSuit(bool printOut);
|
||||||
const char* runFullBenchmarkSuit(bool printOut);
|
const char* runFullBenchmarkSuit(bool printOut);
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#endif //NATIVEOPERATIONS_NATIVEOPS_H
|
#endif //NATIVEOPERATIONS_NATIVEOPS_H
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -27,11 +27,10 @@ namespace nd4j {
|
||||||
ProviderRNG::ProviderRNG() {
|
ProviderRNG::ProviderRNG() {
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[100000];
|
Nd4jLong *buffer = new Nd4jLong[100000];
|
||||||
NativeOps nativeOps;
|
|
||||||
std::lock_guard<std::mutex> lock(_mutex);
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
#ifndef __CUDABLAS__
|
#ifndef __CUDABLAS__
|
||||||
// at this moment we don't have streams etc, so let's just skip this for now
|
// at this moment we don't have streams etc, so let's just skip this for now
|
||||||
_rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
_rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||||
#endif
|
#endif
|
||||||
// if(_rng != nullptr)
|
// if(_rng != nullptr)
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,8 +41,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
|
// FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream
|
||||||
NativeOps nativeOps;
|
refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
||||||
nativeOps.refreshBuffer(nullptr, seed, (Nd4jPointer) rng);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,11 +110,9 @@ namespace helpers {
|
||||||
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
|
indices->syncToDevice(); // linspace only on CPU, so sync to Device as well
|
||||||
|
|
||||||
NDArray scores(*scales);
|
NDArray scores(*scales);
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, stream};
|
Nd4jPointer extras[2] = {nullptr, stream};
|
||||||
|
|
||||||
nativeOps.sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
sortByValue(extras, indices->buffer(), indices->shapeInfo(), indices->specialBuffer(), indices->specialShapeInfo(), scores.buffer(), scores.shapeInfo(), scores.specialBuffer(), scores.specialShapeInfo(), true);
|
||||||
// TO DO: sort indices using scales as value row
|
// TO DO: sort indices using scales as value row
|
||||||
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
//std::sort(indices.begin(), indices.end(), [scales](int i, int j) {return scales->e<T>(i) > scales->e<T>(j);});
|
||||||
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
I* indexBuf = reinterpret_cast<I*>(indices->specialBuffer());
|
||||||
|
|
|
@ -60,8 +60,7 @@ namespace helpers {
|
||||||
params[1] = context->getCudaStream();
|
params[1] = context->getCudaStream();
|
||||||
|
|
||||||
if (input->isVector()) {
|
if (input->isVector()) {
|
||||||
NativeOps ops;
|
sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
|
||||||
ops.sort(params, nullptr, sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), reverse);
|
|
||||||
|
|
||||||
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
|
cudaMemcpy(reinterpret_cast<T*>(output->specialBuffer()), reinterpret_cast<T*>(sortedVals.specialBuffer()) + n, sizeof(T), cudaMemcpyDeviceToDevice);
|
||||||
}
|
}
|
||||||
|
@ -74,8 +73,7 @@ namespace helpers {
|
||||||
auto pTadShapeH = packX.primaryShapeInfo();
|
auto pTadShapeH = packX.primaryShapeInfo();
|
||||||
auto pTadOffsets = packX.specialOffsets();
|
auto pTadOffsets = packX.specialOffsets();
|
||||||
// auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int));
|
// auto pLastDimData = (int*) manager.replicatePointer(lastDims.data(), lastDims.size() * sizeof(int));
|
||||||
NativeOps ops;
|
sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
|
||||||
ops.sortTad(params, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), lastDims.data(), lastDims.size(), pTadShape, pTadOffsets, reverse);
|
|
||||||
// manager.synchronize();
|
// manager.synchronize();
|
||||||
sortedVals.tickWriteDevice();
|
sortedVals.tickWriteDevice();
|
||||||
sortedVals.syncToHost();
|
sortedVals.syncToHost();
|
||||||
|
|
|
@ -38,32 +38,28 @@ TEST_F(HeaderTest, test_dataTypes_1) {
|
||||||
std::string header("0NUMPY6789{'descr': '>f4");
|
std::string header("0NUMPY6789{'descr': '>f4");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::FLOAT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HeaderTest, test_dataTypes_2) {
|
TEST_F(HeaderTest, test_dataTypes_2) {
|
||||||
std::string header("0NUMPY6789{'descr': '>f8");
|
std::string header("0NUMPY6789{'descr': '>f8");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::DOUBLE, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HeaderTest, test_dataTypes_3) {
|
TEST_F(HeaderTest, test_dataTypes_3) {
|
||||||
std::string header("0NUMPY6789{'descr': '<i4");
|
std::string header("0NUMPY6789{'descr': '<i4");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::INT32, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::INT32, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HeaderTest, test_dataTypes_4) {
|
TEST_F(HeaderTest, test_dataTypes_4) {
|
||||||
std::string header("0NUMPY6789{'descr': '>u2");
|
std::string header("0NUMPY6789{'descr': '>u2");
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
ASSERT_EQ(nd4j::DataType::UINT16, dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
||||||
ASSERT_EQ(nd4j::DataType::UINT16, nativeOps.dataTypeFromNpyHeader(const_cast<char *>(header.data())));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -88,8 +84,7 @@ TEST_F(LoadFromStringTest,PathTest) {
|
||||||
ASSERT_EQ(4.0,data[3]);
|
ASSERT_EQ(4.0,data[3]);
|
||||||
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
|
Nd4jPointer pointer = reinterpret_cast<Nd4jPointer >(&loadedArr);
|
||||||
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
|
int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr);
|
||||||
NativeOps nativeOps;
|
Nd4jPointer pointer1 = dataPointForNumpy(loaded);
|
||||||
Nd4jPointer pointer1 = nativeOps.dataPointForNumpy(loaded);
|
|
||||||
delete[] shapeBuffer;
|
delete[] shapeBuffer;
|
||||||
|
|
||||||
double *data2 = reinterpret_cast<double *>(pointer1);
|
double *data2 = reinterpret_cast<double *>(pointer1);
|
||||||
|
|
|
@ -472,9 +472,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
|
||||||
/*
|
/*
|
||||||
Nd4jLong *buffer = new Nd4jLong[100000];
|
Nd4jLong *buffer = new Nd4jLong[100000];
|
||||||
|
|
||||||
NativeOps nativeOps;
|
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
||||||
|
|
||||||
nd4j::random::RandomBuffer *rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer);
|
|
||||||
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("RNG initialization failed");
|
throw std::runtime_error("RNG initialization failed");
|
||||||
|
@ -496,7 +494,7 @@ TEST_F(DeclarableOpsTests1, TestRng1) {
|
||||||
|
|
||||||
ASSERT_TRUE(x->sumNumber() > 0.0);
|
ASSERT_TRUE(x->sumNumber() > 0.0);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
|
|
||||||
delete variableSpace;
|
delete variableSpace;
|
||||||
|
@ -1450,8 +1448,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////
|
||||||
// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) {
|
// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) {
|
||||||
// NativeOps nativeOps;
|
|
||||||
|
|
||||||
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
||||||
// x->assign(1.0f);
|
// x->assign(1.0f);
|
||||||
|
|
||||||
|
@ -1483,8 +1479,8 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
// outputShapes[0] = (Nd4jPointer) z->getShapeInfo();
|
// outputShapes[0] = (Nd4jPointer) z->getShapeInfo();
|
||||||
|
|
||||||
|
|
||||||
// //auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
|
// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false);
|
||||||
// auto status = nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
// auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
// ASSERT_EQ(ND4J_STATUS_OK, status);
|
// ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
// // z->printIndexedBuffer("Output add");
|
// // z->printIndexedBuffer("Output add");
|
||||||
// ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5);
|
// ASSERT_NEAR(2.0f, y->meanNumber().e<float>(0), 1e-5);
|
||||||
|
@ -1503,8 +1499,6 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
|
|
||||||
// //////////////////////////////////////////////////////////////////////
|
// //////////////////////////////////////////////////////////////////////
|
||||||
// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) {
|
// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) {
|
||||||
// NativeOps nativeOps;
|
|
||||||
|
|
||||||
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
// auto x = NDArrayFactory::create_<float>('c', {10, 10});
|
||||||
// x->assign(1.0f);
|
// x->assign(1.0f);
|
||||||
|
|
||||||
|
@ -1532,7 +1526,7 @@ TEST_F(DeclarableOpsTests1, TestRegistrator1) {
|
||||||
// auto outputBuffers = new Nd4jPointer[1];
|
// auto outputBuffers = new Nd4jPointer[1];
|
||||||
// auto outputShapes = new Nd4jPointer[1];
|
// auto outputShapes = new Nd4jPointer[1];
|
||||||
|
|
||||||
// nativeOps.execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
|
// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true);
|
||||||
|
|
||||||
// ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5);
|
// ASSERT_NEAR(2.0, y->meanNumber().e<float>(0), 1e-5);
|
||||||
// ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5);
|
// ASSERT_NEAR(3.0, x->meanNumber().e<float>(0), 1e-5);
|
||||||
|
|
|
@ -876,14 +876,13 @@ TEST_F(DeclarableOpsTests12, pullRows_1) {
|
||||||
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
||||||
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
||||||
|
|
||||||
NativeOps op;
|
|
||||||
Nd4jPointer nativeStart[2];
|
Nd4jPointer nativeStart[2];
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
nativeStart[1] = *(x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(),
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
4, pidx,
|
4, pidx,
|
||||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||||
|
@ -912,12 +911,11 @@ TEST_F(DeclarableOpsTests12, pullRows_2) {
|
||||||
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
auto xTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.getShapeInfo(), dims);
|
||||||
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
auto zTadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(z.getShapeInfo(), dims);
|
||||||
|
|
||||||
NativeOps op;
|
|
||||||
Nd4jPointer nativeStart[2];
|
Nd4jPointer nativeStart[2];
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
nativeStart[1] = *(x.getContext()->getCudaStream());
|
nativeStart[1] = *(x.getContext()->getCudaStream());
|
||||||
#endif
|
#endif
|
||||||
op.pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
pullRows(nativeStart, x.buffer(), x.getShapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.getShapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
4, pidx,
|
4, pidx,
|
||||||
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
xTadPack.platformShapeInfo(), xTadPack.platformOffsets(),
|
||||||
|
|
|
@ -110,8 +110,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
|
||||||
double extraParams[] = {lambda};
|
double extraParams[] = {lambda};
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
NativeOps nativeOps;
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !");
|
||||||
|
|
||||||
|
@ -122,7 +121,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) {
|
||||||
ASSERT_NEAR(mean, actualMean, 0.01);
|
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||||
ASSERT_NEAR(std, actualStd, 0.01);
|
ASSERT_NEAR(std, actualStd, 0.01);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -142,8 +141,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
|
||||||
|
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
NativeOps nativeOps;
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !");
|
||||||
|
|
||||||
|
@ -155,7 +153,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) {
|
||||||
ASSERT_NEAR(mean, actualMean, 0.01);
|
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||||
ASSERT_NEAR(std, actualStd, 0.01);
|
ASSERT_NEAR(std, actualStd, 0.01);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -172,8 +170,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
|
||||||
double extraParams[] = {lambda};
|
double extraParams[] = {lambda};
|
||||||
|
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
NativeOps nativeOps;
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
auto rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !");
|
||||||
|
|
||||||
|
@ -184,7 +181,7 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) {
|
||||||
ASSERT_NEAR(mean, actualMean, 0.01);
|
ASSERT_NEAR(mean, actualMean, 0.01);
|
||||||
ASSERT_NEAR(std, actualStd, 0.01);
|
ASSERT_NEAR(std, actualStd, 0.01);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
delete[] buffer;
|
delete[] buffer;
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
@ -206,14 +203,13 @@ TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) {
|
||||||
Nd4jLong *buffer = new Nd4jLong[N];
|
Nd4jLong *buffer = new Nd4jLong[N];
|
||||||
// Nd4jPointer extra[2];
|
// Nd4jPointer extra[2];
|
||||||
#ifndef __CUDABLAS__
|
#ifndef __CUDABLAS__
|
||||||
NativeOps nativeOps;
|
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
||||||
nd4j::random::RandomBuffer* rng = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, 123, N, (Nd4jPointer) buffer);
|
|
||||||
if (rng == nullptr)
|
if (rng == nullptr)
|
||||||
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
|
throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !");
|
||||||
|
|
||||||
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
|
functions::random::RandomFunction<double>::template execTransform<randomOps::ExponentialDistribution<double>>(rng, y.getBuffer(), y.getShapeInfo(), x.getBuffer(), x.getShapeInfo(), extraParams);
|
||||||
|
|
||||||
nativeOps.destroyRandom((Nd4jPointer) rng);
|
destroyRandom((Nd4jPointer) rng);
|
||||||
#endif
|
#endif
|
||||||
const double actualMean = x.meanNumber().e<double>(0);
|
const double actualMean = x.meanNumber().e<double>(0);
|
||||||
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
|
const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e<double>(0);
|
||||||
|
@ -1005,12 +1001,10 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
x0.linspace(1);
|
x0.linspace(1);
|
||||||
x1.linspace(1);
|
x1.linspace(1);
|
||||||
/*
|
/*
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
float prob[] = {0.5f};
|
float prob[] = {0.5f};
|
||||||
Nd4jLong* _bufferA = new Nd4jLong[100000];
|
Nd4jLong* _bufferA = new Nd4jLong[100000];
|
||||||
long _seed = 119L;
|
long _seed = 119L;
|
||||||
auto _rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
auto _rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||||
|
|
||||||
x0. applyTransform(random::DropOutInverted, &x0, prob);
|
x0. applyTransform(random::DropOutInverted, &x0, prob);
|
||||||
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
|
// x1.template applyRandom<randomOps::DropOutInverted<float>>(_rngB, nullptr, &x1, prob);
|
||||||
|
@ -1026,7 +1020,7 @@ TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) {
|
||||||
// ASSERT_FALSE(x0.equalsTo(nexp0));
|
// ASSERT_FALSE(x0.equalsTo(nexp0));
|
||||||
// ASSERT_FALSE(x0.equalsTo(nexp1));
|
// ASSERT_FALSE(x0.equalsTo(nexp1));
|
||||||
// ASSERT_FALSE(x0.equalsTo(nexp2));
|
// ASSERT_FALSE(x0.equalsTo(nexp2));
|
||||||
nativeOps.destroyRandom(_rngA);
|
destroyRandom(_rngA);
|
||||||
delete [] _bufferA;
|
delete [] _bufferA;
|
||||||
*/
|
*/
|
||||||
nd4j::ops::dropout op;
|
nd4j::ops::dropout op;
|
||||||
|
|
|
@ -51,9 +51,7 @@ public:
|
||||||
*/
|
*/
|
||||||
|
|
||||||
TEST_F(GraphStateTests, Basic_Tests_1) {
|
TEST_F(GraphStateTests, Basic_Tests_1) {
|
||||||
NativeOps nativeOps;
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
|
||||||
ASSERT_EQ(117L, state->id());
|
ASSERT_EQ(117L, state->id());
|
||||||
|
|
||||||
// this call will create scope internally
|
// this call will create scope internally
|
||||||
|
@ -72,14 +70,12 @@ TEST_F(GraphStateTests, Basic_Tests_1) {
|
||||||
ASSERT_TRUE(scope != nullptr);
|
ASSERT_TRUE(scope != nullptr);
|
||||||
ASSERT_EQ(2, scope->size());
|
ASSERT_EQ(2, scope->size());
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
// just separate case for doubles wrapper in NativeOps, nothing else
|
// just separate case for doubles wrapper in NativeOps, nothing else
|
||||||
TEST_F(GraphStateTests, Basic_Tests_2) {
|
TEST_F(GraphStateTests, Basic_Tests_2) {
|
||||||
NativeOps nativeOps;
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
|
||||||
ASSERT_EQ(117L, state->id());
|
ASSERT_EQ(117L, state->id());
|
||||||
|
|
||||||
// this call will create scope internally
|
// this call will create scope internally
|
||||||
|
@ -98,46 +94,40 @@ TEST_F(GraphStateTests, Basic_Tests_2) {
|
||||||
ASSERT_TRUE(scope != nullptr);
|
ASSERT_TRUE(scope != nullptr);
|
||||||
ASSERT_EQ(2, scope->size());
|
ASSERT_EQ(2, scope->size());
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
TEST_F(GraphStateTests, Stateful_Execution_1) {
|
||||||
NativeOps nativeOps;
|
auto state = getGraphState(117L);
|
||||||
|
|
||||||
auto state = nativeOps.getGraphState(117L);
|
|
||||||
|
|
||||||
Nd4jLong scopes[] = {22, 33};
|
Nd4jLong scopes[] = {22, 33};
|
||||||
//auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
//auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::THROW(), status);
|
ASSERT_EQ(Status::THROW(), status);
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_2) {
|
TEST_F(GraphStateTests, Stateful_Execution_2) {
|
||||||
NativeOps nativeOps;
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
|
||||||
|
|
||||||
state->registerScope(22);
|
state->registerScope(22);
|
||||||
state->registerScope(33);
|
state->registerScope(33);
|
||||||
|
|
||||||
Nd4jLong scopes[] = {22, 33};
|
Nd4jLong scopes[] = {22, 33};
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0);
|
||||||
|
|
||||||
// it's no-op: just LogicScope
|
// it's no-op: just LogicScope
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This test checks WHILE loop
|
* This test checks WHILE loop
|
||||||
*/
|
*/
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(11.0f);
|
auto var1 = NDArrayFactory::create<float>(11.0f);
|
||||||
auto var2 = NDArrayFactory::create<float>(2.0f);
|
auto var2 = NDArrayFactory::create<float>(2.0f);
|
||||||
|
@ -147,7 +137,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
auto res2 = NDArrayFactory::create<float>(0.0f);
|
auto res2 = NDArrayFactory::create<float>(0.0f);
|
||||||
|
|
||||||
// registering our GraphState holder
|
// registering our GraphState holder
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
// we're prepping pointers to input/output buffers
|
// we're prepping pointers to input/output buffers
|
||||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()};
|
||||||
|
@ -197,7 +187,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
Nd4jLong scopes[] = {22, 33};
|
Nd4jLong scopes[] = {22, 33};
|
||||||
|
|
||||||
// we're executing while loop
|
// we're executing while loop
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
|
auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
// now we check provided result array
|
// now we check provided result array
|
||||||
|
@ -211,7 +201,7 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
|
|
||||||
// nd4j_printf("0 ------------------\n","");
|
// nd4j_printf("0 ------------------\n","");
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
|
|
||||||
// nd4j_printf("1 ------------------\n","");
|
// nd4j_printf("1 ------------------\n","");
|
||||||
}
|
}
|
||||||
|
@ -220,8 +210,6 @@ TEST_F(GraphStateTests, Stateful_Execution_3) {
|
||||||
* This test checks CONDITIONAL execution for FALSE
|
* This test checks CONDITIONAL execution for FALSE
|
||||||
*/
|
*/
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
|
@ -232,7 +220,7 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
|
|
||||||
|
|
||||||
// registering our GraphState holder
|
// registering our GraphState holder
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
// we're prepping pointers to input/output buffers
|
// we're prepping pointers to input/output buffers
|
||||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||||
|
@ -283,14 +271,14 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
Nd4jLong scopes[] = {22, 33, 44};
|
Nd4jLong scopes[] = {22, 33, 44};
|
||||||
|
|
||||||
// we're executing conditional op
|
// we're executing conditional op
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(&res0));
|
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||||
ASSERT_TRUE(exp.equalsTo(&res0));
|
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||||
|
|
||||||
|
|
||||||
nativeOps.deleteGraphState(state);
|
deleteGraphState(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -298,8 +286,6 @@ TEST_F(GraphStateTests, Stateful_Execution_4) {
|
||||||
* This test checks CONDITIONAL execution for TRUE
|
* This test checks CONDITIONAL execution for TRUE
|
||||||
*/
|
*/
|
||||||
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
auto var0 = NDArrayFactory::create<float>('c', {2, 2}, {1, 2, 3, 4});
|
||||||
auto var1 = NDArrayFactory::create<float>(5.0f);
|
auto var1 = NDArrayFactory::create<float>(5.0f);
|
||||||
|
|
||||||
|
@ -310,7 +296,7 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
|
|
||||||
|
|
||||||
// registering our GraphState holder
|
// registering our GraphState holder
|
||||||
auto state = (GraphState *) nativeOps.getGraphState(117L);
|
auto state = (GraphState *) getGraphState(117L);
|
||||||
|
|
||||||
// we're prepping pointers to input/output buffers
|
// we're prepping pointers to input/output buffers
|
||||||
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()};
|
||||||
|
@ -361,12 +347,11 @@ TEST_F(GraphStateTests, Stateful_Execution_5) {
|
||||||
Nd4jLong scopes[] = {22, 33, 44};
|
Nd4jLong scopes[] = {22, 33, 44};
|
||||||
|
|
||||||
// we're executing conditional op
|
// we're executing conditional op
|
||||||
auto status = nativeOps.execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(&res0));
|
ASSERT_TRUE(exp.isSameShape(&res0));
|
||||||
ASSERT_TRUE(exp.equalsTo(&res0));
|
ASSERT_TRUE(exp.equalsTo(&res0));
|
||||||
|
|
||||||
|
deleteGraphState(state);
|
||||||
nativeOps.deleteGraphState(state);
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,7 +42,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
|
||||||
e.assign(2.f);
|
e.assign(2.f);
|
||||||
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
NativeOps nativeOps;
|
|
||||||
Context context(1);
|
Context context(1);
|
||||||
|
|
||||||
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||||
|
@ -53,7 +52,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) {
|
||||||
|
|
||||||
nd4j_printf("Starting execution...\n","");
|
nd4j_printf("Starting execution...\n","");
|
||||||
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
|
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1");
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context);
|
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||||
|
|
||||||
pm.synchronize();
|
pm.synchronize();
|
||||||
|
|
||||||
|
@ -71,7 +70,6 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
|
||||||
e.assign(false);
|
e.assign(false);
|
||||||
|
|
||||||
nd4j::ops::equals op;
|
nd4j::ops::equals op;
|
||||||
NativeOps nativeOps;
|
|
||||||
Context context(1);
|
Context context(1);
|
||||||
|
|
||||||
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer());
|
||||||
|
@ -82,7 +80,7 @@ TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) {
|
||||||
|
|
||||||
nd4j_printf("Starting execution...\n","");
|
nd4j_printf("Starting execution...\n","");
|
||||||
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
|
PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2");
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &context);
|
execCustomOp2(nullptr, op.getOpHash(), &context);
|
||||||
|
|
||||||
pm.synchronize();
|
pm.synchronize();
|
||||||
|
|
||||||
|
|
|
@ -41,8 +41,6 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
||||||
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
|
auto weights = NDArrayFactory::create<float>('c', {2, 2, 2, 3});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
|
auto exp = NDArrayFactory::create<float>('c', {1, 3, 5, 4});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::conv2d op;
|
nd4j::ops::conv2d op;
|
||||||
|
|
||||||
std::vector<double> tArgs({});
|
std::vector<double> tArgs({});
|
||||||
|
@ -50,7 +48,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
||||||
|
|
||||||
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo(), (Nd4jPointer) weights.getShapeInfo()};
|
||||||
|
|
||||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||||
|
|
||||||
ASSERT_EQ(1, shapeList->size());
|
ASSERT_EQ(1, shapeList->size());
|
||||||
|
|
||||||
|
@ -64,7 +62,7 @@ TEST_F(JavaInteropTests, TestShapeExposure1) {
|
||||||
//delete[] ptr;
|
//delete[] ptr;
|
||||||
//delete shapeList;
|
//delete shapeList;
|
||||||
|
|
||||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
deleteShapeList((Nd4jPointer) shapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,9 +70,6 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
|
||||||
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
auto input = NDArrayFactory::create<float>('c', {1, 2, 5, 4});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
|
auto exp = NDArrayFactory::create<float>('c', {4}, {1, 2, 5, 4});
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::shape_of op;
|
nd4j::ops::shape_of op;
|
||||||
|
|
||||||
std::vector<double> tArgs({});
|
std::vector<double> tArgs({});
|
||||||
|
@ -83,14 +78,14 @@ TEST_F(JavaInteropTests, TestShapeExposure2) {
|
||||||
|
|
||||||
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
|
Nd4jPointer ptrs[] = {(Nd4jPointer) input.getShapeInfo()};
|
||||||
|
|
||||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size());
|
||||||
|
|
||||||
ASSERT_EQ(1, shapeList->size());
|
ASSERT_EQ(1, shapeList->size());
|
||||||
|
|
||||||
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0)));
|
||||||
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]);
|
||||||
|
|
||||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
deleteShapeList((Nd4jPointer) shapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, TestShapeExposure3) {
|
TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||||
|
@ -112,13 +107,12 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||||
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
|
Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer()};
|
||||||
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
|
Nd4jPointer inputShapes[] = {x.shapeInfo(), sizes.shapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::split_v op;
|
nd4j::ops::split_v op;
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {1};
|
Nd4jLong iArgs[] = {1};
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
|
|
||||||
auto shapeList = nativeOps.calculateOutputShapes(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0);
|
||||||
|
|
||||||
ASSERT_EQ(3, shapeList->size());
|
ASSERT_EQ(3, shapeList->size());
|
||||||
|
|
||||||
|
@ -126,7 +120,7 @@ TEST_F(JavaInteropTests, TestShapeExposure3) {
|
||||||
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
|
ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1)));
|
||||||
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
|
ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2)));
|
||||||
|
|
||||||
nativeOps.deleteShapeList((Nd4jPointer) shapeList);
|
deleteShapeList((Nd4jPointer) shapeList);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
||||||
|
@ -143,10 +137,7 @@ TEST_F(JavaInteropTests, Test_Squeeze_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -167,10 +158,7 @@ TEST_F(JavaInteropTests, Test_RDiv_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -203,11 +191,9 @@ TEST_F(JavaInteropTests, TestSconv2d_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1,
|
||||||
nullptr, 0, exp, 9, nullptr, 0, false);
|
nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
//output.printBuffer("output");
|
//output.printBuffer("output");
|
||||||
|
@ -238,11 +224,9 @@ TEST_F(JavaInteropTests, TestSconv2d_2) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
//output.printBuffer("output");
|
//output.printBuffer("output");
|
||||||
|
|
||||||
|
@ -266,9 +250,7 @@ TEST_F(JavaInteropTests, TestMaxPooling2d_1) {
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
||||||
|
|
||||||
Nd4jStatus status = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -294,13 +276,11 @@ TEST_F(JavaInteropTests, TestCol2Im_1) {
|
||||||
|
|
||||||
nd4j::ops::col2im op;
|
nd4j::ops::col2im op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
|
Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1};
|
||||||
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
ASSERT_TRUE(output.meanNumber().e<float>(0) > 0.0f);
|
||||||
}
|
}
|
||||||
|
@ -320,8 +300,6 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
|
||||||
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
auto output = NDArrayFactory::create<float>('c', {1, 3, 3, 3});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::pnormpool2d op;
|
nd4j::ops::pnormpool2d op;
|
||||||
|
|
||||||
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0};
|
||||||
|
@ -332,7 +310,7 @@ TEST_F(JavaInteropTests, TestPNorm_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.getShapeInfo()};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
ASSERT_TRUE(output.meanNumber().e<double>(0) > 0.0);
|
||||||
}
|
}
|
||||||
|
@ -343,8 +321,6 @@ TEST_F(JavaInteropTests, TestInplace_1) {
|
||||||
//auto exp('c', {10, 10});
|
//auto exp('c', {10, 10});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nd4j::ops::clipbyvalue op;
|
nd4j::ops::clipbyvalue op;
|
||||||
|
|
||||||
double extras[] = {-1.0f, 1.0f};
|
double extras[] = {-1.0f, 1.0f};
|
||||||
|
@ -353,7 +329,7 @@ TEST_F(JavaInteropTests, TestInplace_1) {
|
||||||
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.getShapeInfo()};
|
||||||
|
|
||||||
|
|
||||||
Nd4jStatus result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true);
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result);
|
ASSERT_EQ(ND4J_STATUS_OK, result);
|
||||||
|
|
||||||
|
@ -415,7 +391,6 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||||
x.linspace(1.0);
|
x.linspace(1.0);
|
||||||
z.linspace(1.0);
|
z.linspace(1.0);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::avgpool2d op;
|
nd4j::ops::avgpool2d op;
|
||||||
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
//auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1});
|
||||||
|
|
||||||
|
@ -427,7 +402,7 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
auto result = nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), result);
|
ASSERT_EQ(Status::OK(), result);
|
||||||
|
|
||||||
|
@ -496,15 +471,13 @@ TEST_F(JavaInteropTests, test_avgpooling_edge_1) {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
TEST_F(JavaInteropTests, Test_GraphReuse_1) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
||||||
|
|
||||||
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data);
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||||
|
|
||||||
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
nativeOps.unregisterGraph(nullptr, 119);
|
unregisterGraph(nullptr, 119);
|
||||||
|
|
||||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
|
@ -520,8 +493,6 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
|
auto exp1 = NDArrayFactory::create<float>('c', {3}, {6, 6, 6});
|
||||||
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
|
auto exp2 = NDArrayFactory::create<float>('c', {3}, {9, 9, 9});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
// we load graph from file, because we're not in java here, and dont have buffer ready
|
// we load graph from file, because we're not in java here, and dont have buffer ready
|
||||||
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
uint8_t* data = nd4j::graph::readFlatBuffers("./resources/reduce_dim_false.fb");
|
||||||
|
|
||||||
|
@ -529,7 +500,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
// register the graph, to call for it later
|
// register the graph, to call for it later
|
||||||
nativeOps.registerGraph(nullptr, 119, (Nd4jPointer) data);
|
registerGraph(nullptr, 119, (Nd4jPointer) data);
|
||||||
|
|
||||||
// and ensure we're ok
|
// and ensure we're ok
|
||||||
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_TRUE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
@ -547,7 +518,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
|
Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()};
|
||||||
|
|
||||||
// now we're executing stored graph and providing replacement for input variable
|
// now we're executing stored graph and providing replacement for input variable
|
||||||
auto res_0 = nativeOps.executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1);
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res_0->status());
|
||||||
ASSERT_EQ(1, res_0->size());
|
ASSERT_EQ(1, res_0->size());
|
||||||
|
|
||||||
|
@ -562,7 +533,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
|
Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()};
|
||||||
|
|
||||||
// doing it again
|
// doing it again
|
||||||
auto res_1 = nativeOps.executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1);
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
||||||
ASSERT_EQ(1, res_1->size());
|
ASSERT_EQ(1, res_1->size());
|
||||||
|
|
||||||
|
@ -577,7 +548,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
|
Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()};
|
||||||
|
|
||||||
// and again
|
// and again
|
||||||
auto res_2 = nativeOps.executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1);
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
ASSERT_EQ(ND4J_STATUS_OK, res_1->status());
|
||||||
ASSERT_EQ(1, res_2->size());
|
ASSERT_EQ(1, res_2->size());
|
||||||
|
|
||||||
|
@ -586,7 +557,7 @@ TEST_F(JavaInteropTests, Test_GraphReuse_2) {
|
||||||
|
|
||||||
|
|
||||||
//////// clean out
|
//////// clean out
|
||||||
nativeOps.unregisterGraph(nullptr, 119);
|
unregisterGraph(nullptr, 119);
|
||||||
|
|
||||||
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
ASSERT_FALSE(GraphHolder::getInstance()->hasGraph(119));
|
||||||
|
|
||||||
|
@ -616,9 +587,7 @@ TEST_F(JavaInteropTests, Test_Greater_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
o.printIndexedBuffer("Greater JIT");
|
o.printIndexedBuffer("Greater JIT");
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
}
|
}
|
||||||
|
@ -641,9 +610,7 @@ TEST_F(JavaInteropTests, Test_Greater_2) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
}
|
}
|
||||||
|
@ -662,9 +629,8 @@ TEST_F(JavaInteropTests, Test_Boolean_Op_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.equalsTo(&o));
|
ASSERT_TRUE(exp.equalsTo(&o));
|
||||||
|
@ -685,9 +651,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
@ -710,9 +675,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(z));
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
@ -736,9 +700,8 @@ TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) {
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {1};
|
Nd4jLong iArgs[] = {1};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_TRUE(e.isSameShape(output));
|
ASSERT_TRUE(e.isSameShape(output));
|
||||||
|
@ -753,8 +716,7 @@ TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) {
|
||||||
|
|
||||||
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
auto dims = NDArrayFactory::create<int>('c', {2}, {0, 1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execReduce3Tad(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
|
||||||
nativeOps.execReduce3(nullptr, 2, x.buffer(), x.shapeInfo(), nullptr, nullptr, nullptr,
|
|
||||||
y.buffer(), y.shapeInfo(), nullptr, nullptr,
|
y.buffer(), y.shapeInfo(), nullptr, nullptr,
|
||||||
z.buffer(), z.shapeInfo(), nullptr, nullptr,
|
z.buffer(), z.shapeInfo(), nullptr, nullptr,
|
||||||
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
|
dims.buffer(), dims.shapeInfo(), dims.specialBuffer(), dims.specialShapeInfo(), nullptr, nullptr, nullptr, nullptr);
|
||||||
|
@ -764,10 +726,8 @@ TEST_F(JavaInteropTests, Test_SimpleIf_Output) {
|
||||||
Environment::getInstance()->setDebug(true);
|
Environment::getInstance()->setDebug(true);
|
||||||
Environment::getInstance()->setVerbose(false);
|
Environment::getInstance()->setVerbose(false);
|
||||||
|
|
||||||
NativeOps ops;
|
|
||||||
|
|
||||||
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
|
auto pl = nd4j::graph::readFlatBuffers("./resources/simpleif_0_1.fb");
|
||||||
auto ptr = ops.executeFlatGraph(nullptr, pl);
|
auto ptr = executeFlatGraph(nullptr, pl);
|
||||||
|
|
||||||
Environment::getInstance()->setDebug(false);
|
Environment::getInstance()->setDebug(false);
|
||||||
Environment::getInstance()->setVerbose(false);
|
Environment::getInstance()->setVerbose(false);
|
||||||
|
@ -792,9 +752,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) {
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
|
||||||
|
@ -818,9 +777,8 @@ TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) {
|
||||||
|
|
||||||
nd4j::ops::maxpool2d op;
|
nd4j::ops::maxpool2d op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -843,9 +801,8 @@ TEST_F(JavaInteropTests, Test_Unstack_1) {
|
||||||
|
|
||||||
nd4j::ops::unstack op;
|
nd4j::ops::unstack op;
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -864,9 +821,8 @@ TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) {
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||||
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
auto hash = op.getOpHash();
|
auto hash = op.getOpHash();
|
||||||
auto status = nativeOps.execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
|
||||||
|
@ -883,8 +839,7 @@ TEST_F(JavaInteropTests, Test_Mixed_Add_1) {
|
||||||
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<double>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
auto arrayE = NDArrayFactory::create<double>({2, 4, 6, 8});
|
||||||
|
|
||||||
NativeOps ops;
|
execPairwiseTransform(nullptr, pairwise::Add,
|
||||||
ops.execPairwiseTransform(nullptr, pairwise::Add,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
|
arrayY.buffer(), arrayY.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
|
@ -898,7 +853,6 @@ TEST_F(JavaInteropTests, Test_Add_1) {
|
||||||
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
auto y = NDArrayFactory::create<int>('c', {5}, {1, 1, 1, 1, 1});
|
||||||
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
auto e = NDArrayFactory::create<int>('c', {5}, {2, 2, 2, 2, 2});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), y.getBuffer()};
|
||||||
|
@ -907,7 +861,7 @@ TEST_F(JavaInteropTests, Test_Add_1) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.getShapeInfo()};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_EQ(e, x);
|
ASSERT_EQ(e, x);
|
||||||
}
|
}
|
||||||
|
@ -920,7 +874,6 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
||||||
|
|
||||||
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
auto e = NDArrayFactory::create<double>('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::zeta op;
|
nd4j::ops::zeta op;
|
||||||
|
|
||||||
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
|
Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.getBuffer(), q.getBuffer()};
|
||||||
|
@ -929,7 +882,7 @@ TEST_F(JavaInteropTests, zeta_test10) {
|
||||||
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.getBuffer()};
|
||||||
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.getShapeInfo()};
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
}
|
}
|
||||||
|
@ -939,8 +892,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1) {
|
||||||
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<bool>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
auto arrayE = NDArrayFactory::create<bool>({0, 1, 0, 0});
|
||||||
|
|
||||||
NativeOps ops;
|
execTransformAny(nullptr, transform::IsMax,
|
||||||
ops.execTransformAny(nullptr, transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
nullptr);
|
nullptr);
|
||||||
|
@ -953,8 +905,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_1_2) {
|
||||||
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
auto arrayZ = NDArrayFactory::create<float>({0, 0, 0, 0});
|
||||||
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
auto arrayE = NDArrayFactory::create<float>({0, 1, 0, 0});
|
||||||
|
|
||||||
NativeOps ops;
|
execTransformAny(nullptr, transform::IsMax,
|
||||||
ops.execTransformAny(nullptr, transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
nullptr);
|
nullptr);
|
||||||
|
@ -970,8 +921,7 @@ TEST_F(JavaInteropTests, Test_Is_Max_2) {
|
||||||
Nd4jLong *ex[] = {tad, off};
|
Nd4jLong *ex[] = {tad, off};
|
||||||
float ea[] = {2, 1, 2};
|
float ea[] = {2, 1, 2};
|
||||||
|
|
||||||
NativeOps ops;
|
execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
||||||
ops.execTransformBool(reinterpret_cast<void **>(ex), transform::IsMax,
|
|
||||||
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
arrayX.buffer(), arrayX.shapeInfo(), nullptr, nullptr,
|
||||||
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
arrayZ.buffer(), arrayZ.shapeInfo(), nullptr, nullptr,
|
||||||
ea);
|
ea);
|
||||||
|
@ -995,8 +945,7 @@ TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) {
|
||||||
|
|
||||||
|
|
||||||
nd4j::ops::greater_equal op;
|
nd4j::ops::greater_equal op;
|
||||||
NativeOps ops;
|
auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||||
auto shapeList = ops.calculateOutputShapes(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0);
|
|
||||||
|
|
||||||
delete shapeList;
|
delete shapeList;
|
||||||
}
|
}
|
||||||
|
@ -1013,8 +962,7 @@ TEST_F(JavaInteropTests, Test_L2_Loss_3) {
|
||||||
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
Nd4jPointer ptrsOutShapes[] = {reinterpret_cast<Nd4jPointer>(z.shapeInfo())};
|
||||||
|
|
||||||
nd4j::ops::l2_loss op;
|
nd4j::ops::l2_loss op;
|
||||||
NativeOps ops;
|
auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
||||||
auto status = ops.execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false);
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
z.printIndexedBuffer("z");
|
z.printIndexedBuffer("z");
|
||||||
|
@ -1036,9 +984,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_3) {
|
||||||
|
|
||||||
ASSERT_EQ(2, ctx.width());
|
ASSERT_EQ(2, ctx.width());
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::add op;
|
nd4j::ops::add op;
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(exp, z);
|
ASSERT_EQ(exp, z);
|
||||||
}
|
}
|
||||||
|
@ -1054,9 +1001,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_4) {
|
||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
ctx.setIArguments(iArgs, 3);
|
ctx.setIArguments(iArgs, 3);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::tri op;
|
nd4j::ops::tri op;
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(exp, z);
|
ASSERT_EQ(exp, z);
|
||||||
}
|
}
|
||||||
|
@ -1074,9 +1020,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_5) {
|
||||||
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo());
|
||||||
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
|
ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo());
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::matmul op;
|
nd4j::ops::matmul op;
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
@ -1104,9 +1049,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_6) {
|
||||||
|
|
||||||
ctx.setIArguments(iArgs, 3);
|
ctx.setIArguments(iArgs, 3);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::matmul_bp op;
|
nd4j::ops::matmul_bp op;
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
@ -1122,7 +1066,6 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
||||||
|
|
||||||
ctx.setIArguments(iArgs, 1);
|
ctx.setIArguments(iArgs, 1);
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
nd4j::ops::concat op;
|
nd4j::ops::concat op;
|
||||||
|
|
||||||
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo());
|
||||||
|
@ -1130,7 +1073,7 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
||||||
|
|
||||||
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
auto status = nativeOps.execCustomOp(nullptr, op.getOpHash(), &ctx);
|
auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx);
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
ASSERT_EQ(e, z);
|
ASSERT_EQ(e, z);
|
||||||
|
@ -1138,10 +1081,8 @@ TEST_F(JavaInteropTests, Test_Fastpath_7) {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
NativeOps ops;
|
|
||||||
|
|
||||||
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");
|
||||||
auto ptr = ops.executeFlatGraph(nullptr, pl);
|
auto ptr = executeFlatGraph(nullptr, pl);
|
||||||
|
|
||||||
// at this point we have FlatResults
|
// at this point we have FlatResults
|
||||||
auto flatResult = GetFlatResult(ptr->pointer());
|
auto flatResult = GetFlatResult(ptr->pointer());
|
||||||
|
@ -1190,8 +1131,6 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
|
// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) {
|
||||||
// NativeOps ops;
|
|
||||||
|
|
||||||
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
|
// std::array<float, 60> syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f};
|
||||||
// std::array<float, 60> syn1;
|
// std::array<float, 60> syn1;
|
||||||
// std::array<float, 100000> exp;
|
// std::array<float, 100000> exp;
|
||||||
|
@ -1283,5 +1222,5 @@ TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
|
||||||
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
|
// ptrptr[idx+2] = reinterpret_cast<void*>(exp.data());
|
||||||
|
|
||||||
|
|
||||||
// ops.execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data());
|
||||||
// }
|
// }
|
|
@ -53,8 +53,7 @@ TEST_F(LegacyOpsCudaTests, test_sortTad_1) {
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
x.syncToDevice();
|
x.syncToDevice();
|
||||||
NativeOps nativeOps;
|
sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
|
||||||
nativeOps.sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false);
|
|
||||||
x.tickWriteDevice();
|
x.tickWriteDevice();
|
||||||
|
|
||||||
ASSERT_EQ(e, x);
|
ASSERT_EQ(e, x);
|
||||||
|
|
|
@ -501,8 +501,7 @@ TEST_F(LegacyOpsTests, Reduce3_2) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execReduce3Tad(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineSimilarity, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
|
||||||
nullptr, nullptr, nullptr, nullptr);
|
nullptr, nullptr, nullptr, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -517,9 +516,8 @@ TEST_F(LegacyOpsTests, Reduce3_3) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
@ -543,9 +541,8 @@ TEST_F(LegacyOpsTests, Reduce3_4) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
@ -569,9 +566,8 @@ TEST_F(LegacyOpsTests, Reduce3_5) {
|
||||||
|
|
||||||
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
auto dim = NDArrayFactory::create<int>('c', {1}, {1});
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance,
|
execReduce3Tad(nullptr, reduce3::CosineDistance,
|
||||||
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
nullptr,
|
nullptr,
|
||||||
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
|
@ -593,8 +589,7 @@ TEST_F(LegacyOpsTests, test_Reduce3_All_1) {
|
||||||
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
|
auto tadPackX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(x.shapeInfo(), -1);
|
||||||
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
|
auto tadPackY = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(y.shapeInfo(), -1);
|
||||||
|
|
||||||
NativeOps ops;
|
execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
||||||
ops.execReduce3All(nullptr, reduce3::EuclideanDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(),
|
|
||||||
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(),
|
||||||
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(),
|
||||||
|
|
|
@ -33,8 +33,6 @@ public:
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
// just 10GB
|
// just 10GB
|
||||||
Nd4jLong size = 100000L;
|
Nd4jLong size = 100000L;
|
||||||
|
|
||||||
|
@ -43,11 +41,11 @@ TEST_F(MmapTests, Test_Basic_Mmap_1) {
|
||||||
ofs.write("", 1);
|
ofs.write("", 1);
|
||||||
ofs.close();
|
ofs.close();
|
||||||
|
|
||||||
auto result = nativeOps.mmapFile(nullptr, "file", size);
|
auto result = mmapFile(nullptr, "file", size);
|
||||||
|
|
||||||
ASSERT_FALSE(result == nullptr);
|
ASSERT_FALSE(result == nullptr);
|
||||||
|
|
||||||
nativeOps.munmapFile(nullptr, result, size);
|
munmapFile(nullptr, result, size);
|
||||||
|
|
||||||
remove("file");
|
remove("file");
|
||||||
}
|
}
|
|
@ -2258,7 +2258,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) {
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {5, 8});
|
auto z = NDArrayFactory::create<float>('c', {5, 8});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(4);
|
std::vector<void*> buffers(4);
|
||||||
|
@ -2272,7 +2271,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||||
}
|
}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat result linear");
|
z.printBuffer("C Concat result linear");
|
||||||
|
@ -2281,7 +2280,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_1) {
|
||||||
|
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
auto x = NDArrayFactory::create<float>('c', {5,2}, {0,1,2,3,4,5,6,7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('f', {5, 8});
|
auto z = NDArrayFactory::create<float>('f', {5, 8});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(4);
|
std::vector<void*> buffers(4);
|
||||||
|
@ -2295,7 +2293,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||||
}
|
}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 1, 4, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("F Concat result linear");
|
z.printBuffer("F Concat result linear");
|
||||||
|
@ -2304,7 +2302,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_2) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('f', {3, 3});
|
auto z = NDArrayFactory::create<float>('f', {3, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
|
@ -2321,7 +2318,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||||
//}
|
//}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("F Concat result linear");
|
z.printBuffer("F Concat result linear");
|
||||||
|
@ -2331,7 +2328,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_3) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
auto y = NDArrayFactory::create<float>('c', {1,3}, {7,8,9});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {3, 3});
|
auto z = NDArrayFactory::create<float>('c', {3, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
|
@ -2348,7 +2344,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||||
//}
|
//}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat result linear");
|
z.printBuffer("C Concat result linear");
|
||||||
|
@ -2358,7 +2354,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_4) {
|
||||||
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
|
auto x = NDArrayFactory::create<float>('c', {1,2,3}, {1,2,3,4,5,6});
|
||||||
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
|
auto y = NDArrayFactory::create<float>('c', {1,2,3}, {7,8,9,10,11, 12});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
|
auto z = NDArrayFactory::create<float>('c', {2, 2, 3});
|
||||||
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(2);
|
std::vector<void*> buffers(2);
|
||||||
|
@ -2375,7 +2370,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_5) {
|
||||||
//}
|
//}
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 2, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat result linear");
|
z.printBuffer("C Concat result linear");
|
||||||
|
@ -2385,7 +2380,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
||||||
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
|
auto x1 = NDArrayFactory::create<float>('c', {2,2,3}, {1,2,3,4,5,6,7,8, 9, 10,11,12});
|
||||||
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
|
auto x2 = NDArrayFactory::create<float>('c', {1,2,3}, {13,14,15,16,17, 18});
|
||||||
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
|
auto x3 = NDArrayFactory::create<float>('c', {1,2,3}, {19,20,21,22,23, 24});
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
|
auto z = NDArrayFactory::create<float>('c', {4, 2, 3});
|
||||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(3);
|
std::vector<void*> buffers(3);
|
||||||
|
@ -2406,7 +2400,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_6) {
|
||||||
printf("The third array is %p\n", buffers[2]);
|
printf("The third array is %p\n", buffers[2]);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat3D result linear");
|
z.printBuffer("C Concat3D result linear");
|
||||||
|
@ -2417,7 +2411,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
||||||
auto x1 = NDArrayFactory::create<float>(1);
|
auto x1 = NDArrayFactory::create<float>(1);
|
||||||
auto x2 = NDArrayFactory::create<float>(2);
|
auto x2 = NDArrayFactory::create<float>(2);
|
||||||
auto x3 = NDArrayFactory::create<float>(3);
|
auto x3 = NDArrayFactory::create<float>(3);
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
|
auto z = NDArrayFactory::create<float>('c', {3}, {1,2,3});
|
||||||
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = x1.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(3);
|
std::vector<void*> buffers(3);
|
||||||
|
@ -2438,7 +2431,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_7) {
|
||||||
printf("The third array is %p\n", buffers[2]);
|
printf("The third array is %p\n", buffers[2]);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 3, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
z.printIndexedBuffer("Concat result");
|
z.printIndexedBuffer("Concat result");
|
||||||
z.printBuffer("C Concat scalar result linear");
|
z.printBuffer("C Concat scalar result linear");
|
||||||
|
@ -2462,7 +2455,6 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
||||||
lx[i].assign(i);
|
lx[i].assign(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
NativeOps native;
|
|
||||||
auto z = NDArrayFactory::create<float>('c', {totalCount, width});
|
auto z = NDArrayFactory::create<float>('c', {totalCount, width});
|
||||||
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = nd4j::LaunchContext ::defaultContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
std::vector<void*> buffers(totalCount);
|
std::vector<void*> buffers(totalCount);
|
||||||
|
@ -2478,7 +2470,7 @@ TEST_F(NDArrayCudaBasicsTests, Test_ConcatNative_8) {
|
||||||
printf("The third array is %p\n", buffers[2]);
|
printf("The third array is %p\n", buffers[2]);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
extra[1] = *stream;
|
extra[1] = *stream;
|
||||||
native.concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, totalCount, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
z.syncToHost();
|
z.syncToHost();
|
||||||
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
|
nd4j_printf("%f %f %f\n", z.e<float>(0), z.e<float>(width * totalCount / 2), z.e<float>(width * (totalCount - 1)));
|
||||||
//z.printIndexedBuffer("Concat result");
|
//z.printIndexedBuffer("Concat result");
|
||||||
|
@ -2496,7 +2488,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
||||||
arrays.emplace_back(input);
|
arrays.emplace_back(input);
|
||||||
}
|
}
|
||||||
auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
|
auto z = NDArrayFactory::create<float>('c', {total, 10, 10});
|
||||||
NativeOps native;
|
|
||||||
|
|
||||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
|
@ -2512,7 +2503,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_1) {
|
||||||
hostShapes[i] = arrays[i].shapeInfo();
|
hostShapes[i] = arrays[i].shapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
native.concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, total, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
nd4j::ops::tear op;
|
nd4j::ops::tear op;
|
||||||
|
|
||||||
auto result = op.execute({&z}, {}, {1, 2});
|
auto result = op.execute({&z}, {}, {1, 2});
|
||||||
|
@ -2536,7 +2527,6 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
||||||
arrays.emplace_back(input);
|
arrays.emplace_back(input);
|
||||||
}
|
}
|
||||||
auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
|
auto z = NDArrayFactory::create<float>('c', {10, 10, 10});
|
||||||
NativeOps native;
|
|
||||||
|
|
||||||
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
auto stream = input.getContext()->getCudaStream();//reinterpret_cast<cudaStream_t *>(&nativeStream);
|
||||||
Nd4jPointer extra[2];
|
Nd4jPointer extra[2];
|
||||||
|
@ -2552,7 +2542,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
||||||
hostShapes[i] = arrays[i].shapeInfo();
|
hostShapes[i] = arrays[i].shapeInfo();
|
||||||
}
|
}
|
||||||
std::vector<int> dimsToExclude({1,2});
|
std::vector<int> dimsToExclude({1,2});
|
||||||
native.concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
::concat(extra, 0, 10, nullptr, (Nd4jPointer*)hostShapes.data(), (Nd4jPointer*)buffers.data(), (Nd4jPointer*)shapes.data(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr, nullptr);
|
||||||
// z.syncToHost();
|
// z.syncToHost();
|
||||||
// z.printBuffer("Pile OK");
|
// z.printBuffer("Pile OK");
|
||||||
// z.printIndexedBuffer("Pile 10x10");
|
// z.printIndexedBuffer("Pile 10x10");
|
||||||
|
@ -2569,7 +2559,7 @@ TEST_F(NDArrayCudaBasicsTests, TestTear_2) {
|
||||||
Nd4jPointer target = arrays[i].specialBuffer();
|
Nd4jPointer target = arrays[i].specialBuffer();
|
||||||
cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice);
|
cudaMemcpy(&arraysData[i], &target, sizeof(Nd4jPointer), cudaMemcpyHostToDevice);
|
||||||
}
|
}
|
||||||
native.tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
::tear(extra, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), arraysData, input.specialShapeInfo(), packX.specialShapeInfo(), packX.specialOffsets());
|
||||||
// auto result = op.execute({&z}, {}, {1, 2});
|
// auto result = op.execute({&z}, {}, {1, 2});
|
||||||
// nd4j_printf("Result count is %lu\n", result->size());
|
// nd4j_printf("Result count is %lu\n", result->size());
|
||||||
//ASSERT_EQ(10, result->size());
|
//ASSERT_EQ(10, result->size());
|
||||||
|
|
|
@ -313,12 +313,10 @@ TEST_F(PlaygroundTests, test_reduce_3) {
|
||||||
Nd4jLong max = 0L;
|
Nd4jLong max = 0L;
|
||||||
Nd4jLong min = DataTypeUtils::max<Nd4jLong>();
|
Nd4jLong min = DataTypeUtils::max<Nd4jLong>();
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
auto timeStart = std::chrono::system_clock::now();
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
nativeOps.execReduce3(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
|
execReduce3Tad(nullptr, reduce3::CosineDistance, x.buffer(), x.shapeInfo(), x.specialBuffer(),
|
||||||
x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(),
|
x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(),
|
||||||
y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(),
|
||||||
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr,
|
dim.buffer(), dim.shapeInfo(), dim.specialBuffer(), dim.specialShapeInfo(), nullptr,
|
||||||
|
@ -964,8 +962,6 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
||||||
auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count();
|
auto legacyPermTime = std::chrono::duration_cast<std::chrono::microseconds> (legacyPermEnd - legacyPermStart).count();
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
|
||||||
|
|
||||||
Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0};
|
Nd4jLong iArgs[] = {kH, kW, sH, sW, pH, pW, dH, dW, 0};
|
||||||
Nd4jPointer inputBuffers[] = {input.buffer()};
|
Nd4jPointer inputBuffers[] = {input.buffer()};
|
||||||
Nd4jPointer inputShapes[] = {input.shapeInfo()};
|
Nd4jPointer inputShapes[] = {input.shapeInfo()};
|
||||||
|
@ -976,7 +972,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
||||||
auto javaStart = std::chrono::system_clock::now();
|
auto javaStart = std::chrono::system_clock::now();
|
||||||
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputBuffers, outputShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto javaEnd = std::chrono::system_clock::now();
|
auto javaEnd = std::chrono::system_clock::now();
|
||||||
|
@ -990,7 +986,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_1) {
|
||||||
|
|
||||||
|
|
||||||
for (int e = 0; e < iterations; e++) {
|
for (int e = 0; e < iterations; e++) {
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto javaPermEnd = std::chrono::system_clock::now();
|
auto javaPermEnd = std::chrono::system_clock::now();
|
||||||
|
@ -1020,9 +1016,7 @@ TEST_F(PlaygroundTests, Test_Im2Col_2) {
|
||||||
Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()};
|
Nd4jPointer outputPermBuffers[] = {outputPermuted.buffer()};
|
||||||
Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()};
|
Nd4jPointer outputPermShapes[] = {outputPermuted.shapeInfo()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
||||||
|
|
||||||
nativeOps.execCustomOp(nullptr, op.getOpHash(), inputBuffers, inputShapes, 1, outputPermBuffers, outputPermShapes, 1, nullptr, 0, iArgs, 9, nullptr, 0, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PlaygroundTests, Test_Col2Im_1) {
|
TEST_F(PlaygroundTests, Test_Col2Im_1) {
|
||||||
|
@ -1140,8 +1134,6 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
||||||
int length = (int) array->lengthOf();
|
int length = (int) array->lengthOf();
|
||||||
int span = (int) (array->lengthOf() / 6) + 8;
|
int span = (int) (array->lengthOf() / 6) + 8;
|
||||||
|
|
||||||
NativeOps ops;
|
|
||||||
|
|
||||||
auto t = new int[1000000];
|
auto t = new int[1000000];
|
||||||
|
|
||||||
|
|
||||||
|
@ -1150,7 +1142,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
||||||
FloatBits fb;
|
FloatBits fb;
|
||||||
float threshold = 0.99f;
|
float threshold = 0.99f;
|
||||||
fb.f_ = threshold;
|
fb.f_ = threshold;
|
||||||
int le = ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
int le = estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||||
|
|
||||||
t[0] = le;
|
t[0] = le;
|
||||||
t[1] = length;
|
t[1] = length;
|
||||||
|
@ -1162,7 +1154,7 @@ TEST_F(PlaygroundTests, loop_test_1) {
|
||||||
|
|
||||||
for (int x = 0; x < iterations; x++) {
|
for (int x = 0; x < iterations; x++) {
|
||||||
auto permStart = std::chrono::system_clock::now();
|
auto permStart = std::chrono::system_clock::now();
|
||||||
ops.estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
estimateThreshold(nullptr, reinterpret_cast<void *>(array->buffer()), array->shapeInfo(), static_cast<int>(array->lengthOf()), threshold);
|
||||||
TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t);
|
TypeCast::convertToThreshold<float>(nullptr, buffer, array->lengthOf(), t);
|
||||||
|
|
||||||
auto permEnd = std::chrono::system_clock::now();
|
auto permEnd = std::chrono::system_clock::now();
|
||||||
|
|
|
@ -29,7 +29,6 @@ using namespace nd4j;
|
||||||
|
|
||||||
class RNGTests : public testing::Test {
|
class RNGTests : public testing::Test {
|
||||||
private:
|
private:
|
||||||
NativeOps nativeOps;
|
|
||||||
//Nd4jLong *_bufferA;
|
//Nd4jLong *_bufferA;
|
||||||
//Nd4jLong *_bufferB;
|
//Nd4jLong *_bufferB;
|
||||||
|
|
||||||
|
@ -47,8 +46,8 @@ public:
|
||||||
RNGTests() {
|
RNGTests() {
|
||||||
//_bufferA = new Nd4jLong[100000];
|
//_bufferA = new Nd4jLong[100000];
|
||||||
//_bufferB = new Nd4jLong[100000];
|
//_bufferB = new Nd4jLong[100000];
|
||||||
//_rngA = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
//_rngA = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA);
|
||||||
//_rngB = (nd4j::random::RandomBuffer *) nativeOps.initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
|
//_rngB = (nd4j::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferB);
|
||||||
_rngA.setStates(_seed, _seed);
|
_rngA.setStates(_seed, _seed);
|
||||||
_rngB.setStates(_seed, _seed);
|
_rngB.setStates(_seed, _seed);
|
||||||
nexp0->assign(-1.0f);
|
nexp0->assign(-1.0f);
|
||||||
|
@ -57,8 +56,8 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
~RNGTests() {
|
~RNGTests() {
|
||||||
//nativeOps.destroyRandom(_rngA);
|
//destroyRandom(_rngA);
|
||||||
//nativeOps.destroyRandom(_rngB);
|
//destroyRandom(_rngB);
|
||||||
//delete[] _bufferA;
|
//delete[] _bufferA;
|
||||||
//delete[] _bufferB;
|
//delete[] _bufferB;
|
||||||
|
|
||||||
|
@ -791,14 +790,13 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Reproducibility_9) {
|
TEST_F(RNGTests, Test_Reproducibility_9) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||||
const int bufferSize = 10000;
|
const int bufferSize = 10000;
|
||||||
int64_t buffer[bufferSize];
|
int64_t buffer[bufferSize];
|
||||||
|
|
||||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||||
|
|
||||||
const int length = 4000000;
|
const int length = 4000000;
|
||||||
int *arrayE = new int[length];
|
int *arrayE = new int[length];
|
||||||
|
@ -809,7 +807,7 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
|
||||||
|
|
||||||
rng->rewindH(static_cast<Nd4jLong>(length));
|
rng->rewindH(static_cast<Nd4jLong>(length));
|
||||||
|
|
||||||
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||||
|
|
||||||
for (int e = 0; e < length; e++)
|
for (int e = 0; e < length; e++)
|
||||||
arrayT[e] = rng->relativeInt(e);
|
arrayT[e] = rng->relativeInt(e);
|
||||||
|
@ -825,18 +823,17 @@ TEST_F(RNGTests, Test_Reproducibility_9) {
|
||||||
delete[] arrayE;
|
delete[] arrayE;
|
||||||
delete[] arrayT;
|
delete[] arrayT;
|
||||||
|
|
||||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Reproducibility_8) {
|
TEST_F(RNGTests, Test_Reproducibility_8) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<int> shape = {32, 3, 28, 28};
|
std::vector<int> shape = {32, 3, 28, 28};
|
||||||
const int bufferSize = 10000;
|
const int bufferSize = 10000;
|
||||||
int64_t buffer[bufferSize];
|
int64_t buffer[bufferSize];
|
||||||
|
|
||||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||||
|
|
||||||
const int length = 4000000;
|
const int length = 4000000;
|
||||||
int *arrayE = new int[length];
|
int *arrayE = new int[length];
|
||||||
|
@ -847,7 +844,7 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
|
||||||
|
|
||||||
rng->rewindH(static_cast<Nd4jLong>(length));
|
rng->rewindH(static_cast<Nd4jLong>(length));
|
||||||
|
|
||||||
ops.refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
refreshBuffer(nullptr, seed, reinterpret_cast<Nd4jPointer>(rng));
|
||||||
|
|
||||||
for (int e = 0; e < length; e++)
|
for (int e = 0; e < length; e++)
|
||||||
arrayT[e] = static_cast<int>(rng->relativeT<float>(e));
|
arrayT[e] = static_cast<int>(rng->relativeT<float>(e));
|
||||||
|
@ -863,29 +860,27 @@ TEST_F(RNGTests, Test_Reproducibility_8) {
|
||||||
delete[] arrayE;
|
delete[] arrayE;
|
||||||
delete[] arrayT;
|
delete[] arrayT;
|
||||||
|
|
||||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_RandomBuffer_Half_1) {
|
TEST_F(RNGTests, Test_RandomBuffer_Half_1) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||||
const int bufferSize = 10000;
|
const int bufferSize = 10000;
|
||||||
int64_t buffer[bufferSize];
|
int64_t buffer[bufferSize];
|
||||||
|
|
||||||
auto rng = (nd4j::random::RandomBuffer *) ops.initRandom(nullptr, seed, bufferSize, buffer);
|
auto rng = (nd4j::random::RandomBuffer *) initRandom(nullptr, seed, bufferSize, buffer);
|
||||||
|
|
||||||
auto r0 = rng->relativeT<float16>(12L);
|
auto r0 = rng->relativeT<float16>(12L);
|
||||||
auto r1 = rng->relativeT<float16>(13L);
|
auto r1 = rng->relativeT<float16>(13L);
|
||||||
|
|
||||||
ASSERT_NE(r0, r1);
|
ASSERT_NE(r0, r1);
|
||||||
|
|
||||||
ops.destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
destroyRandom(reinterpret_cast<Nd4jPointer>(rng));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RNGTests, Test_Reproducibility_1) {
|
TEST_F(RNGTests, Test_Reproducibility_1) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
std::vector<Nd4jLong> shape = {32, 3, 28, 28};
|
||||||
|
@ -918,7 +913,6 @@ TEST_F(RNGTests, Test_Reproducibility_1) {
|
||||||
|
|
||||||
#ifndef DEBUG_BUILD
|
#ifndef DEBUG_BUILD
|
||||||
TEST_F(RNGTests, Test_Reproducibility_2) {
|
TEST_F(RNGTests, Test_Reproducibility_2) {
|
||||||
NativeOps ops;
|
|
||||||
Nd4jLong seed = 123;
|
Nd4jLong seed = 123;
|
||||||
|
|
||||||
std::vector<Nd4jLong> shape = {32, 3, 64, 64};
|
std::vector<Nd4jLong> shape = {32, 3, 64, 64};
|
||||||
|
|
|
@ -44,8 +44,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_key_1) {
|
||||||
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
@ -62,8 +61,7 @@ TEST_F(SortCpuTests, test_linear_sort_by_val_1) {
|
||||||
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
auto ev = NDArrayFactory::create<double>('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5});
|
||||||
|
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
@ -81,8 +79,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_key_1) {
|
||||||
|
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
@ -100,8 +97,7 @@ TEST_F(SortCpuTests, test_tad_sort_by_val_1) {
|
||||||
|
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
|
|
||||||
ASSERT_EQ(ek, k);
|
ASSERT_EQ(ek, k);
|
||||||
ASSERT_EQ(ev, v);
|
ASSERT_EQ(ev, v);
|
||||||
|
|
|
@ -42,8 +42,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_key_1) {
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
|
@ -60,8 +59,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_1) {
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
||||||
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
|
@ -78,8 +76,7 @@ TEST_F(SortCudaTests, test_linear_sort_by_val_2) {
|
||||||
|
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
NativeOps nativeOps;
|
sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
||||||
nativeOps.sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
k.printIndexedBuffer("KEYS");
|
k.printIndexedBuffer("KEYS");
|
||||||
|
@ -97,8 +94,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_key_1) {
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
|
@ -119,8 +115,7 @@ TEST_F(SortCudaTests, test_tad_sort_by_val_1) {
|
||||||
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()};
|
||||||
|
|
||||||
int axis = 1;
|
int axis = 1;
|
||||||
NativeOps nativeOps;
|
sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
||||||
nativeOps.sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false);
|
|
||||||
k.tickWriteDevice();
|
k.tickWriteDevice();
|
||||||
v.tickWriteDevice();
|
v.tickWriteDevice();
|
||||||
|
|
||||||
|
|
|
@ -58,8 +58,7 @@ TEST_F(TypeCastTests, Test_ConvertDtype_1) {
|
||||||
float16 dst[5];
|
float16 dst[5];
|
||||||
float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f};
|
float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f};
|
||||||
|
|
||||||
NativeOps ops;
|
convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
|
||||||
ops.convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst);
|
|
||||||
|
|
||||||
for (int e = 0; e < 5; e++)
|
for (int e = 0; e < 5; e++)
|
||||||
ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f);
|
ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f);
|
||||||
|
|
Loading…
Reference in New Issue