Create C wrappers for some of the C++ classes currently used by ND4J

master
Samuel Audet 2019-07-24 21:14:54 +09:00 committed by AlexDBlack
parent 8881bfe7aa
commit 526b782e51
28 changed files with 1034 additions and 208 deletions

View File

@ -882,6 +882,8 @@ ND4J_EXPORT void enableVerboseMode(bool reallyEnable);
*/ */
ND4J_EXPORT void setGridLimit(int gridSize); ND4J_EXPORT void setGridLimit(int gridSize);
typedef nd4j::TadPack OpaqueTadPack;
/** /**
* *
* @param xShapeInfo * @param xShapeInfo
@ -890,10 +892,19 @@ ND4J_EXPORT void setGridLimit(int gridSize);
* @param targetBuffer * @param targetBuffer
* @param offsetsBuffer * @param offsetsBuffer
*/ */
ND4J_EXPORT nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo, ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo,
int *dimension, int *dimension,
int dimensionLength); int dimensionLength);
ND4J_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack);
ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack);
ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack);
ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
/* /*
* PullRow special op * PullRow special op
*/ */
@ -1639,10 +1650,13 @@ ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName,
ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length); ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length);
typedef nd4j::graph::ResultWrapper OpaqueResultWrapper;
// flatbuffers execution // flatbuffers execution
ND4J_EXPORT nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer); ND4J_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer);
ND4J_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr);
ND4J_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr);
ND4J_EXPORT const char* getAllCustomOps(); ND4J_EXPORT const char* getAllCustomOps();
@ -1652,14 +1666,31 @@ ND4J_EXPORT const char* getAllOperations();
ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace); ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace);
ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext); ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext);
ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); typedef nd4j::ShapeList OpaqueShapeList;
ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList); ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer);
ND4J_EXPORT nd4j::graph::VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); typedef nd4j::graph::VariablesSet OpaqueVariableSet;
typedef nd4j::graph::Variable OpaqueVariable;
ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set);
ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set);
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i);
ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
ND4J_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable);
ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable);
ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId); ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId);
@ -1668,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer);
ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer); ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer);
ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer); ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer);
ND4J_EXPORT void deleteVariablesSet(Nd4jPointer pointer); ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer);
// GraphState creation // GraphState creation
ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id); ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id);
@ -1684,7 +1715,9 @@ ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPoi
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); //void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length); ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length);
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); ND4J_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr);
ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr);
ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr);
ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
@ -1693,18 +1726,45 @@ ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOf
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
int* hIindexes, int* dIindexes); int* hIindexes, int* dIindexes);
ND4J_EXPORT void deleteShapeBuffer(Nd4jPointer ptr);
ND4J_EXPORT void deleteTadPack(Nd4jPointer ptr);
ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
ND4J_EXPORT nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty); typedef nd4j::ConstantDataBuffer OpaqueConstantDataBuffer;
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length); ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
ND4J_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf);
ND4J_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf);
ND4J_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf);
ND4J_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf);
ND4J_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr);
typedef nd4j::graph::Context OpaqueContext;
typedef nd4j::graph::RandomGenerator OpaqueRandomGenerator;
ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId);
ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr);
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
ND4J_EXPORT void deleteGraphContext(OpaqueContext* ptr);
ND4J_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0);
ND4J_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr);
ND4J_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr);
ND4J_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0);
ND4J_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index);
ND4J_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index);
ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr);
ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut); ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut);
ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut); ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut);

View File

@ -1328,6 +1328,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimen
return pack; return pack;
} }
Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) {
return pack->primaryShapeInfo();
}
Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) {
return pack->primaryOffsets();
}
Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) {
return pack->specialShapeInfo();
}
Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) {
return pack->specialOffsets();
}
Nd4jLong getNumberOfTads(nd4j::TadPack* pack) {
return pack->numberOfTads();
}
int getShapeInfoLength(nd4j::TadPack* pack) {
return pack->shapeInfoLength();
}
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
// no-op // no-op
return 0L; return 0L;
@ -2005,6 +2024,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
} }
Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) {
return ptr->size();
}
Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) {
return ptr->pointer();
}
const char* getAllCustomOps() { const char* getAllCustomOps() {
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations(); return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
} }
@ -2041,7 +2067,13 @@ int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong *hXSh
BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES);
} }
Nd4jLong getShapeListSize(nd4j::ShapeList* list) {
return list->size();
}
Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) {
return list->at(i);
}
void deleteShapeList(Nd4jPointer shapeList) { void deleteShapeList(Nd4jPointer shapeList) {
auto list = reinterpret_cast<nd4j::ShapeList*>(shapeList); auto list = reinterpret_cast<nd4j::ShapeList*>(shapeList);
@ -2305,6 +2337,38 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo
return nullptr; return nullptr;
} }
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
return set->size();
}
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
return set->status();
}
nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) {
return set->at(i);
}
int getVariableId(nd4j::graph::Variable* variable) {
return variable->id();
}
int getVariableIndex(nd4j::graph::Variable* variable) {
return variable->index();
}
const char* getVariableName(nd4j::graph::Variable* variable) {
return variable->getName()->c_str();
}
Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) {
return variable->getNDArray()->shapeInfo();
}
void* getVariableBuffer(nd4j::graph::Variable* variable) {
return variable->getNDArray()->buffer();
}
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId); nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
@ -2628,6 +2692,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int
return reinterpret_cast<Nd4jPointer>(u); return reinterpret_cast<Nd4jPointer>(u);
} }
Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
return reinterpret_cast<nd4j::utf8string*>(ptr)->_length;
}
char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
return reinterpret_cast<nd4j::utf8string*>(ptr)->_buffer;
}
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
delete(reinterpret_cast<nd4j::utf8string*>(ptr)); delete(reinterpret_cast<nd4j::utf8string*>(ptr));
} }
@ -2710,14 +2781,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid
return buffer; return buffer;
} }
void deleteShapeBuffer(Nd4jPointer ptr) { void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) {
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr); delete ptr;
delete buffer;
} }
void deleteTadPack(Nd4jPointer ptr) { void deleteTadPack(nd4j::TadPack* ptr) {
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr); delete ptr;
delete buffer;
} }
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) { nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
@ -2732,6 +2801,78 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
} }
Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) {
return dbf->primary();
}
Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) {
return dbf->special();
}
Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) {
return dbf->length();
}
Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
return dbf->sizeOf();
}
nd4j::graph::Context* createGraphContext(int nodeId) {
return new nd4j::graph::Context(nodeId);
}
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
return &ptr->randomGenerator();
}
void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) {
ptr->markInplace(reallyInplace);
}
void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) {
}
void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
}
void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
}
void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) {
ptr->setTArguments(arguments, numberOfArguments);
}
void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) {
ptr->setIArguments(arguments, numberOfArguments);
}
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
ptr->setBArguments(arguments, numberOfArguments);
}
void deleteGraphContext(nd4j::graph::Context* ptr) {
delete ptr;
}
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
}
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
return ptr->rootState();
}
Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) {
return ptr->nodeState();
}
void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) {
ptr->setStates(rootSeed, nodeSeed);
}
int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
return ptr->relativeInt(index);
}
Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
return ptr->relativeLong(index);
}
void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) {
delete ptr;
}
int dataTypeFromNpyHeader(void *header) { int dataTypeFromNpyHeader(void *header) {

View File

@ -1499,6 +1499,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimen
return pack; return pack;
} }
Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) {
return pack->primaryShapeInfo();
}
Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) {
return pack->primaryOffsets();
}
Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) {
return pack->specialShapeInfo();
}
Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) {
return pack->specialOffsets();
}
Nd4jLong getNumberOfTads(nd4j::TadPack* pack) {
return pack->numberOfTads();
}
int getShapeInfoLength(nd4j::TadPack* pack) {
return pack->shapeInfoLength();
}
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved); cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
@ -2533,6 +2552,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
} }
Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) {
return ptr->size();
}
Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) {
return ptr->pointer();
}
const char* getAllCustomOps() { const char* getAllCustomOps() {
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations(); return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
@ -2607,6 +2633,13 @@ nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash
return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs);
} }
Nd4jLong getShapeListSize(nd4j::ShapeList* list) {
return list->size();
}
Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) {
return list->at(i);
}
static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
if (op == nullptr) if (op == nullptr)
@ -2775,6 +2808,38 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N
return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs);
} }
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
return set->size();
}
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
return set->status();
}
nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) {
return set->at(i);
}
int getVariableId(nd4j::graph::Variable* variable) {
return variable->id();
}
int getVariableIndex(nd4j::graph::Variable* variable) {
return variable->index();
}
const char* getVariableName(nd4j::graph::Variable* variable) {
return variable->getName()->c_str();
}
Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) {
return variable->getNDArray()->shapeInfo();
}
void* getVariableBuffer(nd4j::graph::Variable* variable) {
return variable->getNDArray()->buffer();
}
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId); nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
@ -3102,6 +3167,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int
return reinterpret_cast<Nd4jPointer>(u); return reinterpret_cast<Nd4jPointer>(u);
} }
Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
return reinterpret_cast<nd4j::utf8string*>(ptr)->_length;
}
char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
return reinterpret_cast<nd4j::utf8string*>(ptr)->_buffer;
}
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
delete(reinterpret_cast<nd4j::utf8string*>(ptr)); delete(reinterpret_cast<nd4j::utf8string*>(ptr));
} }
@ -3237,14 +3309,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid
return buffer; return buffer;
} }
void deleteShapeBuffer(Nd4jPointer ptr) { void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) {
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr); delete ptr;
delete buffer;
} }
void deleteTadPack(Nd4jPointer ptr) { void deleteTadPack(nd4j::TadPack* ptr) {
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr); delete ptr;
delete buffer;
} }
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) { nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
@ -3259,6 +3329,82 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
} }
Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) {
return dbf->primary();
}
Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) {
return dbf->special();
}
Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) {
return dbf->length();
}
Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
return dbf->sizeOf();
}
nd4j::graph::Context* createGraphContext(int nodeId) {
return new nd4j::graph::Context(nodeId);
}
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
return &ptr->randomGenerator();
}
void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) {
ptr->markInplace(reallyInplace);
}
void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) {
ptr->setCudaContext(stream, reductionPointer, allocationPointer);
}
void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
}
void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
}
void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) {
ptr->setTArguments(arguments, numberOfArguments);
}
void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) {
ptr->setIArguments(arguments, numberOfArguments);
}
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
ptr->setBArguments(arguments, numberOfArguments);
}
void deleteGraphContext(nd4j::graph::Context* ptr) {
delete ptr;
}
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
}
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
return ptr->rootState();
}
Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) {
return ptr->nodeState();
}
void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) {
ptr->setStates(rootSeed, nodeSeed);
}
int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
return ptr->relativeInt(index);
}
Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
return ptr->relativeLong(index);
}
void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) {
delete ptr;
}
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray)); cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
unsigned int shapeSize = arr.shape.size(); unsigned int shapeSize = arr.shape.size();

View File

@ -27,6 +27,9 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <graph/GraphUtils.h> #include <graph/GraphUtils.h>
using namespace nd4j::ops;
using namespace nd4j::graph;
int int
main(int argc, char *argv[]) { main(int argc, char *argv[]) {
// this string will contain list of operations // this string will contain list of operations

View File

@ -25,6 +25,7 @@
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
using namespace nd4j; using namespace nd4j;
using namespace nd4j::ops;
using namespace nd4j::graph; using namespace nd4j::graph;
class OpTrackerTests : public testing::Test { class OpTrackerTests : public testing::Test {

View File

@ -28,7 +28,7 @@ import java.util.List;
* *
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
public interface OpContext { public interface OpContext extends AutoCloseable {
/** /**
* This method sets integer arguments required for operation * This method sets integer arguments required for operation

View File

@ -38,8 +38,9 @@ import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.ResultWrapperAbstraction; import org.nd4j.nativeblas.OpaqueResultWrapper;
import java.io.File; import java.io.File;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -100,11 +101,12 @@ public class NativeGraphExecutioner implements GraphExecutioner {
log.info("Buffer length: {}", buffer.limit()); log.info("Buffer length: {}", buffer.limit());
val res = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraph(null, bPtr); NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
OpaqueResultWrapper res = nativeOps.executeFlatGraph(null, bPtr);
if (res == null) if (res == null)
throw new ND4JIllegalStateException("Graph execution failed"); throw new ND4JIllegalStateException("Graph execution failed");
PagedPointer pagedPointer = new PagedPointer(res.pointer(),res.size()); PagedPointer pagedPointer = new PagedPointer(nativeOps.getResultWrapperPointer(res), nativeOps.getResultWrapperSize(res));
FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer()); FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer());
log.info("VarMap: {}", sd.variableMap()); log.info("VarMap: {}", sd.variableMap());
@ -132,7 +134,7 @@ public class NativeGraphExecutioner implements GraphExecutioner {
} }
// now we need to release native memory // now we need to release native memory
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteResultWrapper(res); nativeOps.deleteResultWrapper(res);
return results; return results;
} }

View File

@ -697,7 +697,16 @@ public interface NativeOps {
void setGridLimit(int gridSize); void setGridLimit(int gridSize);
Pointer tadOnlyShapeInfo(@Cast("Nd4jLong *") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); OpaqueTadPack tadOnlyShapeInfo(LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
LongPointer getPrimaryOffsets(OpaqueTadPack pack);
LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
LongPointer getSpecialOffsets(OpaqueTadPack pack);
long getNumberOfTads(OpaqueTadPack pack);
int getShapeInfoLength(OpaqueTadPack pack);
void deleteTadPack(OpaqueTadPack pointer);
/////////////// ///////////////
@ -1037,7 +1046,10 @@ public interface NativeOps {
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length); void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);
ResultWrapperAbstraction executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer); OpaqueResultWrapper executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer);
long getResultWrapperSize(OpaqueResultWrapper ptr);
Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
String getAllCustomOps(); String getAllCustomOps();
@ -1047,13 +1059,25 @@ public interface NativeOps {
int execCustomOp(PointerPointer extraPointers, long opHashCode, PointerPointer inputBuffers, PointerPointer inputShapes, int numInput, PointerPointer outputBuffers, PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, boolean isInplace); int execCustomOp(PointerPointer extraPointers, long opHashCode, PointerPointer inputBuffers, PointerPointer inputShapes, int numInput, PointerPointer outputBuffers, PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, boolean isInplace);
Pointer calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs); OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
Pointer calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs); OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
long getShapeListSize(OpaqueShapeList list);
LongPointer getShape(OpaqueShapeList list, long i);
int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer); int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer);
Pointer executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs); OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
long getVariableSetSize(OpaqueVariableSet set);
int getVariableSetStatus(OpaqueVariableSet set);
OpaqueVariable getVariable(OpaqueVariableSet set, long i);
int getVariableId(OpaqueVariable variable);
int getVariableIndex(OpaqueVariable variable);
String getVariableName(OpaqueVariable variable);
LongPointer getVariableShape(OpaqueVariable variable);
Pointer getVariableBuffer(OpaqueVariable variable);
void deleteResultWrapper(Pointer ptr); void deleteResultWrapper(Pointer ptr);
@ -1071,15 +1095,11 @@ public interface NativeOps {
void deleteNPArrayMap(Pointer pointer); void deleteNPArrayMap(Pointer pointer);
void deleteVariablesSet(Pointer pointer); void deleteVariablesSet(OpaqueVariableSet pointer);
// GraphState creation // GraphState creation
Pointer getGraphState(long id); Pointer getGraphState(long id);
void deleteShapeBuffer(Pointer state);
void deleteTadPack(Pointer pointer);
void deleteGraphState(Pointer state); void deleteGraphState(Pointer state);
int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold); int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold);
@ -1096,6 +1116,8 @@ public interface NativeOps {
//void fillUtf8String(PointerPointer extraPointers, String[] string, int numStrings, Pointer buffer); //void fillUtf8String(PointerPointer extraPointers, String[] string, int numStrings, Pointer buffer);
Pointer createUtf8String(PointerPointer extraPointers, String string, int length); Pointer createUtf8String(PointerPointer extraPointers, String string, int length);
long getUtf8StringLength(PointerPointer extraPointers, Pointer ptr);
BytePointer getUtf8StringBuffer(PointerPointer extraPointers, Pointer ptr);
void deleteUtf8String(PointerPointer extraPointers, Pointer ptr); void deleteUtf8String(PointerPointer extraPointers, Pointer ptr);
@ -1116,11 +1138,37 @@ public interface NativeOps {
*/ */
int dataTypeFromNpyHeader(Pointer numpyHeader); int dataTypeFromNpyHeader(Pointer numpyHeader);
Pointer shapeBuffer(int rank, @Cast("Nd4jLong *") LongPointer shape, @Cast("Nd4jLong *") LongPointer strides, int dtype, char order, long ews, boolean empty); OpaqueConstantDataBuffer shapeBuffer(int rank, LongPointer shape, LongPointer strides, int dtype, char order, long ews, boolean empty);
Pointer constantBufferDouble(int dtype, DoublePointer data, int length); OpaqueConstantDataBuffer constantBufferDouble(int dtype, DoublePointer data, int length);
Pointer constantBufferLong(int dtype, @Cast("Nd4jLong *") LongPointer data, int length); OpaqueConstantDataBuffer constantBufferLong(int dtype, LongPointer data, int length);
Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
void deleteShapeBuffer(OpaqueConstantDataBuffer state);
OpaqueContext createGraphContext(int nodeId);
OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
void markGraphContextInplace(OpaqueContext ptr, boolean reallyInplace);
void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
void deleteGraphContext(OpaqueContext ptr);
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);
long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
void deleteRandomGenerator(OpaqueRandomGenerator ptr);
String runLightBenchmarkSuit(boolean printOut); String runLightBenchmarkSuit(boolean printOut);

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueConstantDataBuffer extends Pointer {
public OpaqueConstantDataBuffer(Pointer p) { super(p); }
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueContext extends Pointer {
public OpaqueContext(Pointer p) { super(p); }
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueRandomGenerator extends Pointer {
public OpaqueRandomGenerator(Pointer p) { super(p); }
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueResultWrapper extends Pointer {
public OpaqueResultWrapper(Pointer p) { super(p); }
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueShapeList extends Pointer {
public OpaqueShapeList(Pointer p) { super(p); }
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueTadPack extends Pointer {
public OpaqueTadPack(Pointer p) { super(p); }
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueVariable extends Pointer {
public OpaqueVariable(Pointer p) { super(p); }
}

View File

@ -0,0 +1,27 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.nativeblas;
import org.bytedeco.javacpp.Pointer;
/**
*
* @author saudet
*/
public class OpaqueVariableSet extends Pointer {
public OpaqueVariableSet(Pointer p) { super(p); }
}

View File

@ -72,6 +72,11 @@ import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCuda; import org.nd4j.nativeblas.Nd4jCuda;
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariableSet;
import java.util.*; import java.util.*;
@ -2208,13 +2213,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
for (val t: op.tArgs()) for (val t: op.tArgs())
tArgs.put(cnt++, (float) t); tArgs.put(cnt++, (float) t);
val ptrptr = (Nd4jCuda.ShapeList) nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
if (ptrptr == null) if (ptrptr == null)
throw new RuntimeException(); throw new RuntimeException();
for (int e = 0; e < ptrptr.size(); e++ ) for (int e = 0; e < nativeOps.getShapeListSize(ptrptr); e++ )
result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer())); result.add(getShapeFromPointer(new PagedPointer(nativeOps.getShape(ptrptr, e)).asLongPointer()));
nativeOps.deleteShapeList(ptrptr); nativeOps.deleteShapeList(ptrptr);
@ -2251,28 +2256,32 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
val context = (CudaOpContext) buildContext(); val name = op.opName();
context.markInplace(op.isInplaceCall()); try (val context = (CudaOpContext) buildContext()) {
context.markInplace(op.isInplaceCall());
// transferring rng state // transferring rng state
context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState()); context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
//transferring input/output arrays //transferring input/output arrays
context.setInputArrays(op.inputArguments()); context.setInputArrays(op.inputArguments());
context.setOutputArrays(op.outputArguments()); context.setOutputArrays(op.outputArguments());
// transferring static args // transferring static args
context.setBArguments(op.bArgs()); context.setBArguments(op.bArgs());
context.setIArguments(op.iArgs()); context.setIArguments(op.iArgs());
context.setTArguments(op.tArgs()); context.setTArguments(op.tArgs());
val result = exec(op, context); val result = exec(op, context);
val states = context.getRngStates(); val states = context.getRngStates();
// pulling states back // pulling states back
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
return result; return result;
} catch (Exception e) {
throw new RuntimeException("Op [" + name + "] execution failed", e);
}
/* /*
long st = profilingConfigurableHookIn(op); long st = profilingConfigurableHookIn(op);
@ -2418,19 +2427,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
val newMap = new LinkedHashMap<String, INDArray>(); val newMap = new LinkedHashMap<String, INDArray>();
val result = (Nd4jCuda.VariablesSet) nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
val status = OpStatus.byNumber(result.status()); OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result));
if (status != OpStatus.ND4J_STATUS_OK) if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status); throw new ND4JIllegalStateException("Op execution failed: " + status);
for (int e = 0; e < result.size(); e++) { for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) {
val var = result.at(e); OpaqueVariable var = nativeOps.getVariable(result, e);
val nodeId = var.id(); int nodeId = nativeOps.getVariableId(var);
val index = var.index(); int index = nativeOps.getVariableIndex(var);
val shapeInfo = var.getNDArray().shapeInfo(); LongPointer shapeInfo = nativeOps.getVariableShape(var);
val buffer = var.getNDArray().buffer(); Pointer buffer = nativeOps.getVariableBuffer(var);
val rank = (int) shapeInfo.get(0); val rank = (int) shapeInfo.get(0);
val jshape = new long[rank * 2 + 4]; val jshape = new long[rank * 2 + 4];
@ -2446,7 +2455,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType()); Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType());
AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
val nodeName = var.getName().getString(); String nodeName = nativeOps.getVariableName(var);
newMap.put(nodeName, array); newMap.put(nodeName, array);
} }
@ -2584,9 +2593,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
val result = new CudaLongDataBuffer(dbf.primary(), dbf.special(), Shape.shapeInfoLength(shape.length)); val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length));
nativeOps.deleteShapeBuffer(dbf); nativeOps.deleteShapeBuffer(dbf);
@ -2595,10 +2604,10 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
val pack = (Nd4jCuda.TadPack) nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
val tadShape = new CudaLongDataBuffer(pack.primaryShapeInfo(), pack.specialShapeInfo(), pack.shapeInfoLength()); val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack));
val tadOffsets = new CudaLongDataBuffer(pack.primaryOffsets(), pack.specialOffsets(), pack.numberOfTads()); val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));
nativeOps.deleteTadPack(pack); nativeOps.deleteTadPack(pack);
@ -2607,9 +2616,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType); val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
buffer.setConstant(true); buffer.setConstant(true);
return buffer; return buffer;
@ -2617,9 +2626,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType); val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
buffer.setConstant(true); buffer.setConstant(true);
return buffer; return buffer;

View File

@ -18,6 +18,9 @@ package org.nd4j.linalg.jcublas.ops.executioner;
import lombok.NonNull; import lombok.NonNull;
import lombok.val; import lombok.val;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
@ -28,7 +31,10 @@ import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.Nd4jCuda; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueContext;
import org.nd4j.nativeblas.OpaqueRandomGenerator;
/** /**
* CUDA wrapper for op Context * CUDA wrapper for op Context
@ -36,34 +42,41 @@ import org.nd4j.nativeblas.Nd4jCuda;
*/ */
public class CudaOpContext extends BaseOpContext implements OpContext { public class CudaOpContext extends BaseOpContext implements OpContext {
// we might want to have configurable // we might want to have configurable
private Nd4jCuda.Context context = new Nd4jCuda.Context(1); private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private OpaqueContext context = nativeOps.createGraphContext(1);
@Override
public void close() {
nativeOps.deleteGraphContext(context);
}
@Override @Override
public void setIArguments(long... arguments) { public void setIArguments(long... arguments) {
super.setIArguments(arguments); super.setIArguments(arguments);
context.setIArguments(arguments, arguments.length); nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
} }
@Override @Override
public void setBArguments(boolean... arguments) { public void setBArguments(boolean... arguments) {
super.setBArguments(arguments); super.setBArguments(arguments);
context.setBArguments(arguments, arguments.length); nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
} }
@Override @Override
public void setTArguments(double... arguments) { public void setTArguments(double... arguments) {
super.setTArguments(arguments); super.setTArguments(arguments);
context.setTArguments(arguments, arguments.length); nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
} }
@Override @Override
public void setRngStates(long rootState, long nodeState) { public void setRngStates(long rootState, long nodeState) {
context.randomGenerator().setStates(rootState, nodeState); nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
} }
@Override @Override
public Pair<Long, Long> getRngStates() { public Pair<Long, Long> getRngStates() {
return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState()); OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
} }
@Override @Override
@ -72,7 +85,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
super.setInputArray(index, array); super.setInputArray(index, array);
} }
@ -82,7 +95,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
super.setOutputArray(index, array); super.setOutputArray(index, array);
} }
@ -113,11 +126,11 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
public void setCudaStream(cudaStream_t stream, Pointer reductionPointer, Pointer allocationPointer) { public void setCudaStream(cudaStream_t stream, Pointer reductionPointer, Pointer allocationPointer) {
context.setCudaContext(stream, reductionPointer, allocationPointer); nativeOps.setGraphContextCudaContext(context, stream, reductionPointer, allocationPointer);
} }
@Override @Override
public void markInplace(boolean reallyInplace) { public void markInplace(boolean reallyInplace) {
context.markInplace(reallyInplace); nativeOps.markGraphContextInplace(context, reallyInplace);
} }
} }

View File

@ -21,7 +21,9 @@ import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.Nd4jCuda; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueRandomGenerator;
import org.nd4j.rng.NativeRandom; import org.nd4j.rng.NativeRandom;
import java.util.List; import java.util.List;
@ -33,7 +35,7 @@ import java.util.List;
*/ */
@Slf4j @Slf4j
public class CudaNativeRandom extends NativeRandom { public class CudaNativeRandom extends NativeRandom {
private NativeOps nativeOps;
protected List<DataBuffer> stateBuffers; protected List<DataBuffer> stateBuffers;
public CudaNativeRandom() { public CudaNativeRandom() {
@ -50,10 +52,16 @@ public class CudaNativeRandom extends NativeRandom {
@Override @Override
public void init() { public void init() {
statePointer = new Nd4jCuda.RandomGenerator(seed, seed ^ 0xdeadbeef); nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
setSeed(seed); setSeed(seed);
} }
@Override
public void close() {
nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer);
}
@Override @Override
public PointerPointer getExtraPointers() { public PointerPointer getExtraPointers() {
return null; return null;
@ -63,7 +71,7 @@ public class CudaNativeRandom extends NativeRandom {
public void setSeed(long seed) { public void setSeed(long seed) {
this.seed = seed; this.seed = seed;
this.currentPosition.set(0); this.currentPosition.set(0);
((Nd4jCuda.RandomGenerator) statePointer).setStates(seed, seed ^ 0xdeadbeef); nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef);
} }
@Override @Override
@ -73,24 +81,24 @@ public class CudaNativeRandom extends NativeRandom {
@Override @Override
public int nextInt() { public int nextInt() {
return ((Nd4jCuda.RandomGenerator) statePointer).relativeInt(currentPosition.getAndIncrement()); return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
} }
@Override @Override
public long nextLong() { public long nextLong() {
return ((Nd4jCuda.RandomGenerator) statePointer).relativeLong(currentPosition.getAndIncrement()); return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
} }
public long rootState() { public long rootState() {
return ((Nd4jCuda.RandomGenerator) statePointer).rootState(); return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer);
} }
public long nodeState() { public long nodeState() {
return ((Nd4jCuda.RandomGenerator) statePointer).nodeState(); return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer);
} }
@Override @Override
public void setStates(long rootState, long nodeState) { public void setStates(long rootState, long nodeState) {
((Nd4jCuda.RandomGenerator) statePointer).setStates(rootState, nodeState); nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState);
} }
} }

View File

@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize);
* @param targetBuffer * @param targetBuffer
* @param offsetsBuffer * @param offsetsBuffer
*/ */
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo, public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
IntPointer dimension, IntPointer dimension,
int dimensionLength); int dimensionLength);
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo, public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
IntBuffer dimension, IntBuffer dimension,
int dimensionLength); int dimensionLength);
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo, public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
int[] dimension, int[] dimension,
int dimensionLength); int dimensionLength);
public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack);
public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack);
public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack);
public native int getShapeInfoLength(OpaqueTadPack pack);
public native void deleteTadPack(OpaqueTadPack ptr);
/* /*
* PullRow special op * PullRow special op
*/ */
@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length); public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length); public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
// flatbuffers execution // flatbuffers execution
public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr);
public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
public native @Cast("char*") String getAllCustomOps(); public native @Cast("char*") String getAllCustomOps();
@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext); public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
public native int getVariableId(OpaqueVariable variable);
public native int getVariableIndex(OpaqueVariable variable);
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable);
public native Pointer getVariableBuffer(OpaqueVariable variable);
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId); public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer); public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
// GraphState creation // GraphState creation
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*"
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); //void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length); public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length); public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
public native OpaqueContext createGraphContext(int nodeId);
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments);
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments);
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments);
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void deleteGraphContext(OpaqueContext ptr);
public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
public native OpaqueRandomGenerator createRandomGenerator();
public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr);
public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
public native void deleteRandomGenerator(OpaqueRandomGenerator ptr);
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
/**
* This method returns new array with the same shape & data type
* @return
*/
public native @ByVal NDArray like();
/**
* This method returns new uninitialized array with the same shape & data type
* @return
*/
public native @ByVal NDArray ulike();
/** /**
* this constructor creates new NDArray with shape matching "other" array, * this constructor creates new NDArray with shape matching "other" array,
* doesn't copy "other" elements into new array !!! * doesn't copy "other" elements into new array !!!

View File

@ -113,6 +113,14 @@ public class Nd4jCudaPresets implements InfoMapper {
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations()) "_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
.put(new Info("NativeOps.h").objectify()) .put(new Info("NativeOps.h").objectify())
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
.put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator"))
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
"@Cast(\"char*\") BytePointer")) "@Cast(\"char*\") BytePointer"))
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",

View File

@ -48,7 +48,6 @@ import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.BaseNativeNDArrayFactory; import org.nd4j.nativeblas.BaseNativeNDArrayFactory;
import org.nd4j.nativeblas.LongPointerWrapper; import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCpu;
import java.util.*; import java.util.*;

View File

@ -17,12 +17,18 @@
package org.nd4j.linalg.cpu.nativecpu.ops; package org.nd4j.linalg.cpu.nativecpu.ops;
import lombok.NonNull; import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.BaseOpContext;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.Nd4jCpu; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueContext;
import org.nd4j.nativeblas.OpaqueRandomGenerator;
import java.util.List; import java.util.List;
@ -33,46 +39,53 @@ import java.util.List;
*/ */
public class CpuOpContext extends BaseOpContext implements OpContext { public class CpuOpContext extends BaseOpContext implements OpContext {
// we might want to have configurable // we might want to have configurable
private Nd4jCpu.Context context = new Nd4jCpu.Context(1); private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
private OpaqueContext context = nativeOps.createGraphContext(1);
@Override
public void close() {
nativeOps.deleteGraphContext(context);
}
@Override @Override
public void setIArguments(long... arguments) { public void setIArguments(long... arguments) {
super.setIArguments(arguments); super.setIArguments(arguments);
context.setIArguments(arguments, arguments.length); nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
} }
@Override @Override
public void setBArguments(boolean... arguments) { public void setBArguments(boolean... arguments) {
super.setBArguments(arguments); super.setBArguments(arguments);
context.setBArguments(arguments, arguments.length); nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
} }
@Override @Override
public void setTArguments(double... arguments) { public void setTArguments(double... arguments) {
super.setTArguments(arguments); super.setTArguments(arguments);
context.setTArguments(arguments, arguments.length); nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
} }
@Override @Override
public void setRngStates(long rootState, long nodeState) { public void setRngStates(long rootState, long nodeState) {
context.randomGenerator().setStates(rootState, nodeState); nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
} }
@Override @Override
public Pair<Long, Long> getRngStates() { public Pair<Long, Long> getRngStates() {
return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState()); OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
} }
@Override @Override
public void setInputArray(int index, @NonNull INDArray array) { public void setInputArray(int index, @NonNull INDArray array) {
context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
super.setInputArray(index, array); super.setInputArray(index, array);
} }
@Override @Override
public void setOutputArray(int index, @NonNull INDArray array) { public void setOutputArray(int index, @NonNull INDArray array) {
context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
super.setOutputArray(index, array); super.setOutputArray(index, array);
} }
@ -84,6 +97,6 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
@Override @Override
public void markInplace(boolean reallyInplace) { public void markInplace(boolean reallyInplace) {
context.markInplace(reallyInplace); nativeOps.markGraphContextInplace(context, reallyInplace);
} }
} }

View File

@ -70,10 +70,14 @@ import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCpu; import org.nd4j.nativeblas.Nd4jCpu;
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariableSet;
import java.util.*; import java.util.*;
/** /**
* *
* Native operation * Native operation
@ -1641,23 +1645,22 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
} }
val name = op.opName(); val name = op.opName();
val context = buildContext(); try (val context = buildContext()) {
context.markInplace(op.isInplaceCall()); context.markInplace(op.isInplaceCall());
// transferring rng state // transferring rng state
context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState()); context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
//transferring input/output arrays //transferring input/output arrays
context.setInputArrays(op.inputArguments()); context.setInputArrays(op.inputArguments());
context.setOutputArrays(op.outputArguments()); context.setOutputArrays(op.outputArguments());
// transferring static args // transferring static args
context.setBArguments(op.bArgs()); context.setBArguments(op.bArgs());
context.setIArguments(op.iArgs()); context.setIArguments(op.iArgs());
context.setTArguments(op.tArgs()); context.setTArguments(op.tArgs());
try {
val result = exec(op, context); val result = exec(op, context);
val states = context.getRngStates(); val states = context.getRngStates();
@ -1860,9 +1863,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
for (val t: tArgs1) for (val t: tArgs1)
tArgs.put(cnt++, t); tArgs.put(cnt++, t);
Nd4jCpu.ShapeList ptrptr; OpaqueShapeList ptrptr;
try { try {
ptrptr = (Nd4jCpu.ShapeList) loop.calculateOutputShapes2(null, ptrptr = loop.calculateOutputShapes2(null,
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments()); op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
} catch (Throwable t){ } catch (Throwable t){
@ -1891,8 +1894,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
if (ptrptr == null) if (ptrptr == null)
throw new RuntimeException(); throw new RuntimeException();
for (int e = 0; e < ptrptr.size(); e++ ) for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ )
result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer())); result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer()));
loop.deleteShapeList(ptrptr); loop.deleteShapeList(ptrptr);
@ -1947,19 +1950,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
val newMap = new LinkedHashMap<String, INDArray>(); val newMap = new LinkedHashMap<String, INDArray>();
val result = (Nd4jCpu.VariablesSet) loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
val status = OpStatus.byNumber(result.status()); OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result));
if (status != OpStatus.ND4J_STATUS_OK) if (status != OpStatus.ND4J_STATUS_OK)
throw new ND4JIllegalStateException("Op execution failed: " + status); throw new ND4JIllegalStateException("Op execution failed: " + status);
for (int e = 0; e < result.size(); e++) { for (int e = 0; e < loop.getVariableSetSize(result); e++) {
val var = result.at(e); OpaqueVariable var = loop.getVariable(result, e);
val nodeId = var.id(); int nodeId = loop.getVariableId(var);
val index = var.index(); int index = loop.getVariableIndex(var);
val shapeInfo = var.getNDArray().shapeInfo(); LongPointer shapeInfo = loop.getVariableShape(var);
val buffer = var.getNDArray().buffer(); Pointer buffer = loop.getVariableBuffer(var);
val rank = (int) shapeInfo.get(0); val rank = (int) shapeInfo.get(0);
val jshape = new long[rank * 2 + 4]; val jshape = new long[rank * 2 + 4];
@ -1979,7 +1982,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(array.dataType()), MemcpyDirection.HOST_TO_HOST); PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(array.dataType()), MemcpyDirection.HOST_TO_HOST);
//newMap.put(keySet.get(nodeId), array); //newMap.put(keySet.get(nodeId), array);
val nodeName = var.getName().getString(); String nodeName = loop.getVariableName(var);
newMap.put(nodeName, array); newMap.put(nodeName, array);
} }
@ -2160,9 +2163,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override @Override
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
val dbf = (Nd4jCpu.ConstantDataBuffer) loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
val result = new LongBuffer(dbf.primary(), Shape.shapeInfoLength(shape.length)); val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length));
loop.deleteShapeBuffer(dbf); loop.deleteShapeBuffer(dbf);
@ -2171,10 +2174,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
@Override @Override
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
val pack = (Nd4jCpu.TadPack) loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
val tadShape = new LongBuffer(pack.primaryShapeInfo(), pack.shapeInfoLength()); val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack));
val tadOffsets = new LongBuffer(pack.primaryOffsets(), pack.numberOfTads()); val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack));
loop.deleteTadPack(pack); loop.deleteTadPack(pack);

View File

@ -18,7 +18,9 @@ package org.nd4j.linalg.cpu.nativecpu.rng;
import org.bytedeco.javacpp.PointerPointer; import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.nativeblas.Nd4jCpu; import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueRandomGenerator;
import org.nd4j.rng.NativeRandom; import org.nd4j.rng.NativeRandom;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
@ -29,6 +31,8 @@ import java.util.concurrent.atomic.AtomicLong;
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
public class CpuNativeRandom extends NativeRandom { public class CpuNativeRandom extends NativeRandom {
private NativeOps nativeOps;
public CpuNativeRandom() { public CpuNativeRandom() {
super(); super();
} }
@ -43,7 +47,13 @@ public class CpuNativeRandom extends NativeRandom {
@Override @Override
public void init() { public void init() {
statePointer = new Nd4jCpu.RandomGenerator(this.seed, this.seed ^ 0xdeadbeef); nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
}
@Override
public void close() {
nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer);
} }
@Override @Override
@ -55,7 +65,7 @@ public class CpuNativeRandom extends NativeRandom {
public void setSeed(long seed) { public void setSeed(long seed) {
this.seed = seed; this.seed = seed;
this.currentPosition.set(0); this.currentPosition.set(0);
((Nd4jCpu.RandomGenerator)statePointer).setStates(seed, seed ^ 0xdeadbeef); nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef);
} }
@Override @Override
@ -65,24 +75,24 @@ public class CpuNativeRandom extends NativeRandom {
@Override @Override
public int nextInt() { public int nextInt() {
return ((Nd4jCpu.RandomGenerator)statePointer).relativeInt(currentPosition.getAndIncrement()); return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
} }
@Override @Override
public long nextLong() { public long nextLong() {
return ((Nd4jCpu.RandomGenerator)statePointer).relativeLong(currentPosition.getAndIncrement()); return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
} }
public long rootState() { public long rootState() {
return ((Nd4jCpu.RandomGenerator) statePointer).rootState(); return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer);
} }
public long nodeState() { public long nodeState() {
return ((Nd4jCpu.RandomGenerator) statePointer).nodeState(); return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer);
} }
@Override @Override
public void setStates(long rootState, long nodeState) { public void setStates(long rootState, long nodeState) {
((Nd4jCpu.RandomGenerator) statePointer).setStates(rootState, nodeState); nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState);
} }
} }

View File

@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize);
* @param targetBuffer * @param targetBuffer
* @param offsetsBuffer * @param offsetsBuffer
*/ */
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo, public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
IntPointer dimension, IntPointer dimension,
int dimensionLength); int dimensionLength);
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo, public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
IntBuffer dimension, IntBuffer dimension,
int dimensionLength); int dimensionLength);
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo, public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
int[] dimension, int[] dimension,
int dimensionLength); int dimensionLength);
public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack);
public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack);
public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack);
public native int getShapeInfoLength(OpaqueTadPack pack);
public native void deleteTadPack(OpaqueTadPack ptr);
/* /*
* PullRow special op * PullRow special op
*/ */
@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length); public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length); public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
// flatbuffers execution // flatbuffers execution
public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr);
public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
public native @Cast("char*") String getAllCustomOps(); public native @Cast("char*") String getAllCustomOps();
@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext); public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
public native int getVariableId(OpaqueVariable variable);
public native int getVariableIndex(OpaqueVariable variable);
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable);
public native Pointer getVariableBuffer(OpaqueVariable variable);
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId); public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer); public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
// GraphState creation // GraphState creation
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*"
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); //void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length); public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length); public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length); public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
public native OpaqueContext createGraphContext(int nodeId);
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments);
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments);
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments);
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void deleteGraphContext(OpaqueContext ptr);
public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
public native OpaqueRandomGenerator createRandomGenerator();
public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr);
public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
public native void deleteRandomGenerator(OpaqueRandomGenerator ptr);
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
/**
* This method returns new array with the same shape & data type
* @return
*/
public native @ByVal NDArray like();
/**
* This method returns new uninitialized array with the same shape & data type
* @return
*/
public native @ByVal NDArray ulike();
/** /**
* this constructor creates new NDArray with shape matching "other" array, * this constructor creates new NDArray with shape matching "other" array,
* doesn't copy "other" elements into new array !!! * doesn't copy "other" elements into new array !!!

View File

@ -156,6 +156,14 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations()) "_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
.put(new Info("NativeOps.h").objectify()) .put(new Info("NativeOps.h").objectify())
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
.put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator"))
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
"@Cast(\"char*\") BytePointer")) "@Cast(\"char*\") BytePointer"))
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",

View File

@ -442,7 +442,7 @@ public class CustomOpsTests extends BaseNd4jTest {
context.setOutputArray(0, arrayZ); context.setOutputArray(0, arrayZ);
val addOp = new AddOp(); val addOp = new AddOp();
NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp(null, addOp.opHash(), context.contextPointer()); NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp2(null, addOp.opHash(), context.contextPointer());
assertEquals(exp, arrayZ); assertEquals(exp, arrayZ);
} }