Create C wrappers for some of the C++ classes currently used by ND4J
parent
8881bfe7aa
commit
526b782e51
|
@ -882,6 +882,8 @@ ND4J_EXPORT void enableVerboseMode(bool reallyEnable);
|
|||
*/
|
||||
ND4J_EXPORT void setGridLimit(int gridSize);
|
||||
|
||||
typedef nd4j::TadPack OpaqueTadPack;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param xShapeInfo
|
||||
|
@ -890,10 +892,19 @@ ND4J_EXPORT void setGridLimit(int gridSize);
|
|||
* @param targetBuffer
|
||||
* @param offsetsBuffer
|
||||
*/
|
||||
ND4J_EXPORT nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo,
|
||||
ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo,
|
||||
int *dimension,
|
||||
int dimensionLength);
|
||||
|
||||
ND4J_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack);
|
||||
ND4J_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack);
|
||||
ND4J_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack);
|
||||
ND4J_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack);
|
||||
ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack);
|
||||
ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack);
|
||||
|
||||
ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
|
||||
|
||||
/*
|
||||
* PullRow special op
|
||||
*/
|
||||
|
@ -1639,10 +1650,13 @@ ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName,
|
|||
|
||||
ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length);
|
||||
|
||||
typedef nd4j::graph::ResultWrapper OpaqueResultWrapper;
|
||||
|
||||
// flatbuffers execution
|
||||
ND4J_EXPORT nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer);
|
||||
ND4J_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer);
|
||||
|
||||
ND4J_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr);
|
||||
ND4J_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr);
|
||||
|
||||
ND4J_EXPORT const char* getAllCustomOps();
|
||||
|
||||
|
@ -1652,14 +1666,31 @@ ND4J_EXPORT const char* getAllOperations();
|
|||
ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace);
|
||||
ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext);
|
||||
|
||||
ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
||||
ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
||||
typedef nd4j::ShapeList OpaqueShapeList;
|
||||
|
||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
||||
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
||||
|
||||
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
|
||||
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
|
||||
|
||||
ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
|
||||
|
||||
ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer);
|
||||
|
||||
ND4J_EXPORT nd4j::graph::VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
|
||||
typedef nd4j::graph::VariablesSet OpaqueVariableSet;
|
||||
typedef nd4j::graph::Variable OpaqueVariable;
|
||||
|
||||
ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
|
||||
|
||||
ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set);
|
||||
ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set);
|
||||
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i);
|
||||
ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
|
||||
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
|
||||
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
|
||||
ND4J_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable);
|
||||
ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable);
|
||||
|
||||
ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId);
|
||||
|
||||
|
@ -1668,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer);
|
|||
ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer);
|
||||
ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer);
|
||||
|
||||
ND4J_EXPORT void deleteVariablesSet(Nd4jPointer pointer);
|
||||
ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer);
|
||||
|
||||
// GraphState creation
|
||||
ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id);
|
||||
|
@ -1684,7 +1715,9 @@ ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPoi
|
|||
|
||||
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
||||
ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length);
|
||||
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
||||
ND4J_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
||||
ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
||||
ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
||||
|
||||
ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
|
||||
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
|
||||
|
@ -1693,18 +1726,45 @@ ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOf
|
|||
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
|
||||
int* hIindexes, int* dIindexes);
|
||||
|
||||
ND4J_EXPORT void deleteShapeBuffer(Nd4jPointer ptr);
|
||||
ND4J_EXPORT void deleteTadPack(Nd4jPointer ptr);
|
||||
|
||||
ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
|
||||
|
||||
|
||||
ND4J_EXPORT nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
|
||||
typedef nd4j::ConstantDataBuffer OpaqueConstantDataBuffer;
|
||||
|
||||
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length);
|
||||
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
|
||||
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
||||
ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
|
||||
|
||||
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length);
|
||||
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
|
||||
ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
||||
|
||||
ND4J_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf);
|
||||
ND4J_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf);
|
||||
ND4J_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf);
|
||||
ND4J_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf);
|
||||
|
||||
ND4J_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr);
|
||||
|
||||
typedef nd4j::graph::Context OpaqueContext;
|
||||
typedef nd4j::graph::RandomGenerator OpaqueRandomGenerator;
|
||||
|
||||
ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId);
|
||||
ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr);
|
||||
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
|
||||
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
||||
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
||||
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
||||
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
||||
ND4J_EXPORT void deleteGraphContext(OpaqueContext* ptr);
|
||||
|
||||
ND4J_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0);
|
||||
ND4J_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr);
|
||||
ND4J_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr);
|
||||
ND4J_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0);
|
||||
ND4J_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index);
|
||||
ND4J_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index);
|
||||
ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr);
|
||||
|
||||
ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut);
|
||||
ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut);
|
||||
|
|
|
@ -1328,6 +1328,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimen
|
|||
return pack;
|
||||
}
|
||||
|
||||
Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) {
|
||||
return pack->primaryShapeInfo();
|
||||
}
|
||||
Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) {
|
||||
return pack->primaryOffsets();
|
||||
}
|
||||
Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) {
|
||||
return pack->specialShapeInfo();
|
||||
}
|
||||
Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) {
|
||||
return pack->specialOffsets();
|
||||
}
|
||||
Nd4jLong getNumberOfTads(nd4j::TadPack* pack) {
|
||||
return pack->numberOfTads();
|
||||
}
|
||||
int getShapeInfoLength(nd4j::TadPack* pack) {
|
||||
return pack->shapeInfoLength();
|
||||
}
|
||||
|
||||
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||
// no-op
|
||||
return 0L;
|
||||
|
@ -2005,6 +2024,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi
|
|||
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
|
||||
}
|
||||
|
||||
Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) {
|
||||
return ptr->size();
|
||||
}
|
||||
Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) {
|
||||
return ptr->pointer();
|
||||
}
|
||||
|
||||
const char* getAllCustomOps() {
|
||||
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
||||
}
|
||||
|
@ -2041,7 +2067,13 @@ int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong *hXSh
|
|||
BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
Nd4jLong getShapeListSize(nd4j::ShapeList* list) {
|
||||
return list->size();
|
||||
}
|
||||
|
||||
Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) {
|
||||
return list->at(i);
|
||||
}
|
||||
|
||||
void deleteShapeList(Nd4jPointer shapeList) {
|
||||
auto list = reinterpret_cast<nd4j::ShapeList*>(shapeList);
|
||||
|
@ -2305,6 +2337,38 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
|
||||
return set->size();
|
||||
}
|
||||
|
||||
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
|
||||
return set->status();
|
||||
}
|
||||
|
||||
nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) {
|
||||
return set->at(i);
|
||||
}
|
||||
|
||||
int getVariableId(nd4j::graph::Variable* variable) {
|
||||
return variable->id();
|
||||
}
|
||||
|
||||
int getVariableIndex(nd4j::graph::Variable* variable) {
|
||||
return variable->index();
|
||||
}
|
||||
|
||||
const char* getVariableName(nd4j::graph::Variable* variable) {
|
||||
return variable->getName()->c_str();
|
||||
}
|
||||
|
||||
Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) {
|
||||
return variable->getNDArray()->shapeInfo();
|
||||
}
|
||||
|
||||
void* getVariableBuffer(nd4j::graph::Variable* variable) {
|
||||
return variable->getNDArray()->buffer();
|
||||
}
|
||||
|
||||
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
||||
|
||||
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
||||
|
@ -2628,6 +2692,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int
|
|||
return reinterpret_cast<Nd4jPointer>(u);
|
||||
}
|
||||
|
||||
Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||
return reinterpret_cast<nd4j::utf8string*>(ptr)->_length;
|
||||
}
|
||||
char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||
return reinterpret_cast<nd4j::utf8string*>(ptr)->_buffer;
|
||||
}
|
||||
|
||||
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||
delete(reinterpret_cast<nd4j::utf8string*>(ptr));
|
||||
}
|
||||
|
@ -2710,14 +2781,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid
|
|||
return buffer;
|
||||
}
|
||||
|
||||
void deleteShapeBuffer(Nd4jPointer ptr) {
|
||||
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr);
|
||||
delete buffer;
|
||||
void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteTadPack(Nd4jPointer ptr) {
|
||||
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr);
|
||||
delete buffer;
|
||||
void deleteTadPack(nd4j::TadPack* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||
|
@ -2732,6 +2801,78 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes
|
|||
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||
}
|
||||
|
||||
Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->primary();
|
||||
}
|
||||
Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->special();
|
||||
}
|
||||
Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->length();
|
||||
}
|
||||
Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->sizeOf();
|
||||
}
|
||||
|
||||
|
||||
nd4j::graph::Context* createGraphContext(int nodeId) {
|
||||
return new nd4j::graph::Context(nodeId);
|
||||
}
|
||||
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
|
||||
return &ptr->randomGenerator();
|
||||
}
|
||||
void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) {
|
||||
ptr->markInplace(reallyInplace);
|
||||
}
|
||||
void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) {
|
||||
}
|
||||
void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||
ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||
}
|
||||
void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||
ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||
}
|
||||
void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) {
|
||||
ptr->setTArguments(arguments, numberOfArguments);
|
||||
}
|
||||
void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) {
|
||||
ptr->setIArguments(arguments, numberOfArguments);
|
||||
}
|
||||
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
||||
ptr->setBArguments(arguments, numberOfArguments);
|
||||
}
|
||||
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
|
||||
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||
}
|
||||
|
||||
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
|
||||
return ptr->rootState();
|
||||
}
|
||||
|
||||
Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) {
|
||||
return ptr->nodeState();
|
||||
}
|
||||
|
||||
void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||
ptr->setStates(rootSeed, nodeSeed);
|
||||
}
|
||||
|
||||
int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||
return ptr->relativeInt(index);
|
||||
}
|
||||
|
||||
Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||
return ptr->relativeLong(index);
|
||||
}
|
||||
|
||||
void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
|
||||
int dataTypeFromNpyHeader(void *header) {
|
||||
|
|
|
@ -1499,6 +1499,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimen
|
|||
return pack;
|
||||
}
|
||||
|
||||
Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) {
|
||||
return pack->primaryShapeInfo();
|
||||
}
|
||||
Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) {
|
||||
return pack->primaryOffsets();
|
||||
}
|
||||
Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) {
|
||||
return pack->specialShapeInfo();
|
||||
}
|
||||
Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) {
|
||||
return pack->specialOffsets();
|
||||
}
|
||||
Nd4jLong getNumberOfTads(nd4j::TadPack* pack) {
|
||||
return pack->numberOfTads();
|
||||
}
|
||||
int getShapeInfoLength(nd4j::TadPack* pack) {
|
||||
return pack->shapeInfoLength();
|
||||
}
|
||||
|
||||
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
||||
|
||||
|
@ -2533,6 +2552,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi
|
|||
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
|
||||
}
|
||||
|
||||
Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) {
|
||||
return ptr->size();
|
||||
}
|
||||
Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) {
|
||||
return ptr->pointer();
|
||||
}
|
||||
|
||||
|
||||
const char* getAllCustomOps() {
|
||||
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
||||
|
@ -2607,6 +2633,13 @@ nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash
|
|||
return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs);
|
||||
}
|
||||
|
||||
Nd4jLong getShapeListSize(nd4j::ShapeList* list) {
|
||||
return list->size();
|
||||
}
|
||||
|
||||
Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) {
|
||||
return list->at(i);
|
||||
}
|
||||
|
||||
static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
|
||||
if (op == nullptr)
|
||||
|
@ -2775,6 +2808,38 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N
|
|||
return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs);
|
||||
}
|
||||
|
||||
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
|
||||
return set->size();
|
||||
}
|
||||
|
||||
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
|
||||
return set->status();
|
||||
}
|
||||
|
||||
nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) {
|
||||
return set->at(i);
|
||||
}
|
||||
|
||||
int getVariableId(nd4j::graph::Variable* variable) {
|
||||
return variable->id();
|
||||
}
|
||||
|
||||
int getVariableIndex(nd4j::graph::Variable* variable) {
|
||||
return variable->index();
|
||||
}
|
||||
|
||||
const char* getVariableName(nd4j::graph::Variable* variable) {
|
||||
return variable->getName()->c_str();
|
||||
}
|
||||
|
||||
Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) {
|
||||
return variable->getNDArray()->shapeInfo();
|
||||
}
|
||||
|
||||
void* getVariableBuffer(nd4j::graph::Variable* variable) {
|
||||
return variable->getNDArray()->buffer();
|
||||
}
|
||||
|
||||
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
||||
|
||||
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
||||
|
@ -3102,6 +3167,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int
|
|||
return reinterpret_cast<Nd4jPointer>(u);
|
||||
}
|
||||
|
||||
Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||
return reinterpret_cast<nd4j::utf8string*>(ptr)->_length;
|
||||
}
|
||||
char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||
return reinterpret_cast<nd4j::utf8string*>(ptr)->_buffer;
|
||||
}
|
||||
|
||||
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||
delete(reinterpret_cast<nd4j::utf8string*>(ptr));
|
||||
}
|
||||
|
@ -3237,14 +3309,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid
|
|||
return buffer;
|
||||
}
|
||||
|
||||
void deleteShapeBuffer(Nd4jPointer ptr) {
|
||||
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr);
|
||||
delete buffer;
|
||||
void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
void deleteTadPack(Nd4jPointer ptr) {
|
||||
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr);
|
||||
delete buffer;
|
||||
void deleteTadPack(nd4j::TadPack* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||
|
@ -3259,6 +3329,82 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes
|
|||
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||
}
|
||||
|
||||
|
||||
Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->primary();
|
||||
}
|
||||
Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->special();
|
||||
}
|
||||
Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->length();
|
||||
}
|
||||
Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
|
||||
return dbf->sizeOf();
|
||||
}
|
||||
|
||||
|
||||
nd4j::graph::Context* createGraphContext(int nodeId) {
|
||||
return new nd4j::graph::Context(nodeId);
|
||||
}
|
||||
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
|
||||
return &ptr->randomGenerator();
|
||||
}
|
||||
void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) {
|
||||
ptr->markInplace(reallyInplace);
|
||||
}
|
||||
void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) {
|
||||
ptr->setCudaContext(stream, reductionPointer, allocationPointer);
|
||||
}
|
||||
void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||
ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||
}
|
||||
void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||
ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||
}
|
||||
void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) {
|
||||
ptr->setTArguments(arguments, numberOfArguments);
|
||||
}
|
||||
void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) {
|
||||
ptr->setIArguments(arguments, numberOfArguments);
|
||||
}
|
||||
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
||||
ptr->setBArguments(arguments, numberOfArguments);
|
||||
}
|
||||
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
|
||||
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||
}
|
||||
|
||||
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
|
||||
return ptr->rootState();
|
||||
}
|
||||
|
||||
Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) {
|
||||
return ptr->nodeState();
|
||||
}
|
||||
|
||||
void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||
ptr->setStates(rootSeed, nodeSeed);
|
||||
}
|
||||
|
||||
int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||
return ptr->relativeInt(index);
|
||||
}
|
||||
|
||||
Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||
return ptr->relativeLong(index);
|
||||
}
|
||||
|
||||
void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
|
||||
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||
unsigned int shapeSize = arr.shape.size();
|
||||
|
|
|
@ -27,6 +27,9 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <graph/GraphUtils.h>
|
||||
|
||||
using namespace nd4j::ops;
|
||||
using namespace nd4j::graph;
|
||||
|
||||
int
|
||||
main(int argc, char *argv[]) {
|
||||
// this string will contain list of operations
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <ops/declarable/CustomOperations.h>
|
||||
|
||||
using namespace nd4j;
|
||||
using namespace nd4j::ops;
|
||||
using namespace nd4j::graph;
|
||||
|
||||
class OpTrackerTests : public testing::Test {
|
||||
|
|
|
@ -28,7 +28,7 @@ import java.util.List;
|
|||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public interface OpContext {
|
||||
public interface OpContext extends AutoCloseable {
|
||||
|
||||
/**
|
||||
* This method sets integer arguments required for operation
|
||||
|
|
|
@ -38,8 +38,9 @@ import org.nd4j.linalg.api.ops.Op;
|
|||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.ResultWrapperAbstraction;
|
||||
import org.nd4j.nativeblas.OpaqueResultWrapper;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -100,11 +101,12 @@ public class NativeGraphExecutioner implements GraphExecutioner {
|
|||
|
||||
log.info("Buffer length: {}", buffer.limit());
|
||||
|
||||
val res = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraph(null, bPtr);
|
||||
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
OpaqueResultWrapper res = nativeOps.executeFlatGraph(null, bPtr);
|
||||
if (res == null)
|
||||
throw new ND4JIllegalStateException("Graph execution failed");
|
||||
|
||||
PagedPointer pagedPointer = new PagedPointer(res.pointer(),res.size());
|
||||
PagedPointer pagedPointer = new PagedPointer(nativeOps.getResultWrapperPointer(res), nativeOps.getResultWrapperSize(res));
|
||||
FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer());
|
||||
|
||||
log.info("VarMap: {}", sd.variableMap());
|
||||
|
@ -132,7 +134,7 @@ public class NativeGraphExecutioner implements GraphExecutioner {
|
|||
}
|
||||
|
||||
// now we need to release native memory
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteResultWrapper(res);
|
||||
nativeOps.deleteResultWrapper(res);
|
||||
|
||||
return results;
|
||||
}
|
||||
|
|
|
@ -697,7 +697,16 @@ public interface NativeOps {
|
|||
|
||||
void setGridLimit(int gridSize);
|
||||
|
||||
Pointer tadOnlyShapeInfo(@Cast("Nd4jLong *") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
||||
OpaqueTadPack tadOnlyShapeInfo(LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
||||
|
||||
LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
|
||||
LongPointer getPrimaryOffsets(OpaqueTadPack pack);
|
||||
LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
|
||||
LongPointer getSpecialOffsets(OpaqueTadPack pack);
|
||||
long getNumberOfTads(OpaqueTadPack pack);
|
||||
int getShapeInfoLength(OpaqueTadPack pack);
|
||||
|
||||
void deleteTadPack(OpaqueTadPack pointer);
|
||||
|
||||
///////////////
|
||||
|
||||
|
@ -1037,7 +1046,10 @@ public interface NativeOps {
|
|||
|
||||
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);
|
||||
|
||||
ResultWrapperAbstraction executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer);
|
||||
OpaqueResultWrapper executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer);
|
||||
|
||||
long getResultWrapperSize(OpaqueResultWrapper ptr);
|
||||
Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
|
||||
|
||||
String getAllCustomOps();
|
||||
|
||||
|
@ -1047,13 +1059,25 @@ public interface NativeOps {
|
|||
|
||||
int execCustomOp(PointerPointer extraPointers, long opHashCode, PointerPointer inputBuffers, PointerPointer inputShapes, int numInput, PointerPointer outputBuffers, PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, boolean isInplace);
|
||||
|
||||
Pointer calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
|
||||
OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
|
||||
|
||||
Pointer calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
|
||||
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
|
||||
|
||||
long getShapeListSize(OpaqueShapeList list);
|
||||
LongPointer getShape(OpaqueShapeList list, long i);
|
||||
|
||||
int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer);
|
||||
|
||||
Pointer executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
|
||||
long getVariableSetSize(OpaqueVariableSet set);
|
||||
int getVariableSetStatus(OpaqueVariableSet set);
|
||||
OpaqueVariable getVariable(OpaqueVariableSet set, long i);
|
||||
int getVariableId(OpaqueVariable variable);
|
||||
int getVariableIndex(OpaqueVariable variable);
|
||||
String getVariableName(OpaqueVariable variable);
|
||||
LongPointer getVariableShape(OpaqueVariable variable);
|
||||
Pointer getVariableBuffer(OpaqueVariable variable);
|
||||
|
||||
void deleteResultWrapper(Pointer ptr);
|
||||
|
||||
|
@ -1071,15 +1095,11 @@ public interface NativeOps {
|
|||
|
||||
void deleteNPArrayMap(Pointer pointer);
|
||||
|
||||
void deleteVariablesSet(Pointer pointer);
|
||||
void deleteVariablesSet(OpaqueVariableSet pointer);
|
||||
|
||||
// GraphState creation
|
||||
Pointer getGraphState(long id);
|
||||
|
||||
void deleteShapeBuffer(Pointer state);
|
||||
|
||||
void deleteTadPack(Pointer pointer);
|
||||
|
||||
void deleteGraphState(Pointer state);
|
||||
|
||||
int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold);
|
||||
|
@ -1096,6 +1116,8 @@ public interface NativeOps {
|
|||
|
||||
//void fillUtf8String(PointerPointer extraPointers, String[] string, int numStrings, Pointer buffer);
|
||||
Pointer createUtf8String(PointerPointer extraPointers, String string, int length);
|
||||
long getUtf8StringLength(PointerPointer extraPointers, Pointer ptr);
|
||||
BytePointer getUtf8StringBuffer(PointerPointer extraPointers, Pointer ptr);
|
||||
void deleteUtf8String(PointerPointer extraPointers, Pointer ptr);
|
||||
|
||||
|
||||
|
@ -1116,11 +1138,37 @@ public interface NativeOps {
|
|||
*/
|
||||
int dataTypeFromNpyHeader(Pointer numpyHeader);
|
||||
|
||||
Pointer shapeBuffer(int rank, @Cast("Nd4jLong *") LongPointer shape, @Cast("Nd4jLong *") LongPointer strides, int dtype, char order, long ews, boolean empty);
|
||||
OpaqueConstantDataBuffer shapeBuffer(int rank, LongPointer shape, LongPointer strides, int dtype, char order, long ews, boolean empty);
|
||||
|
||||
Pointer constantBufferDouble(int dtype, DoublePointer data, int length);
|
||||
OpaqueConstantDataBuffer constantBufferDouble(int dtype, DoublePointer data, int length);
|
||||
|
||||
Pointer constantBufferLong(int dtype, @Cast("Nd4jLong *") LongPointer data, int length);
|
||||
OpaqueConstantDataBuffer constantBufferLong(int dtype, LongPointer data, int length);
|
||||
|
||||
Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
|
||||
Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
|
||||
long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
|
||||
long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
|
||||
|
||||
void deleteShapeBuffer(OpaqueConstantDataBuffer state);
|
||||
|
||||
OpaqueContext createGraphContext(int nodeId);
|
||||
OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||
void markGraphContextInplace(OpaqueContext ptr, boolean reallyInplace);
|
||||
void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||
void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
||||
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
||||
void deleteGraphContext(OpaqueContext ptr);
|
||||
|
||||
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);
|
||||
long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
|
||||
long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
|
||||
void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||
int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||
long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||
void deleteRandomGenerator(OpaqueRandomGenerator ptr);
|
||||
|
||||
String runLightBenchmarkSuit(boolean printOut);
|
||||
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueConstantDataBuffer extends Pointer {
|
||||
public OpaqueConstantDataBuffer(Pointer p) { super(p); }
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueContext extends Pointer {
|
||||
public OpaqueContext(Pointer p) { super(p); }
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueRandomGenerator extends Pointer {
|
||||
public OpaqueRandomGenerator(Pointer p) { super(p); }
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueResultWrapper extends Pointer {
|
||||
public OpaqueResultWrapper(Pointer p) { super(p); }
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueShapeList extends Pointer {
|
||||
public OpaqueShapeList(Pointer p) { super(p); }
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueTadPack extends Pointer {
|
||||
public OpaqueTadPack(Pointer p) { super(p); }
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueVariable extends Pointer {
|
||||
public OpaqueVariable(Pointer p) { super(p); }
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.nativeblas;
|
||||
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
|
||||
/**
|
||||
*
|
||||
* @author saudet
|
||||
*/
|
||||
public class OpaqueVariableSet extends Pointer {
|
||||
public OpaqueVariableSet(Pointer p) { super(p); }
|
||||
}
|
|
@ -72,6 +72,11 @@ import org.nd4j.nativeblas.LongPointerWrapper;
|
|||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.Nd4jCuda;
|
||||
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
|
||||
import org.nd4j.nativeblas.OpaqueShapeList;
|
||||
import org.nd4j.nativeblas.OpaqueTadPack;
|
||||
import org.nd4j.nativeblas.OpaqueVariable;
|
||||
import org.nd4j.nativeblas.OpaqueVariableSet;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
@ -2208,13 +2213,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
for (val t: op.tArgs())
|
||||
tArgs.put(cnt++, (float) t);
|
||||
|
||||
val ptrptr = (Nd4jCuda.ShapeList) nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
|
||||
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
|
||||
|
||||
if (ptrptr == null)
|
||||
throw new RuntimeException();
|
||||
|
||||
for (int e = 0; e < ptrptr.size(); e++ )
|
||||
result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer()));
|
||||
for (int e = 0; e < nativeOps.getShapeListSize(ptrptr); e++ )
|
||||
result.add(getShapeFromPointer(new PagedPointer(nativeOps.getShape(ptrptr, e)).asLongPointer()));
|
||||
|
||||
nativeOps.deleteShapeList(ptrptr);
|
||||
|
||||
|
@ -2251,7 +2256,8 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||
|
||||
val context = (CudaOpContext) buildContext();
|
||||
val name = op.opName();
|
||||
try (val context = (CudaOpContext) buildContext()) {
|
||||
context.markInplace(op.isInplaceCall());
|
||||
|
||||
// transferring rng state
|
||||
|
@ -2273,6 +2279,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
|
||||
|
||||
return result;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Op [" + name + "] execution failed", e);
|
||||
}
|
||||
|
||||
/*
|
||||
long st = profilingConfigurableHookIn(op);
|
||||
|
@ -2418,19 +2427,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val newMap = new LinkedHashMap<String, INDArray>();
|
||||
|
||||
val result = (Nd4jCuda.VariablesSet) nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
|
||||
val status = OpStatus.byNumber(result.status());
|
||||
OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result));
|
||||
|
||||
if (status != OpStatus.ND4J_STATUS_OK)
|
||||
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
||||
|
||||
for (int e = 0; e < result.size(); e++) {
|
||||
val var = result.at(e);
|
||||
val nodeId = var.id();
|
||||
val index = var.index();
|
||||
val shapeInfo = var.getNDArray().shapeInfo();
|
||||
val buffer = var.getNDArray().buffer();
|
||||
for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) {
|
||||
OpaqueVariable var = nativeOps.getVariable(result, e);
|
||||
int nodeId = nativeOps.getVariableId(var);
|
||||
int index = nativeOps.getVariableIndex(var);
|
||||
LongPointer shapeInfo = nativeOps.getVariableShape(var);
|
||||
Pointer buffer = nativeOps.getVariableBuffer(var);
|
||||
|
||||
val rank = (int) shapeInfo.get(0);
|
||||
val jshape = new long[rank * 2 + 4];
|
||||
|
@ -2446,7 +2455,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType());
|
||||
AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
|
||||
|
||||
val nodeName = var.getName().getString();
|
||||
String nodeName = nativeOps.getVariableName(var);
|
||||
newMap.put(nodeName, array);
|
||||
}
|
||||
|
||||
|
@ -2584,9 +2593,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||
|
||||
val result = new CudaLongDataBuffer(dbf.primary(), dbf.special(), Shape.shapeInfoLength(shape.length));
|
||||
val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length));
|
||||
|
||||
nativeOps.deleteShapeBuffer(dbf);
|
||||
|
||||
|
@ -2595,10 +2604,10 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||
val pack = (Nd4jCuda.TadPack) nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||
|
||||
val tadShape = new CudaLongDataBuffer(pack.primaryShapeInfo(), pack.specialShapeInfo(), pack.shapeInfoLength());
|
||||
val tadOffsets = new CudaLongDataBuffer(pack.primaryOffsets(), pack.specialOffsets(), pack.numberOfTads());
|
||||
val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack));
|
||||
val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));
|
||||
|
||||
nativeOps.deleteTadPack(pack);
|
||||
|
||||
|
@ -2607,9 +2616,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
||||
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
||||
|
||||
val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType);
|
||||
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
|
||||
buffer.setConstant(true);
|
||||
|
||||
return buffer;
|
||||
|
@ -2617,9 +2626,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
||||
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
||||
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
||||
|
||||
val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType);
|
||||
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
|
||||
buffer.setConstant(true);
|
||||
|
||||
return buffer;
|
||||
|
|
|
@ -18,6 +18,9 @@ package org.nd4j.linalg.jcublas.ops.executioner;
|
|||
|
||||
import lombok.NonNull;
|
||||
import lombok.val;
|
||||
import org.bytedeco.javacpp.BooleanPointer;
|
||||
import org.bytedeco.javacpp.DoublePointer;
|
||||
import org.bytedeco.javacpp.LongPointer;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
||||
|
@ -28,7 +31,10 @@ import org.nd4j.linalg.api.ops.OpContext;
|
|||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.nativeblas.Nd4jCuda;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.OpaqueContext;
|
||||
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||
|
||||
/**
|
||||
* CUDA wrapper for op Context
|
||||
|
@ -36,34 +42,41 @@ import org.nd4j.nativeblas.Nd4jCuda;
|
|||
*/
|
||||
public class CudaOpContext extends BaseOpContext implements OpContext {
|
||||
// we might want to have configurable
|
||||
private Nd4jCuda.Context context = new Nd4jCuda.Context(1);
|
||||
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
private OpaqueContext context = nativeOps.createGraphContext(1);
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
nativeOps.deleteGraphContext(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIArguments(long... arguments) {
|
||||
super.setIArguments(arguments);
|
||||
context.setIArguments(arguments, arguments.length);
|
||||
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBArguments(boolean... arguments) {
|
||||
super.setBArguments(arguments);
|
||||
context.setBArguments(arguments, arguments.length);
|
||||
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTArguments(double... arguments) {
|
||||
super.setTArguments(arguments);
|
||||
context.setTArguments(arguments, arguments.length);
|
||||
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRngStates(long rootState, long nodeState) {
|
||||
context.randomGenerator().setStates(rootState, nodeState);
|
||||
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<Long, Long> getRngStates() {
|
||||
return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState());
|
||||
OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
|
||||
return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -72,7 +85,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
||||
|
||||
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||
context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||
nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||
|
||||
super.setInputArray(index, array);
|
||||
}
|
||||
|
@ -82,7 +95,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
||||
|
||||
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||
context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||
nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||
|
||||
super.setOutputArray(index, array);
|
||||
}
|
||||
|
@ -113,11 +126,11 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
|
||||
|
||||
public void setCudaStream(cudaStream_t stream, Pointer reductionPointer, Pointer allocationPointer) {
|
||||
context.setCudaContext(stream, reductionPointer, allocationPointer);
|
||||
nativeOps.setGraphContextCudaContext(context, stream, reductionPointer, allocationPointer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void markInplace(boolean reallyInplace) {
|
||||
context.markInplace(reallyInplace);
|
||||
nativeOps.markGraphContextInplace(context, reallyInplace);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,9 @@ import org.bytedeco.javacpp.PointerPointer;
|
|||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||
import org.nd4j.nativeblas.Nd4jCuda;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||
import org.nd4j.rng.NativeRandom;
|
||||
|
||||
import java.util.List;
|
||||
|
@ -33,7 +35,7 @@ import java.util.List;
|
|||
*/
|
||||
@Slf4j
|
||||
public class CudaNativeRandom extends NativeRandom {
|
||||
|
||||
private NativeOps nativeOps;
|
||||
protected List<DataBuffer> stateBuffers;
|
||||
|
||||
public CudaNativeRandom() {
|
||||
|
@ -50,10 +52,16 @@ public class CudaNativeRandom extends NativeRandom {
|
|||
|
||||
@Override
|
||||
public void init() {
|
||||
statePointer = new Nd4jCuda.RandomGenerator(seed, seed ^ 0xdeadbeef);
|
||||
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
||||
setSeed(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PointerPointer getExtraPointers() {
|
||||
return null;
|
||||
|
@ -63,7 +71,7 @@ public class CudaNativeRandom extends NativeRandom {
|
|||
public void setSeed(long seed) {
|
||||
this.seed = seed;
|
||||
this.currentPosition.set(0);
|
||||
((Nd4jCuda.RandomGenerator) statePointer).setStates(seed, seed ^ 0xdeadbeef);
|
||||
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -73,24 +81,24 @@ public class CudaNativeRandom extends NativeRandom {
|
|||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
return ((Nd4jCuda.RandomGenerator) statePointer).relativeInt(currentPosition.getAndIncrement());
|
||||
return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||
}
|
||||
|
||||
@Override
|
||||
public long nextLong() {
|
||||
return ((Nd4jCuda.RandomGenerator) statePointer).relativeLong(currentPosition.getAndIncrement());
|
||||
return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||
}
|
||||
|
||||
public long rootState() {
|
||||
return ((Nd4jCuda.RandomGenerator) statePointer).rootState();
|
||||
return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer);
|
||||
}
|
||||
|
||||
public long nodeState() {
|
||||
return ((Nd4jCuda.RandomGenerator) statePointer).nodeState();
|
||||
return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStates(long rootState, long nodeState) {
|
||||
((Nd4jCuda.RandomGenerator) statePointer).setStates(rootState, nodeState);
|
||||
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize);
|
|||
* @param targetBuffer
|
||||
* @param offsetsBuffer
|
||||
*/
|
||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
IntPointer dimension,
|
||||
int dimensionLength);
|
||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
IntBuffer dimension,
|
||||
int dimensionLength);
|
||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
int[] dimension,
|
||||
int dimensionLength);
|
||||
|
||||
public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack);
|
||||
public native int getShapeInfoLength(OpaqueTadPack pack);
|
||||
|
||||
public native void deleteTadPack(OpaqueTadPack ptr);
|
||||
|
||||
/*
|
||||
* PullRow special op
|
||||
*/
|
||||
|
@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers
|
|||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
|
||||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
|
||||
|
||||
|
||||
// flatbuffers execution
|
||||
public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
|
||||
public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr);
|
||||
public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
|
||||
|
||||
public native @Cast("char*") String getAllCustomOps();
|
||||
|
||||
|
@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer
|
|||
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
|
||||
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
|
||||
|
||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||
|
||||
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
||||
|
||||
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
|
||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
|
||||
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
|
||||
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
|
||||
public native int getVariableId(OpaqueVariable variable);
|
||||
public native int getVariableIndex(OpaqueVariable variable);
|
||||
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
|
||||
public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable);
|
||||
public native Pointer getVariableBuffer(OpaqueVariable variable);
|
||||
|
||||
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
|
||||
|
||||
|
@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
|
|||
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
|
||||
public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer);
|
||||
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
|
||||
|
||||
// GraphState creation
|
||||
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
||||
|
@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*"
|
|||
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
|
||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
|
||||
public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||
public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||
|
||||
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
|
||||
|
@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
|||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||
|
||||
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
|
||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
||||
|
||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
||||
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
||||
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
|
||||
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
|
||||
public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
|
||||
public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
|
||||
|
||||
public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
|
||||
|
||||
public native OpaqueContext createGraphContext(int nodeId);
|
||||
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
||||
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
||||
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments);
|
||||
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments);
|
||||
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
||||
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
||||
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments);
|
||||
public native void deleteGraphContext(OpaqueContext ptr);
|
||||
|
||||
public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||
public native OpaqueRandomGenerator createRandomGenerator();
|
||||
public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
|
||||
public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
|
||||
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr);
|
||||
public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||
public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||
public native void deleteRandomGenerator(OpaqueRandomGenerator ptr);
|
||||
|
||||
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||
|
@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
|
|||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
||||
|
||||
|
||||
/**
|
||||
* This method returns new array with the same shape & data type
|
||||
* @return
|
||||
*/
|
||||
public native @ByVal NDArray like();
|
||||
|
||||
/**
|
||||
* This method returns new uninitialized array with the same shape & data type
|
||||
* @return
|
||||
*/
|
||||
public native @ByVal NDArray ulike();
|
||||
|
||||
|
||||
/**
|
||||
* this constructor creates new NDArray with shape matching "other" array,
|
||||
* doesn't copy "other" elements into new array !!!
|
||||
|
|
|
@ -113,6 +113,14 @@ public class Nd4jCudaPresets implements InfoMapper {
|
|||
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
||||
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
||||
.put(new Info("NativeOps.h").objectify())
|
||||
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
|
||||
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
|
||||
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
|
||||
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
|
||||
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
|
||||
.put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator"))
|
||||
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
|
||||
"@Cast(\"char*\") BytePointer"))
|
||||
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",
|
||||
|
|
|
@ -48,7 +48,6 @@ import org.nd4j.linalg.util.ArrayUtil;
|
|||
import org.nd4j.nativeblas.BaseNativeNDArrayFactory;
|
||||
import org.nd4j.nativeblas.LongPointerWrapper;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.Nd4jCpu;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
|
|
@ -17,12 +17,18 @@
|
|||
package org.nd4j.linalg.cpu.nativecpu.ops;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.bytedeco.javacpp.BooleanPointer;
|
||||
import org.bytedeco.javacpp.DoublePointer;
|
||||
import org.bytedeco.javacpp.LongPointer;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||
import org.nd4j.linalg.api.ops.OpContext;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
import org.nd4j.nativeblas.Nd4jCpu;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.OpaqueContext;
|
||||
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
@ -33,46 +39,53 @@ import java.util.List;
|
|||
*/
|
||||
public class CpuOpContext extends BaseOpContext implements OpContext {
|
||||
// we might want to have configurable
|
||||
private Nd4jCpu.Context context = new Nd4jCpu.Context(1);
|
||||
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
private OpaqueContext context = nativeOps.createGraphContext(1);
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
nativeOps.deleteGraphContext(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIArguments(long... arguments) {
|
||||
super.setIArguments(arguments);
|
||||
context.setIArguments(arguments, arguments.length);
|
||||
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBArguments(boolean... arguments) {
|
||||
super.setBArguments(arguments);
|
||||
context.setBArguments(arguments, arguments.length);
|
||||
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setTArguments(double... arguments) {
|
||||
super.setTArguments(arguments);
|
||||
context.setTArguments(arguments, arguments.length);
|
||||
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setRngStates(long rootState, long nodeState) {
|
||||
context.randomGenerator().setStates(rootState, nodeState);
|
||||
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Pair<Long, Long> getRngStates() {
|
||||
return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState());
|
||||
OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
|
||||
return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setInputArray(int index, @NonNull INDArray array) {
|
||||
context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
||||
nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
||||
|
||||
super.setInputArray(index, array);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setOutputArray(int index, @NonNull INDArray array) {
|
||||
context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
||||
nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
||||
|
||||
super.setOutputArray(index, array);
|
||||
}
|
||||
|
@ -84,6 +97,6 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
|
|||
|
||||
@Override
|
||||
public void markInplace(boolean reallyInplace) {
|
||||
context.markInplace(reallyInplace);
|
||||
nativeOps.markGraphContextInplace(context, reallyInplace);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,10 +70,14 @@ import org.nd4j.nativeblas.LongPointerWrapper;
|
|||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.Nd4jCpu;
|
||||
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
|
||||
import org.nd4j.nativeblas.OpaqueShapeList;
|
||||
import org.nd4j.nativeblas.OpaqueTadPack;
|
||||
import org.nd4j.nativeblas.OpaqueVariable;
|
||||
import org.nd4j.nativeblas.OpaqueVariableSet;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* Native operation
|
||||
|
@ -1641,7 +1645,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
}
|
||||
|
||||
val name = op.opName();
|
||||
val context = buildContext();
|
||||
try (val context = buildContext()) {
|
||||
|
||||
context.markInplace(op.isInplaceCall());
|
||||
|
||||
|
@ -1657,7 +1661,6 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
context.setIArguments(op.iArgs());
|
||||
context.setTArguments(op.tArgs());
|
||||
|
||||
try {
|
||||
val result = exec(op, context);
|
||||
val states = context.getRngStates();
|
||||
|
||||
|
@ -1860,9 +1863,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
for (val t: tArgs1)
|
||||
tArgs.put(cnt++, t);
|
||||
|
||||
Nd4jCpu.ShapeList ptrptr;
|
||||
OpaqueShapeList ptrptr;
|
||||
try {
|
||||
ptrptr = (Nd4jCpu.ShapeList) loop.calculateOutputShapes2(null,
|
||||
ptrptr = loop.calculateOutputShapes2(null,
|
||||
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
|
||||
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
|
||||
} catch (Throwable t){
|
||||
|
@ -1891,8 +1894,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
if (ptrptr == null)
|
||||
throw new RuntimeException();
|
||||
|
||||
for (int e = 0; e < ptrptr.size(); e++ )
|
||||
result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer()));
|
||||
for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ )
|
||||
result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer()));
|
||||
|
||||
|
||||
loop.deleteShapeList(ptrptr);
|
||||
|
@ -1947,19 +1950,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
val newMap = new LinkedHashMap<String, INDArray>();
|
||||
|
||||
val result = (Nd4jCpu.VariablesSet) loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||
|
||||
val status = OpStatus.byNumber(result.status());
|
||||
OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result));
|
||||
|
||||
if (status != OpStatus.ND4J_STATUS_OK)
|
||||
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
||||
|
||||
for (int e = 0; e < result.size(); e++) {
|
||||
val var = result.at(e);
|
||||
val nodeId = var.id();
|
||||
val index = var.index();
|
||||
val shapeInfo = var.getNDArray().shapeInfo();
|
||||
val buffer = var.getNDArray().buffer();
|
||||
for (int e = 0; e < loop.getVariableSetSize(result); e++) {
|
||||
OpaqueVariable var = loop.getVariable(result, e);
|
||||
int nodeId = loop.getVariableId(var);
|
||||
int index = loop.getVariableIndex(var);
|
||||
LongPointer shapeInfo = loop.getVariableShape(var);
|
||||
Pointer buffer = loop.getVariableBuffer(var);
|
||||
|
||||
val rank = (int) shapeInfo.get(0);
|
||||
val jshape = new long[rank * 2 + 4];
|
||||
|
@ -1979,7 +1982,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(array.dataType()), MemcpyDirection.HOST_TO_HOST);
|
||||
|
||||
//newMap.put(keySet.get(nodeId), array);
|
||||
val nodeName = var.getName().getString();
|
||||
String nodeName = loop.getVariableName(var);
|
||||
newMap.put(nodeName, array);
|
||||
}
|
||||
|
||||
|
@ -2160,9 +2163,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||
val dbf = (Nd4jCpu.ConstantDataBuffer) loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||
OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||
|
||||
val result = new LongBuffer(dbf.primary(), Shape.shapeInfoLength(shape.length));
|
||||
val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length));
|
||||
|
||||
loop.deleteShapeBuffer(dbf);
|
||||
|
||||
|
@ -2171,10 +2174,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
@Override
|
||||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||
val pack = (Nd4jCpu.TadPack) loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||
OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||
|
||||
val tadShape = new LongBuffer(pack.primaryShapeInfo(), pack.shapeInfoLength());
|
||||
val tadOffsets = new LongBuffer(pack.primaryOffsets(), pack.numberOfTads());
|
||||
val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack));
|
||||
val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack));
|
||||
|
||||
loop.deleteTadPack(pack);
|
||||
|
||||
|
|
|
@ -18,7 +18,9 @@ package org.nd4j.linalg.cpu.nativecpu.rng;
|
|||
|
||||
import org.bytedeco.javacpp.PointerPointer;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.nativeblas.Nd4jCpu;
|
||||
import org.nd4j.nativeblas.NativeOps;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||
import org.nd4j.rng.NativeRandom;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
@ -29,6 +31,8 @@ import java.util.concurrent.atomic.AtomicLong;
|
|||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class CpuNativeRandom extends NativeRandom {
|
||||
private NativeOps nativeOps;
|
||||
|
||||
public CpuNativeRandom() {
|
||||
super();
|
||||
}
|
||||
|
@ -43,7 +47,13 @@ public class CpuNativeRandom extends NativeRandom {
|
|||
|
||||
@Override
|
||||
public void init() {
|
||||
statePointer = new Nd4jCpu.RandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
||||
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -55,7 +65,7 @@ public class CpuNativeRandom extends NativeRandom {
|
|||
public void setSeed(long seed) {
|
||||
this.seed = seed;
|
||||
this.currentPosition.set(0);
|
||||
((Nd4jCpu.RandomGenerator)statePointer).setStates(seed, seed ^ 0xdeadbeef);
|
||||
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -65,24 +75,24 @@ public class CpuNativeRandom extends NativeRandom {
|
|||
|
||||
@Override
|
||||
public int nextInt() {
|
||||
return ((Nd4jCpu.RandomGenerator)statePointer).relativeInt(currentPosition.getAndIncrement());
|
||||
return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||
}
|
||||
|
||||
@Override
|
||||
public long nextLong() {
|
||||
return ((Nd4jCpu.RandomGenerator)statePointer).relativeLong(currentPosition.getAndIncrement());
|
||||
return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||
}
|
||||
|
||||
public long rootState() {
|
||||
return ((Nd4jCpu.RandomGenerator) statePointer).rootState();
|
||||
return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer);
|
||||
}
|
||||
|
||||
public long nodeState() {
|
||||
return ((Nd4jCpu.RandomGenerator) statePointer).nodeState();
|
||||
return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setStates(long rootState, long nodeState) {
|
||||
((Nd4jCpu.RandomGenerator) statePointer).setStates(rootState, nodeState);
|
||||
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize);
|
|||
* @param targetBuffer
|
||||
* @param offsetsBuffer
|
||||
*/
|
||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||
IntPointer dimension,
|
||||
int dimensionLength);
|
||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||
IntBuffer dimension,
|
||||
int dimensionLength);
|
||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
||||
int[] dimension,
|
||||
int dimensionLength);
|
||||
|
||||
public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack);
|
||||
public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack);
|
||||
public native int getShapeInfoLength(OpaqueTadPack pack);
|
||||
|
||||
public native void deleteTadPack(OpaqueTadPack ptr);
|
||||
|
||||
/*
|
||||
* PullRow special op
|
||||
*/
|
||||
|
@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers
|
|||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
|
||||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
|
||||
|
||||
|
||||
// flatbuffers execution
|
||||
public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
|
||||
public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr);
|
||||
public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
|
||||
|
||||
public native @Cast("char*") String getAllCustomOps();
|
||||
|
||||
|
@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer
|
|||
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
|
||||
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
|
||||
|
||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||
|
||||
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
||||
|
||||
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||
|
||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||
|
||||
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
|
||||
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
|
||||
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
|
||||
public native int getVariableId(OpaqueVariable variable);
|
||||
public native int getVariableIndex(OpaqueVariable variable);
|
||||
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
|
||||
public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable);
|
||||
public native Pointer getVariableBuffer(OpaqueVariable variable);
|
||||
|
||||
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
|
||||
|
||||
|
@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
|
|||
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||
|
||||
public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer);
|
||||
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
|
||||
|
||||
// GraphState creation
|
||||
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
||||
|
@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*"
|
|||
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
|
||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
|
||||
public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||
public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||
|
||||
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
|
||||
|
@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
|||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||
|
||||
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
|
||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
||||
public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
||||
|
||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
||||
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
||||
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
|
||||
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
|
||||
public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
|
||||
public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
|
||||
|
||||
public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
|
||||
|
||||
public native OpaqueContext createGraphContext(int nodeId);
|
||||
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
||||
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
||||
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments);
|
||||
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments);
|
||||
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
||||
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
||||
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments);
|
||||
public native void deleteGraphContext(OpaqueContext ptr);
|
||||
|
||||
public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||
public native OpaqueRandomGenerator createRandomGenerator();
|
||||
public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
|
||||
public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
|
||||
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr);
|
||||
public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||
public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||
public native void deleteRandomGenerator(OpaqueRandomGenerator ptr);
|
||||
|
||||
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||
|
@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
|
|||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
||||
|
||||
|
||||
/**
|
||||
* This method returns new array with the same shape & data type
|
||||
* @return
|
||||
*/
|
||||
public native @ByVal NDArray like();
|
||||
|
||||
/**
|
||||
* This method returns new uninitialized array with the same shape & data type
|
||||
* @return
|
||||
*/
|
||||
public native @ByVal NDArray ulike();
|
||||
|
||||
|
||||
/**
|
||||
* this constructor creates new NDArray with shape matching "other" array,
|
||||
* doesn't copy "other" elements into new array !!!
|
||||
|
|
|
@ -156,6 +156,14 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
|
|||
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
||||
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
||||
.put(new Info("NativeOps.h").objectify())
|
||||
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
|
||||
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
|
||||
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
|
||||
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
|
||||
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
|
||||
.put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator"))
|
||||
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
|
||||
"@Cast(\"char*\") BytePointer"))
|
||||
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",
|
||||
|
|
|
@ -442,7 +442,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
context.setOutputArray(0, arrayZ);
|
||||
|
||||
val addOp = new AddOp();
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp(null, addOp.opHash(), context.contextPointer());
|
||||
NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp2(null, addOp.opHash(), context.contextPointer());
|
||||
|
||||
assertEquals(exp, arrayZ);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue