Create C wrappers for some of the C++ classes currently used by ND4J
parent
8881bfe7aa
commit
526b782e51
|
@ -882,6 +882,8 @@ ND4J_EXPORT void enableVerboseMode(bool reallyEnable);
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT void setGridLimit(int gridSize);
|
ND4J_EXPORT void setGridLimit(int gridSize);
|
||||||
|
|
||||||
|
typedef nd4j::TadPack OpaqueTadPack;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param xShapeInfo
|
* @param xShapeInfo
|
||||||
|
@ -890,10 +892,19 @@ ND4J_EXPORT void setGridLimit(int gridSize);
|
||||||
* @param targetBuffer
|
* @param targetBuffer
|
||||||
* @param offsetsBuffer
|
* @param offsetsBuffer
|
||||||
*/
|
*/
|
||||||
ND4J_EXPORT nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo,
|
ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo,
|
||||||
int *dimension,
|
int *dimension,
|
||||||
int dimensionLength);
|
int dimensionLength);
|
||||||
|
|
||||||
|
ND4J_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack);
|
||||||
|
ND4J_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack);
|
||||||
|
ND4J_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack);
|
||||||
|
ND4J_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack);
|
||||||
|
ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack);
|
||||||
|
ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack);
|
||||||
|
|
||||||
|
ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* PullRow special op
|
* PullRow special op
|
||||||
*/
|
*/
|
||||||
|
@ -1639,10 +1650,13 @@ ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName,
|
||||||
|
|
||||||
ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length);
|
ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length);
|
||||||
|
|
||||||
|
typedef nd4j::graph::ResultWrapper OpaqueResultWrapper;
|
||||||
|
|
||||||
// flatbuffers execution
|
// flatbuffers execution
|
||||||
ND4J_EXPORT nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer);
|
ND4J_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer);
|
||||||
|
|
||||||
|
ND4J_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr);
|
||||||
|
ND4J_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr);
|
||||||
|
|
||||||
ND4J_EXPORT const char* getAllCustomOps();
|
ND4J_EXPORT const char* getAllCustomOps();
|
||||||
|
|
||||||
|
@ -1652,14 +1666,31 @@ ND4J_EXPORT const char* getAllOperations();
|
||||||
ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace);
|
ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace);
|
||||||
ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext);
|
ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext);
|
||||||
|
|
||||||
ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
typedef nd4j::ShapeList OpaqueShapeList;
|
||||||
ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
|
||||||
|
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs);
|
||||||
|
ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs);
|
||||||
|
|
||||||
|
ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list);
|
||||||
|
ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i);
|
||||||
|
|
||||||
ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
|
ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList);
|
||||||
|
|
||||||
ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer);
|
ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer);
|
||||||
|
|
||||||
ND4J_EXPORT nd4j::graph::VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
|
typedef nd4j::graph::VariablesSet OpaqueVariableSet;
|
||||||
|
typedef nd4j::graph::Variable OpaqueVariable;
|
||||||
|
|
||||||
|
ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs);
|
||||||
|
|
||||||
|
ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set);
|
||||||
|
ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set);
|
||||||
|
ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i);
|
||||||
|
ND4J_EXPORT int getVariableId(OpaqueVariable* variable);
|
||||||
|
ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable);
|
||||||
|
ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable);
|
||||||
|
ND4J_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable);
|
||||||
|
ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable);
|
||||||
|
|
||||||
ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId);
|
ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId);
|
||||||
|
|
||||||
|
@ -1668,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer);
|
||||||
ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer);
|
ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer);
|
||||||
ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer);
|
ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer);
|
||||||
|
|
||||||
ND4J_EXPORT void deleteVariablesSet(Nd4jPointer pointer);
|
ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer);
|
||||||
|
|
||||||
// GraphState creation
|
// GraphState creation
|
||||||
ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id);
|
ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id);
|
||||||
|
@ -1684,7 +1715,9 @@ ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPoi
|
||||||
|
|
||||||
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
||||||
ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length);
|
ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length);
|
||||||
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
ND4J_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
||||||
|
ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
||||||
|
ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr);
|
||||||
|
|
||||||
ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
|
ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs,
|
||||||
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
|
void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets,
|
||||||
|
@ -1693,18 +1726,45 @@ ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOf
|
||||||
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
|
void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets,
|
||||||
int* hIindexes, int* dIindexes);
|
int* hIindexes, int* dIindexes);
|
||||||
|
|
||||||
ND4J_EXPORT void deleteShapeBuffer(Nd4jPointer ptr);
|
|
||||||
ND4J_EXPORT void deleteTadPack(Nd4jPointer ptr);
|
|
||||||
|
|
||||||
ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
|
ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo);
|
||||||
|
|
||||||
|
|
||||||
ND4J_EXPORT nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
|
typedef nd4j::ConstantDataBuffer OpaqueConstantDataBuffer;
|
||||||
|
|
||||||
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length);
|
ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty);
|
||||||
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
|
|
||||||
ND4J_EXPORT nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
|
||||||
|
|
||||||
|
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length);
|
||||||
|
ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length);
|
||||||
|
ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor);
|
||||||
|
|
||||||
|
ND4J_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf);
|
||||||
|
ND4J_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf);
|
||||||
|
ND4J_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf);
|
||||||
|
ND4J_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf);
|
||||||
|
|
||||||
|
ND4J_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr);
|
||||||
|
|
||||||
|
typedef nd4j::graph::Context OpaqueContext;
|
||||||
|
typedef nd4j::graph::RandomGenerator OpaqueRandomGenerator;
|
||||||
|
|
||||||
|
ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId);
|
||||||
|
ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr);
|
||||||
|
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
|
||||||
|
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
||||||
|
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
|
ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
|
ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments);
|
||||||
|
ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments);
|
||||||
|
ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments);
|
||||||
|
ND4J_EXPORT void deleteGraphContext(OpaqueContext* ptr);
|
||||||
|
|
||||||
|
ND4J_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0);
|
||||||
|
ND4J_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr);
|
||||||
|
ND4J_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr);
|
||||||
|
ND4J_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0);
|
||||||
|
ND4J_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index);
|
||||||
|
ND4J_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index);
|
||||||
|
ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr);
|
||||||
|
|
||||||
ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut);
|
ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut);
|
||||||
ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut);
|
ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut);
|
||||||
|
|
|
@ -1328,6 +1328,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimen
|
||||||
return pack;
|
return pack;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) {
|
||||||
|
return pack->primaryShapeInfo();
|
||||||
|
}
|
||||||
|
Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) {
|
||||||
|
return pack->primaryOffsets();
|
||||||
|
}
|
||||||
|
Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) {
|
||||||
|
return pack->specialShapeInfo();
|
||||||
|
}
|
||||||
|
Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) {
|
||||||
|
return pack->specialOffsets();
|
||||||
|
}
|
||||||
|
Nd4jLong getNumberOfTads(nd4j::TadPack* pack) {
|
||||||
|
return pack->numberOfTads();
|
||||||
|
}
|
||||||
|
int getShapeInfoLength(nd4j::TadPack* pack) {
|
||||||
|
return pack->shapeInfoLength();
|
||||||
|
}
|
||||||
|
|
||||||
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||||
// no-op
|
// no-op
|
||||||
return 0L;
|
return 0L;
|
||||||
|
@ -2005,6 +2024,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi
|
||||||
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
|
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) {
|
||||||
|
return ptr->size();
|
||||||
|
}
|
||||||
|
Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) {
|
||||||
|
return ptr->pointer();
|
||||||
|
}
|
||||||
|
|
||||||
const char* getAllCustomOps() {
|
const char* getAllCustomOps() {
|
||||||
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
||||||
}
|
}
|
||||||
|
@ -2041,7 +2067,13 @@ int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong *hXSh
|
||||||
BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getShapeListSize(nd4j::ShapeList* list) {
|
||||||
|
return list->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) {
|
||||||
|
return list->at(i);
|
||||||
|
}
|
||||||
|
|
||||||
void deleteShapeList(Nd4jPointer shapeList) {
|
void deleteShapeList(Nd4jPointer shapeList) {
|
||||||
auto list = reinterpret_cast<nd4j::ShapeList*>(shapeList);
|
auto list = reinterpret_cast<nd4j::ShapeList*>(shapeList);
|
||||||
|
@ -2305,6 +2337,38 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
|
||||||
|
return set->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
|
||||||
|
return set->status();
|
||||||
|
}
|
||||||
|
|
||||||
|
nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) {
|
||||||
|
return set->at(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVariableId(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->id();
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVariableIndex(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->index();
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getVariableName(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->getName()->c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->getNDArray()->shapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
void* getVariableBuffer(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->getNDArray()->buffer();
|
||||||
|
}
|
||||||
|
|
||||||
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
||||||
|
|
||||||
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
||||||
|
@ -2628,6 +2692,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int
|
||||||
return reinterpret_cast<Nd4jPointer>(u);
|
return reinterpret_cast<Nd4jPointer>(u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||||
|
return reinterpret_cast<nd4j::utf8string*>(ptr)->_length;
|
||||||
|
}
|
||||||
|
char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||||
|
return reinterpret_cast<nd4j::utf8string*>(ptr)->_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||||
delete(reinterpret_cast<nd4j::utf8string*>(ptr));
|
delete(reinterpret_cast<nd4j::utf8string*>(ptr));
|
||||||
}
|
}
|
||||||
|
@ -2710,14 +2781,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void deleteShapeBuffer(Nd4jPointer ptr) {
|
void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) {
|
||||||
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr);
|
delete ptr;
|
||||||
delete buffer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void deleteTadPack(Nd4jPointer ptr) {
|
void deleteTadPack(nd4j::TadPack* ptr) {
|
||||||
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr);
|
delete ptr;
|
||||||
delete buffer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||||
|
@ -2732,6 +2801,78 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes
|
||||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->primary();
|
||||||
|
}
|
||||||
|
Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->special();
|
||||||
|
}
|
||||||
|
Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->length();
|
||||||
|
}
|
||||||
|
Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->sizeOf();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::graph::Context* createGraphContext(int nodeId) {
|
||||||
|
return new nd4j::graph::Context(nodeId);
|
||||||
|
}
|
||||||
|
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
|
||||||
|
return &ptr->randomGenerator();
|
||||||
|
}
|
||||||
|
void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) {
|
||||||
|
ptr->markInplace(reallyInplace);
|
||||||
|
}
|
||||||
|
void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) {
|
||||||
|
}
|
||||||
|
void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||||
|
ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||||
|
}
|
||||||
|
void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||||
|
ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||||
|
}
|
||||||
|
void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) {
|
||||||
|
ptr->setTArguments(arguments, numberOfArguments);
|
||||||
|
}
|
||||||
|
void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) {
|
||||||
|
ptr->setIArguments(arguments, numberOfArguments);
|
||||||
|
}
|
||||||
|
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
||||||
|
ptr->setBArguments(arguments, numberOfArguments);
|
||||||
|
}
|
||||||
|
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||||
|
delete ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||||
|
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
|
||||||
|
return ptr->rootState();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) {
|
||||||
|
return ptr->nodeState();
|
||||||
|
}
|
||||||
|
|
||||||
|
void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||||
|
ptr->setStates(rootSeed, nodeSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||||
|
return ptr->relativeInt(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||||
|
return ptr->relativeLong(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) {
|
||||||
|
delete ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
int dataTypeFromNpyHeader(void *header) {
|
int dataTypeFromNpyHeader(void *header) {
|
||||||
|
|
|
@ -1499,6 +1499,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimen
|
||||||
return pack;
|
return pack;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) {
|
||||||
|
return pack->primaryShapeInfo();
|
||||||
|
}
|
||||||
|
Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) {
|
||||||
|
return pack->primaryOffsets();
|
||||||
|
}
|
||||||
|
Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) {
|
||||||
|
return pack->specialShapeInfo();
|
||||||
|
}
|
||||||
|
Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) {
|
||||||
|
return pack->specialOffsets();
|
||||||
|
}
|
||||||
|
Nd4jLong getNumberOfTads(nd4j::TadPack* pack) {
|
||||||
|
return pack->numberOfTads();
|
||||||
|
}
|
||||||
|
int getShapeInfoLength(nd4j::TadPack* pack) {
|
||||||
|
return pack->shapeInfoLength();
|
||||||
|
}
|
||||||
|
|
||||||
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
||||||
|
|
||||||
|
@ -2533,6 +2552,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi
|
||||||
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
|
return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) {
|
||||||
|
return ptr->size();
|
||||||
|
}
|
||||||
|
Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) {
|
||||||
|
return ptr->pointer();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
const char* getAllCustomOps() {
|
const char* getAllCustomOps() {
|
||||||
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations();
|
||||||
|
@ -2607,6 +2633,13 @@ nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash
|
||||||
return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs);
|
return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getShapeListSize(nd4j::ShapeList* list) {
|
||||||
|
return list->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) {
|
||||||
|
return list->at(i);
|
||||||
|
}
|
||||||
|
|
||||||
static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
|
static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) {
|
||||||
if (op == nullptr)
|
if (op == nullptr)
|
||||||
|
@ -2775,6 +2808,38 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N
|
||||||
return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs);
|
return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) {
|
||||||
|
return set->size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) {
|
||||||
|
return set->status();
|
||||||
|
}
|
||||||
|
|
||||||
|
nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) {
|
||||||
|
return set->at(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVariableId(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->id();
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVariableIndex(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->index();
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getVariableName(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->getName()->c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->getNDArray()->shapeInfo();
|
||||||
|
}
|
||||||
|
|
||||||
|
void* getVariableBuffer(nd4j::graph::Variable* variable) {
|
||||||
|
return variable->getNDArray()->buffer();
|
||||||
|
}
|
||||||
|
|
||||||
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) {
|
||||||
|
|
||||||
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId);
|
||||||
|
@ -3102,6 +3167,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int
|
||||||
return reinterpret_cast<Nd4jPointer>(u);
|
return reinterpret_cast<Nd4jPointer>(u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||||
|
return reinterpret_cast<nd4j::utf8string*>(ptr)->_length;
|
||||||
|
}
|
||||||
|
char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||||
|
return reinterpret_cast<nd4j::utf8string*>(ptr)->_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) {
|
||||||
delete(reinterpret_cast<nd4j::utf8string*>(ptr));
|
delete(reinterpret_cast<nd4j::utf8string*>(ptr));
|
||||||
}
|
}
|
||||||
|
@ -3237,14 +3309,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
void deleteShapeBuffer(Nd4jPointer ptr) {
|
void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) {
|
||||||
auto buffer = reinterpret_cast<nd4j::ConstantDataBuffer*>(ptr);
|
delete ptr;
|
||||||
delete buffer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void deleteTadPack(Nd4jPointer ptr) {
|
void deleteTadPack(nd4j::TadPack* ptr) {
|
||||||
auto buffer = reinterpret_cast<nd4j::TadPack*>(ptr);
|
delete ptr;
|
||||||
delete buffer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||||
|
@ -3259,6 +3329,82 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes
|
||||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->primary();
|
||||||
|
}
|
||||||
|
Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->special();
|
||||||
|
}
|
||||||
|
Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->length();
|
||||||
|
}
|
||||||
|
Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) {
|
||||||
|
return dbf->sizeOf();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::graph::Context* createGraphContext(int nodeId) {
|
||||||
|
return new nd4j::graph::Context(nodeId);
|
||||||
|
}
|
||||||
|
nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) {
|
||||||
|
return &ptr->randomGenerator();
|
||||||
|
}
|
||||||
|
void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) {
|
||||||
|
ptr->markInplace(reallyInplace);
|
||||||
|
}
|
||||||
|
void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) {
|
||||||
|
ptr->setCudaContext(stream, reductionPointer, allocationPointer);
|
||||||
|
}
|
||||||
|
void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||||
|
ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||||
|
}
|
||||||
|
void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) {
|
||||||
|
ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo);
|
||||||
|
}
|
||||||
|
void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) {
|
||||||
|
ptr->setTArguments(arguments, numberOfArguments);
|
||||||
|
}
|
||||||
|
void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) {
|
||||||
|
ptr->setIArguments(arguments, numberOfArguments);
|
||||||
|
}
|
||||||
|
void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) {
|
||||||
|
ptr->setBArguments(arguments, numberOfArguments);
|
||||||
|
}
|
||||||
|
void deleteGraphContext(nd4j::graph::Context* ptr) {
|
||||||
|
delete ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||||
|
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) {
|
||||||
|
return ptr->rootState();
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) {
|
||||||
|
return ptr->nodeState();
|
||||||
|
}
|
||||||
|
|
||||||
|
void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||||
|
ptr->setStates(rootSeed, nodeSeed);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||||
|
return ptr->relativeInt(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) {
|
||||||
|
return ptr->relativeLong(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) {
|
||||||
|
delete ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) {
|
||||||
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast<char *>(npyArray));
|
||||||
unsigned int shapeSize = arr.shape.size();
|
unsigned int shapeSize = arr.shape.size();
|
||||||
|
|
|
@ -27,6 +27,9 @@
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <graph/GraphUtils.h>
|
#include <graph/GraphUtils.h>
|
||||||
|
|
||||||
|
using namespace nd4j::ops;
|
||||||
|
using namespace nd4j::graph;
|
||||||
|
|
||||||
int
|
int
|
||||||
main(int argc, char *argv[]) {
|
main(int argc, char *argv[]) {
|
||||||
// this string will contain list of operations
|
// this string will contain list of operations
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
using namespace nd4j::ops;
|
||||||
using namespace nd4j::graph;
|
using namespace nd4j::graph;
|
||||||
|
|
||||||
class OpTrackerTests : public testing::Test {
|
class OpTrackerTests : public testing::Test {
|
||||||
|
|
|
@ -28,7 +28,7 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
public interface OpContext {
|
public interface OpContext extends AutoCloseable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method sets integer arguments required for operation
|
* This method sets integer arguments required for operation
|
||||||
|
|
|
@ -38,8 +38,9 @@ import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.nativeblas.ResultWrapperAbstraction;
|
import org.nd4j.nativeblas.OpaqueResultWrapper;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
|
@ -100,11 +101,12 @@ public class NativeGraphExecutioner implements GraphExecutioner {
|
||||||
|
|
||||||
log.info("Buffer length: {}", buffer.limit());
|
log.info("Buffer length: {}", buffer.limit());
|
||||||
|
|
||||||
val res = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraph(null, bPtr);
|
NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
|
OpaqueResultWrapper res = nativeOps.executeFlatGraph(null, bPtr);
|
||||||
if (res == null)
|
if (res == null)
|
||||||
throw new ND4JIllegalStateException("Graph execution failed");
|
throw new ND4JIllegalStateException("Graph execution failed");
|
||||||
|
|
||||||
PagedPointer pagedPointer = new PagedPointer(res.pointer(),res.size());
|
PagedPointer pagedPointer = new PagedPointer(nativeOps.getResultWrapperPointer(res), nativeOps.getResultWrapperSize(res));
|
||||||
FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer());
|
FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer());
|
||||||
|
|
||||||
log.info("VarMap: {}", sd.variableMap());
|
log.info("VarMap: {}", sd.variableMap());
|
||||||
|
@ -132,7 +134,7 @@ public class NativeGraphExecutioner implements GraphExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
// now we need to release native memory
|
// now we need to release native memory
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().deleteResultWrapper(res);
|
nativeOps.deleteResultWrapper(res);
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
|
@ -697,7 +697,16 @@ public interface NativeOps {
|
||||||
|
|
||||||
void setGridLimit(int gridSize);
|
void setGridLimit(int gridSize);
|
||||||
|
|
||||||
Pointer tadOnlyShapeInfo(@Cast("Nd4jLong *") LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
OpaqueTadPack tadOnlyShapeInfo(LongPointer shapeInfo, IntPointer dimension, int dimensionLength);
|
||||||
|
|
||||||
|
LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
|
||||||
|
LongPointer getPrimaryOffsets(OpaqueTadPack pack);
|
||||||
|
LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
|
||||||
|
LongPointer getSpecialOffsets(OpaqueTadPack pack);
|
||||||
|
long getNumberOfTads(OpaqueTadPack pack);
|
||||||
|
int getShapeInfoLength(OpaqueTadPack pack);
|
||||||
|
|
||||||
|
void deleteTadPack(OpaqueTadPack pointer);
|
||||||
|
|
||||||
///////////////
|
///////////////
|
||||||
|
|
||||||
|
@ -1037,7 +1046,10 @@ public interface NativeOps {
|
||||||
|
|
||||||
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);
|
void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length);
|
||||||
|
|
||||||
ResultWrapperAbstraction executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer);
|
OpaqueResultWrapper executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer);
|
||||||
|
|
||||||
|
long getResultWrapperSize(OpaqueResultWrapper ptr);
|
||||||
|
Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
|
||||||
|
|
||||||
String getAllCustomOps();
|
String getAllCustomOps();
|
||||||
|
|
||||||
|
@ -1047,13 +1059,25 @@ public interface NativeOps {
|
||||||
|
|
||||||
int execCustomOp(PointerPointer extraPointers, long opHashCode, PointerPointer inputBuffers, PointerPointer inputShapes, int numInput, PointerPointer outputBuffers, PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, boolean isInplace);
|
int execCustomOp(PointerPointer extraPointers, long opHashCode, PointerPointer inputBuffers, PointerPointer inputShapes, int numInput, PointerPointer outputBuffers, PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, boolean isInplace);
|
||||||
|
|
||||||
Pointer calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
|
OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs);
|
||||||
|
|
||||||
Pointer calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
|
OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs);
|
||||||
|
|
||||||
|
long getShapeListSize(OpaqueShapeList list);
|
||||||
|
LongPointer getShape(OpaqueShapeList list, long i);
|
||||||
|
|
||||||
int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer);
|
int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer);
|
||||||
|
|
||||||
Pointer executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||||
|
|
||||||
|
long getVariableSetSize(OpaqueVariableSet set);
|
||||||
|
int getVariableSetStatus(OpaqueVariableSet set);
|
||||||
|
OpaqueVariable getVariable(OpaqueVariableSet set, long i);
|
||||||
|
int getVariableId(OpaqueVariable variable);
|
||||||
|
int getVariableIndex(OpaqueVariable variable);
|
||||||
|
String getVariableName(OpaqueVariable variable);
|
||||||
|
LongPointer getVariableShape(OpaqueVariable variable);
|
||||||
|
Pointer getVariableBuffer(OpaqueVariable variable);
|
||||||
|
|
||||||
void deleteResultWrapper(Pointer ptr);
|
void deleteResultWrapper(Pointer ptr);
|
||||||
|
|
||||||
|
@ -1071,15 +1095,11 @@ public interface NativeOps {
|
||||||
|
|
||||||
void deleteNPArrayMap(Pointer pointer);
|
void deleteNPArrayMap(Pointer pointer);
|
||||||
|
|
||||||
void deleteVariablesSet(Pointer pointer);
|
void deleteVariablesSet(OpaqueVariableSet pointer);
|
||||||
|
|
||||||
// GraphState creation
|
// GraphState creation
|
||||||
Pointer getGraphState(long id);
|
Pointer getGraphState(long id);
|
||||||
|
|
||||||
void deleteShapeBuffer(Pointer state);
|
|
||||||
|
|
||||||
void deleteTadPack(Pointer pointer);
|
|
||||||
|
|
||||||
void deleteGraphState(Pointer state);
|
void deleteGraphState(Pointer state);
|
||||||
|
|
||||||
int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold);
|
int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold);
|
||||||
|
@ -1096,6 +1116,8 @@ public interface NativeOps {
|
||||||
|
|
||||||
//void fillUtf8String(PointerPointer extraPointers, String[] string, int numStrings, Pointer buffer);
|
//void fillUtf8String(PointerPointer extraPointers, String[] string, int numStrings, Pointer buffer);
|
||||||
Pointer createUtf8String(PointerPointer extraPointers, String string, int length);
|
Pointer createUtf8String(PointerPointer extraPointers, String string, int length);
|
||||||
|
long getUtf8StringLength(PointerPointer extraPointers, Pointer ptr);
|
||||||
|
BytePointer getUtf8StringBuffer(PointerPointer extraPointers, Pointer ptr);
|
||||||
void deleteUtf8String(PointerPointer extraPointers, Pointer ptr);
|
void deleteUtf8String(PointerPointer extraPointers, Pointer ptr);
|
||||||
|
|
||||||
|
|
||||||
|
@ -1116,11 +1138,37 @@ public interface NativeOps {
|
||||||
*/
|
*/
|
||||||
int dataTypeFromNpyHeader(Pointer numpyHeader);
|
int dataTypeFromNpyHeader(Pointer numpyHeader);
|
||||||
|
|
||||||
Pointer shapeBuffer(int rank, @Cast("Nd4jLong *") LongPointer shape, @Cast("Nd4jLong *") LongPointer strides, int dtype, char order, long ews, boolean empty);
|
OpaqueConstantDataBuffer shapeBuffer(int rank, LongPointer shape, LongPointer strides, int dtype, char order, long ews, boolean empty);
|
||||||
|
|
||||||
Pointer constantBufferDouble(int dtype, DoublePointer data, int length);
|
OpaqueConstantDataBuffer constantBufferDouble(int dtype, DoublePointer data, int length);
|
||||||
|
|
||||||
Pointer constantBufferLong(int dtype, @Cast("Nd4jLong *") LongPointer data, int length);
|
OpaqueConstantDataBuffer constantBufferLong(int dtype, LongPointer data, int length);
|
||||||
|
|
||||||
|
Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
|
||||||
|
Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
|
||||||
|
long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
|
||||||
|
long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
|
||||||
|
|
||||||
|
void deleteShapeBuffer(OpaqueConstantDataBuffer state);
|
||||||
|
|
||||||
|
OpaqueContext createGraphContext(int nodeId);
|
||||||
|
OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||||
|
void markGraphContextInplace(OpaqueContext ptr, boolean reallyInplace);
|
||||||
|
void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||||
|
void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
|
void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
|
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||||
|
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
|
||||||
|
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
|
||||||
|
void deleteGraphContext(OpaqueContext ptr);
|
||||||
|
|
||||||
|
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);
|
||||||
|
long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
|
||||||
|
long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
|
||||||
|
void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||||
|
int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||||
|
long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||||
|
void deleteRandomGenerator(OpaqueRandomGenerator ptr);
|
||||||
|
|
||||||
String runLightBenchmarkSuit(boolean printOut);
|
String runLightBenchmarkSuit(boolean printOut);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueConstantDataBuffer extends Pointer {
|
||||||
|
public OpaqueConstantDataBuffer(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueContext extends Pointer {
|
||||||
|
public OpaqueContext(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueRandomGenerator extends Pointer {
|
||||||
|
public OpaqueRandomGenerator(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueResultWrapper extends Pointer {
|
||||||
|
public OpaqueResultWrapper(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueShapeList extends Pointer {
|
||||||
|
public OpaqueShapeList(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueTadPack extends Pointer {
|
||||||
|
public OpaqueTadPack(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueVariable extends Pointer {
|
||||||
|
public OpaqueVariable(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.nativeblas;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @author saudet
|
||||||
|
*/
|
||||||
|
public class OpaqueVariableSet extends Pointer {
|
||||||
|
public OpaqueVariableSet(Pointer p) { super(p); }
|
||||||
|
}
|
|
@ -72,6 +72,11 @@ import org.nd4j.nativeblas.LongPointerWrapper;
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.nativeblas.Nd4jCuda;
|
import org.nd4j.nativeblas.Nd4jCuda;
|
||||||
|
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
|
||||||
|
import org.nd4j.nativeblas.OpaqueShapeList;
|
||||||
|
import org.nd4j.nativeblas.OpaqueTadPack;
|
||||||
|
import org.nd4j.nativeblas.OpaqueVariable;
|
||||||
|
import org.nd4j.nativeblas.OpaqueVariableSet;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -2208,13 +2213,13 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
for (val t: op.tArgs())
|
for (val t: op.tArgs())
|
||||||
tArgs.put(cnt++, (float) t);
|
tArgs.put(cnt++, (float) t);
|
||||||
|
|
||||||
val ptrptr = (Nd4jCuda.ShapeList) nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
|
OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments());
|
||||||
|
|
||||||
if (ptrptr == null)
|
if (ptrptr == null)
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
|
||||||
for (int e = 0; e < ptrptr.size(); e++ )
|
for (int e = 0; e < nativeOps.getShapeListSize(ptrptr); e++ )
|
||||||
result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer()));
|
result.add(getShapeFromPointer(new PagedPointer(nativeOps.getShape(ptrptr, e)).asLongPointer()));
|
||||||
|
|
||||||
nativeOps.deleteShapeList(ptrptr);
|
nativeOps.deleteShapeList(ptrptr);
|
||||||
|
|
||||||
|
@ -2251,28 +2256,32 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||||
|
|
||||||
val context = (CudaOpContext) buildContext();
|
val name = op.opName();
|
||||||
context.markInplace(op.isInplaceCall());
|
try (val context = (CudaOpContext) buildContext()) {
|
||||||
|
context.markInplace(op.isInplaceCall());
|
||||||
|
|
||||||
// transferring rng state
|
// transferring rng state
|
||||||
context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
|
context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
|
||||||
|
|
||||||
//transferring input/output arrays
|
//transferring input/output arrays
|
||||||
context.setInputArrays(op.inputArguments());
|
context.setInputArrays(op.inputArguments());
|
||||||
context.setOutputArrays(op.outputArguments());
|
context.setOutputArrays(op.outputArguments());
|
||||||
|
|
||||||
// transferring static args
|
// transferring static args
|
||||||
context.setBArguments(op.bArgs());
|
context.setBArguments(op.bArgs());
|
||||||
context.setIArguments(op.iArgs());
|
context.setIArguments(op.iArgs());
|
||||||
context.setTArguments(op.tArgs());
|
context.setTArguments(op.tArgs());
|
||||||
|
|
||||||
val result = exec(op, context);
|
val result = exec(op, context);
|
||||||
val states = context.getRngStates();
|
val states = context.getRngStates();
|
||||||
|
|
||||||
// pulling states back
|
// pulling states back
|
||||||
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
|
Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Op [" + name + "] execution failed", e);
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
long st = profilingConfigurableHookIn(op);
|
long st = profilingConfigurableHookIn(op);
|
||||||
|
@ -2418,19 +2427,19 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val newMap = new LinkedHashMap<String, INDArray>();
|
val newMap = new LinkedHashMap<String, INDArray>();
|
||||||
|
|
||||||
val result = (Nd4jCuda.VariablesSet) nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||||
|
|
||||||
val status = OpStatus.byNumber(result.status());
|
OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result));
|
||||||
|
|
||||||
if (status != OpStatus.ND4J_STATUS_OK)
|
if (status != OpStatus.ND4J_STATUS_OK)
|
||||||
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
||||||
|
|
||||||
for (int e = 0; e < result.size(); e++) {
|
for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) {
|
||||||
val var = result.at(e);
|
OpaqueVariable var = nativeOps.getVariable(result, e);
|
||||||
val nodeId = var.id();
|
int nodeId = nativeOps.getVariableId(var);
|
||||||
val index = var.index();
|
int index = nativeOps.getVariableIndex(var);
|
||||||
val shapeInfo = var.getNDArray().shapeInfo();
|
LongPointer shapeInfo = nativeOps.getVariableShape(var);
|
||||||
val buffer = var.getNDArray().buffer();
|
Pointer buffer = nativeOps.getVariableBuffer(var);
|
||||||
|
|
||||||
val rank = (int) shapeInfo.get(0);
|
val rank = (int) shapeInfo.get(0);
|
||||||
val jshape = new long[rank * 2 + 4];
|
val jshape = new long[rank * 2 + 4];
|
||||||
|
@ -2446,7 +2455,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType());
|
Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType());
|
||||||
AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
|
AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
|
||||||
|
|
||||||
val nodeName = var.getName().getString();
|
String nodeName = nativeOps.getVariableName(var);
|
||||||
newMap.put(nodeName, array);
|
newMap.put(nodeName, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2584,9 +2593,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||||
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||||
|
|
||||||
val result = new CudaLongDataBuffer(dbf.primary(), dbf.special(), Shape.shapeInfoLength(shape.length));
|
val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length));
|
||||||
|
|
||||||
nativeOps.deleteShapeBuffer(dbf);
|
nativeOps.deleteShapeBuffer(dbf);
|
||||||
|
|
||||||
|
@ -2595,10 +2604,10 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||||
val pack = (Nd4jCuda.TadPack) nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||||
|
|
||||||
val tadShape = new CudaLongDataBuffer(pack.primaryShapeInfo(), pack.specialShapeInfo(), pack.shapeInfoLength());
|
val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack));
|
||||||
val tadOffsets = new CudaLongDataBuffer(pack.primaryOffsets(), pack.specialOffsets(), pack.numberOfTads());
|
val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));
|
||||||
|
|
||||||
nativeOps.deleteTadPack(pack);
|
nativeOps.deleteTadPack(pack);
|
||||||
|
|
||||||
|
@ -2607,9 +2616,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
|
||||||
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);
|
||||||
|
|
||||||
val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType);
|
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
|
||||||
buffer.setConstant(true);
|
buffer.setConstant(true);
|
||||||
|
|
||||||
return buffer;
|
return buffer;
|
||||||
|
@ -2617,9 +2626,9 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
public DataBuffer createConstantBuffer(double[] values, DataType desiredType) {
|
||||||
val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);
|
||||||
|
|
||||||
val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType);
|
val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
|
||||||
buffer.setConstant(true);
|
buffer.setConstant(true);
|
||||||
|
|
||||||
return buffer;
|
return buffer;
|
||||||
|
|
|
@ -18,6 +18,9 @@ package org.nd4j.linalg.jcublas.ops.executioner;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
import org.bytedeco.javacpp.BooleanPointer;
|
||||||
|
import org.bytedeco.javacpp.DoublePointer;
|
||||||
|
import org.bytedeco.javacpp.LongPointer;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
|
||||||
|
@ -28,7 +31,10 @@ import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.nativeblas.Nd4jCuda;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
import org.nd4j.nativeblas.OpaqueContext;
|
||||||
|
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* CUDA wrapper for op Context
|
* CUDA wrapper for op Context
|
||||||
|
@ -36,34 +42,41 @@ import org.nd4j.nativeblas.Nd4jCuda;
|
||||||
*/
|
*/
|
||||||
public class CudaOpContext extends BaseOpContext implements OpContext {
|
public class CudaOpContext extends BaseOpContext implements OpContext {
|
||||||
// we might want to have configurable
|
// we might want to have configurable
|
||||||
private Nd4jCuda.Context context = new Nd4jCuda.Context(1);
|
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
|
private OpaqueContext context = nativeOps.createGraphContext(1);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
nativeOps.deleteGraphContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setIArguments(long... arguments) {
|
public void setIArguments(long... arguments) {
|
||||||
super.setIArguments(arguments);
|
super.setIArguments(arguments);
|
||||||
context.setIArguments(arguments, arguments.length);
|
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setBArguments(boolean... arguments) {
|
public void setBArguments(boolean... arguments) {
|
||||||
super.setBArguments(arguments);
|
super.setBArguments(arguments);
|
||||||
context.setBArguments(arguments, arguments.length);
|
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setTArguments(double... arguments) {
|
public void setTArguments(double... arguments) {
|
||||||
super.setTArguments(arguments);
|
super.setTArguments(arguments);
|
||||||
context.setTArguments(arguments, arguments.length);
|
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setRngStates(long rootState, long nodeState) {
|
public void setRngStates(long rootState, long nodeState) {
|
||||||
context.randomGenerator().setStates(rootState, nodeState);
|
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Long, Long> getRngStates() {
|
public Pair<Long, Long> getRngStates() {
|
||||||
return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState());
|
OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
|
||||||
|
return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -72,7 +85,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
||||||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
||||||
|
|
||||||
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||||
context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||||
|
|
||||||
super.setInputArray(index, array);
|
super.setInputArray(index, array);
|
||||||
}
|
}
|
||||||
|
@ -82,7 +95,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
||||||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
||||||
|
|
||||||
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||||
context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||||
|
|
||||||
super.setOutputArray(index, array);
|
super.setOutputArray(index, array);
|
||||||
}
|
}
|
||||||
|
@ -113,11 +126,11 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
||||||
|
|
||||||
|
|
||||||
public void setCudaStream(cudaStream_t stream, Pointer reductionPointer, Pointer allocationPointer) {
|
public void setCudaStream(cudaStream_t stream, Pointer reductionPointer, Pointer allocationPointer) {
|
||||||
context.setCudaContext(stream, reductionPointer, allocationPointer);
|
nativeOps.setGraphContextCudaContext(context, stream, reductionPointer, allocationPointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void markInplace(boolean reallyInplace) {
|
public void markInplace(boolean reallyInplace) {
|
||||||
context.markInplace(reallyInplace);
|
nativeOps.markGraphContextInplace(context, reallyInplace);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,9 @@ import org.bytedeco.javacpp.PointerPointer;
|
||||||
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.linalg.jcublas.context.CudaContext;
|
import org.nd4j.linalg.jcublas.context.CudaContext;
|
||||||
import org.nd4j.nativeblas.Nd4jCuda;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||||
import org.nd4j.rng.NativeRandom;
|
import org.nd4j.rng.NativeRandom;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -33,7 +35,7 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class CudaNativeRandom extends NativeRandom {
|
public class CudaNativeRandom extends NativeRandom {
|
||||||
|
private NativeOps nativeOps;
|
||||||
protected List<DataBuffer> stateBuffers;
|
protected List<DataBuffer> stateBuffers;
|
||||||
|
|
||||||
public CudaNativeRandom() {
|
public CudaNativeRandom() {
|
||||||
|
@ -50,10 +52,16 @@ public class CudaNativeRandom extends NativeRandom {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void init() {
|
public void init() {
|
||||||
statePointer = new Nd4jCuda.RandomGenerator(seed, seed ^ 0xdeadbeef);
|
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
|
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
||||||
setSeed(seed);
|
setSeed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public PointerPointer getExtraPointers() {
|
public PointerPointer getExtraPointers() {
|
||||||
return null;
|
return null;
|
||||||
|
@ -63,7 +71,7 @@ public class CudaNativeRandom extends NativeRandom {
|
||||||
public void setSeed(long seed) {
|
public void setSeed(long seed) {
|
||||||
this.seed = seed;
|
this.seed = seed;
|
||||||
this.currentPosition.set(0);
|
this.currentPosition.set(0);
|
||||||
((Nd4jCuda.RandomGenerator) statePointer).setStates(seed, seed ^ 0xdeadbeef);
|
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -73,24 +81,24 @@ public class CudaNativeRandom extends NativeRandom {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int nextInt() {
|
public int nextInt() {
|
||||||
return ((Nd4jCuda.RandomGenerator) statePointer).relativeInt(currentPosition.getAndIncrement());
|
return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long nextLong() {
|
public long nextLong() {
|
||||||
return ((Nd4jCuda.RandomGenerator) statePointer).relativeLong(currentPosition.getAndIncrement());
|
return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||||
}
|
}
|
||||||
|
|
||||||
public long rootState() {
|
public long rootState() {
|
||||||
return ((Nd4jCuda.RandomGenerator) statePointer).rootState();
|
return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
public long nodeState() {
|
public long nodeState() {
|
||||||
return ((Nd4jCuda.RandomGenerator) statePointer).nodeState();
|
return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setStates(long rootState, long nodeState) {
|
public void setStates(long rootState, long nodeState) {
|
||||||
((Nd4jCuda.RandomGenerator) statePointer).setStates(rootState, nodeState);
|
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize);
|
||||||
* @param targetBuffer
|
* @param targetBuffer
|
||||||
* @param offsetsBuffer
|
* @param offsetsBuffer
|
||||||
*/
|
*/
|
||||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||||
IntPointer dimension,
|
IntPointer dimension,
|
||||||
int dimensionLength);
|
int dimensionLength);
|
||||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||||
IntBuffer dimension,
|
IntBuffer dimension,
|
||||||
int dimensionLength);
|
int dimensionLength);
|
||||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
||||||
int[] dimension,
|
int[] dimension,
|
||||||
int dimensionLength);
|
int dimensionLength);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack);
|
||||||
|
public native int getShapeInfoLength(OpaqueTadPack pack);
|
||||||
|
|
||||||
|
public native void deleteTadPack(OpaqueTadPack ptr);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* PullRow special op
|
* PullRow special op
|
||||||
*/
|
*/
|
||||||
|
@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers
|
||||||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
|
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
|
||||||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
|
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
|
||||||
|
|
||||||
|
|
||||||
// flatbuffers execution
|
// flatbuffers execution
|
||||||
public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr);
|
||||||
|
public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
|
||||||
|
|
||||||
public native @Cast("char*") String getAllCustomOps();
|
public native @Cast("char*") String getAllCustomOps();
|
||||||
|
|
||||||
|
@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer
|
||||||
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
|
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
|
||||||
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
|
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
|
||||||
|
|
||||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||||
|
|
||||||
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
||||||
|
|
||||||
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||||
|
|
||||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
|
||||||
|
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
|
||||||
|
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
|
||||||
|
public native int getVariableId(OpaqueVariable variable);
|
||||||
|
public native int getVariableIndex(OpaqueVariable variable);
|
||||||
|
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable);
|
||||||
|
public native Pointer getVariableBuffer(OpaqueVariable variable);
|
||||||
|
|
||||||
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
|
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
|
||||||
|
|
||||||
|
@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||||
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||||
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||||
|
|
||||||
public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer);
|
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
|
||||||
|
|
||||||
// GraphState creation
|
// GraphState creation
|
||||||
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
||||||
|
@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*"
|
||||||
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
||||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
|
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
|
||||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
|
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
|
||||||
|
public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||||
|
public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||||
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||||
|
|
||||||
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
|
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
|
||||||
|
@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
||||||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||||
|
|
||||||
|
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||||
|
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||||
|
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||||
|
|
||||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
||||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
||||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
||||||
|
|
||||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
|
||||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
|
||||||
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
|
||||||
|
|
||||||
|
public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
|
||||||
|
|
||||||
|
public native OpaqueContext createGraphContext(int nodeId);
|
||||||
|
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||||
|
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||||
|
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||||
|
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
|
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments);
|
||||||
|
public native void deleteGraphContext(OpaqueContext ptr);
|
||||||
|
|
||||||
|
public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||||
|
public native OpaqueRandomGenerator createRandomGenerator();
|
||||||
|
public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
|
||||||
|
public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
|
||||||
|
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||||
|
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr);
|
||||||
|
public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||||
|
public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||||
|
public native void deleteRandomGenerator(OpaqueRandomGenerator ptr);
|
||||||
|
|
||||||
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||||
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||||
|
@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns new array with the same shape & data type
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public native @ByVal NDArray like();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns new uninitialized array with the same shape & data type
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public native @ByVal NDArray ulike();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new NDArray with shape matching "other" array,
|
* this constructor creates new NDArray with shape matching "other" array,
|
||||||
* doesn't copy "other" elements into new array !!!
|
* doesn't copy "other" elements into new array !!!
|
||||||
|
|
|
@ -113,6 +113,14 @@ public class Nd4jCudaPresets implements InfoMapper {
|
||||||
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
||||||
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
||||||
.put(new Info("NativeOps.h").objectify())
|
.put(new Info("NativeOps.h").objectify())
|
||||||
|
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||||
|
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||||
|
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
|
||||||
|
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
|
||||||
|
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
|
||||||
|
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
|
||||||
|
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
|
||||||
|
.put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator"))
|
||||||
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
|
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
|
||||||
"@Cast(\"char*\") BytePointer"))
|
"@Cast(\"char*\") BytePointer"))
|
||||||
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",
|
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",
|
||||||
|
|
|
@ -48,7 +48,6 @@ import org.nd4j.linalg.util.ArrayUtil;
|
||||||
import org.nd4j.nativeblas.BaseNativeNDArrayFactory;
|
import org.nd4j.nativeblas.BaseNativeNDArrayFactory;
|
||||||
import org.nd4j.nativeblas.LongPointerWrapper;
|
import org.nd4j.nativeblas.LongPointerWrapper;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.nativeblas.Nd4jCpu;
|
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
|
@ -17,12 +17,18 @@
|
||||||
package org.nd4j.linalg.cpu.nativecpu.ops;
|
package org.nd4j.linalg.cpu.nativecpu.ops;
|
||||||
|
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
|
import org.bytedeco.javacpp.BooleanPointer;
|
||||||
|
import org.bytedeco.javacpp.DoublePointer;
|
||||||
|
import org.bytedeco.javacpp.LongPointer;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseOpContext;
|
import org.nd4j.linalg.api.ops.BaseOpContext;
|
||||||
import org.nd4j.linalg.api.ops.OpContext;
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.nativeblas.Nd4jCpu;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
import org.nd4j.nativeblas.OpaqueContext;
|
||||||
|
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -33,46 +39,53 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class CpuOpContext extends BaseOpContext implements OpContext {
|
public class CpuOpContext extends BaseOpContext implements OpContext {
|
||||||
// we might want to have configurable
|
// we might want to have configurable
|
||||||
private Nd4jCpu.Context context = new Nd4jCpu.Context(1);
|
private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
|
private OpaqueContext context = nativeOps.createGraphContext(1);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
nativeOps.deleteGraphContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setIArguments(long... arguments) {
|
public void setIArguments(long... arguments) {
|
||||||
super.setIArguments(arguments);
|
super.setIArguments(arguments);
|
||||||
context.setIArguments(arguments, arguments.length);
|
nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setBArguments(boolean... arguments) {
|
public void setBArguments(boolean... arguments) {
|
||||||
super.setBArguments(arguments);
|
super.setBArguments(arguments);
|
||||||
context.setBArguments(arguments, arguments.length);
|
nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setTArguments(double... arguments) {
|
public void setTArguments(double... arguments) {
|
||||||
super.setTArguments(arguments);
|
super.setTArguments(arguments);
|
||||||
context.setTArguments(arguments, arguments.length);
|
nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setRngStates(long rootState, long nodeState) {
|
public void setRngStates(long rootState, long nodeState) {
|
||||||
context.randomGenerator().setStates(rootState, nodeState);
|
nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Pair<Long, Long> getRngStates() {
|
public Pair<Long, Long> getRngStates() {
|
||||||
return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState());
|
OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context);
|
||||||
|
return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputArray(int index, @NonNull INDArray array) {
|
public void setInputArray(int index, @NonNull INDArray array) {
|
||||||
context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
||||||
|
|
||||||
super.setInputArray(index, array);
|
super.setInputArray(index, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setOutputArray(int index, @NonNull INDArray array) {
|
public void setOutputArray(int index, @NonNull INDArray array) {
|
||||||
context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null);
|
||||||
|
|
||||||
super.setOutputArray(index, array);
|
super.setOutputArray(index, array);
|
||||||
}
|
}
|
||||||
|
@ -84,6 +97,6 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void markInplace(boolean reallyInplace) {
|
public void markInplace(boolean reallyInplace) {
|
||||||
context.markInplace(reallyInplace);
|
nativeOps.markGraphContextInplace(context, reallyInplace);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,10 +70,14 @@ import org.nd4j.nativeblas.LongPointerWrapper;
|
||||||
import org.nd4j.nativeblas.NativeOps;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.nativeblas.Nd4jCpu;
|
import org.nd4j.nativeblas.Nd4jCpu;
|
||||||
|
import org.nd4j.nativeblas.OpaqueConstantDataBuffer;
|
||||||
|
import org.nd4j.nativeblas.OpaqueShapeList;
|
||||||
|
import org.nd4j.nativeblas.OpaqueTadPack;
|
||||||
|
import org.nd4j.nativeblas.OpaqueVariable;
|
||||||
|
import org.nd4j.nativeblas.OpaqueVariableSet;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* Native operation
|
* Native operation
|
||||||
|
@ -1641,23 +1645,22 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
val name = op.opName();
|
val name = op.opName();
|
||||||
val context = buildContext();
|
try (val context = buildContext()) {
|
||||||
|
|
||||||
context.markInplace(op.isInplaceCall());
|
context.markInplace(op.isInplaceCall());
|
||||||
|
|
||||||
// transferring rng state
|
// transferring rng state
|
||||||
context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
|
context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
|
||||||
|
|
||||||
//transferring input/output arrays
|
//transferring input/output arrays
|
||||||
context.setInputArrays(op.inputArguments());
|
context.setInputArrays(op.inputArguments());
|
||||||
context.setOutputArrays(op.outputArguments());
|
context.setOutputArrays(op.outputArguments());
|
||||||
|
|
||||||
// transferring static args
|
// transferring static args
|
||||||
context.setBArguments(op.bArgs());
|
context.setBArguments(op.bArgs());
|
||||||
context.setIArguments(op.iArgs());
|
context.setIArguments(op.iArgs());
|
||||||
context.setTArguments(op.tArgs());
|
context.setTArguments(op.tArgs());
|
||||||
|
|
||||||
try {
|
|
||||||
val result = exec(op, context);
|
val result = exec(op, context);
|
||||||
val states = context.getRngStates();
|
val states = context.getRngStates();
|
||||||
|
|
||||||
|
@ -1860,9 +1863,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
for (val t: tArgs1)
|
for (val t: tArgs1)
|
||||||
tArgs.put(cnt++, t);
|
tArgs.put(cnt++, t);
|
||||||
|
|
||||||
Nd4jCpu.ShapeList ptrptr;
|
OpaqueShapeList ptrptr;
|
||||||
try {
|
try {
|
||||||
ptrptr = (Nd4jCpu.ShapeList) loop.calculateOutputShapes2(null,
|
ptrptr = loop.calculateOutputShapes2(null,
|
||||||
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
|
hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs,
|
||||||
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
|
op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments());
|
||||||
} catch (Throwable t){
|
} catch (Throwable t){
|
||||||
|
@ -1891,8 +1894,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
if (ptrptr == null)
|
if (ptrptr == null)
|
||||||
throw new RuntimeException();
|
throw new RuntimeException();
|
||||||
|
|
||||||
for (int e = 0; e < ptrptr.size(); e++ )
|
for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ )
|
||||||
result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer()));
|
result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer()));
|
||||||
|
|
||||||
|
|
||||||
loop.deleteShapeList(ptrptr);
|
loop.deleteShapeList(ptrptr);
|
||||||
|
@ -1947,19 +1950,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
val newMap = new LinkedHashMap<String, INDArray>();
|
val newMap = new LinkedHashMap<String, INDArray>();
|
||||||
|
|
||||||
val result = (Nd4jCpu.VariablesSet) loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());
|
||||||
|
|
||||||
val status = OpStatus.byNumber(result.status());
|
OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result));
|
||||||
|
|
||||||
if (status != OpStatus.ND4J_STATUS_OK)
|
if (status != OpStatus.ND4J_STATUS_OK)
|
||||||
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
throw new ND4JIllegalStateException("Op execution failed: " + status);
|
||||||
|
|
||||||
for (int e = 0; e < result.size(); e++) {
|
for (int e = 0; e < loop.getVariableSetSize(result); e++) {
|
||||||
val var = result.at(e);
|
OpaqueVariable var = loop.getVariable(result, e);
|
||||||
val nodeId = var.id();
|
int nodeId = loop.getVariableId(var);
|
||||||
val index = var.index();
|
int index = loop.getVariableIndex(var);
|
||||||
val shapeInfo = var.getNDArray().shapeInfo();
|
LongPointer shapeInfo = loop.getVariableShape(var);
|
||||||
val buffer = var.getNDArray().buffer();
|
Pointer buffer = loop.getVariableBuffer(var);
|
||||||
|
|
||||||
val rank = (int) shapeInfo.get(0);
|
val rank = (int) shapeInfo.get(0);
|
||||||
val jshape = new long[rank * 2 + 4];
|
val jshape = new long[rank * 2 + 4];
|
||||||
|
@ -1979,7 +1982,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(array.dataType()), MemcpyDirection.HOST_TO_HOST);
|
PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(array.dataType()), MemcpyDirection.HOST_TO_HOST);
|
||||||
|
|
||||||
//newMap.put(keySet.get(nodeId), array);
|
//newMap.put(keySet.get(nodeId), array);
|
||||||
val nodeName = var.getName().getString();
|
String nodeName = loop.getVariableName(var);
|
||||||
newMap.put(nodeName, array);
|
newMap.put(nodeName, array);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2160,9 +2163,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
|
||||||
val dbf = (Nd4jCpu.ConstantDataBuffer) loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);
|
||||||
|
|
||||||
val result = new LongBuffer(dbf.primary(), Shape.shapeInfoLength(shape.length));
|
val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length));
|
||||||
|
|
||||||
loop.deleteShapeBuffer(dbf);
|
loop.deleteShapeBuffer(dbf);
|
||||||
|
|
||||||
|
@ -2171,10 +2174,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) {
|
||||||
val pack = (Nd4jCpu.TadPack) loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length);
|
||||||
|
|
||||||
val tadShape = new LongBuffer(pack.primaryShapeInfo(), pack.shapeInfoLength());
|
val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack));
|
||||||
val tadOffsets = new LongBuffer(pack.primaryOffsets(), pack.numberOfTads());
|
val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack));
|
||||||
|
|
||||||
loop.deleteTadPack(pack);
|
loop.deleteTadPack(pack);
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,9 @@ package org.nd4j.linalg.cpu.nativecpu.rng;
|
||||||
|
|
||||||
import org.bytedeco.javacpp.PointerPointer;
|
import org.bytedeco.javacpp.PointerPointer;
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
import org.nd4j.nativeblas.Nd4jCpu;
|
import org.nd4j.nativeblas.NativeOps;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
import org.nd4j.nativeblas.OpaqueRandomGenerator;
|
||||||
import org.nd4j.rng.NativeRandom;
|
import org.nd4j.rng.NativeRandom;
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicLong;
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
@ -29,6 +31,8 @@ import java.util.concurrent.atomic.AtomicLong;
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
public class CpuNativeRandom extends NativeRandom {
|
public class CpuNativeRandom extends NativeRandom {
|
||||||
|
private NativeOps nativeOps;
|
||||||
|
|
||||||
public CpuNativeRandom() {
|
public CpuNativeRandom() {
|
||||||
super();
|
super();
|
||||||
}
|
}
|
||||||
|
@ -43,7 +47,13 @@ public class CpuNativeRandom extends NativeRandom {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void init() {
|
public void init() {
|
||||||
statePointer = new Nd4jCpu.RandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
|
||||||
|
statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -55,7 +65,7 @@ public class CpuNativeRandom extends NativeRandom {
|
||||||
public void setSeed(long seed) {
|
public void setSeed(long seed) {
|
||||||
this.seed = seed;
|
this.seed = seed;
|
||||||
this.currentPosition.set(0);
|
this.currentPosition.set(0);
|
||||||
((Nd4jCpu.RandomGenerator)statePointer).setStates(seed, seed ^ 0xdeadbeef);
|
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -65,24 +75,24 @@ public class CpuNativeRandom extends NativeRandom {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int nextInt() {
|
public int nextInt() {
|
||||||
return ((Nd4jCpu.RandomGenerator)statePointer).relativeInt(currentPosition.getAndIncrement());
|
return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long nextLong() {
|
public long nextLong() {
|
||||||
return ((Nd4jCpu.RandomGenerator)statePointer).relativeLong(currentPosition.getAndIncrement());
|
return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement());
|
||||||
}
|
}
|
||||||
|
|
||||||
public long rootState() {
|
public long rootState() {
|
||||||
return ((Nd4jCpu.RandomGenerator) statePointer).rootState();
|
return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
public long nodeState() {
|
public long nodeState() {
|
||||||
return ((Nd4jCpu.RandomGenerator) statePointer).nodeState();
|
return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setStates(long rootState, long nodeState) {
|
public void setStates(long rootState, long nodeState) {
|
||||||
((Nd4jCpu.RandomGenerator) statePointer).setStates(rootState, nodeState);
|
nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize);
|
||||||
* @param targetBuffer
|
* @param targetBuffer
|
||||||
* @param offsetsBuffer
|
* @param offsetsBuffer
|
||||||
*/
|
*/
|
||||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo,
|
||||||
IntPointer dimension,
|
IntPointer dimension,
|
||||||
int dimensionLength);
|
int dimensionLength);
|
||||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo,
|
||||||
IntBuffer dimension,
|
IntBuffer dimension,
|
||||||
int dimensionLength);
|
int dimensionLength);
|
||||||
public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo,
|
||||||
int[] dimension,
|
int[] dimension,
|
||||||
int dimensionLength);
|
int dimensionLength);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack);
|
||||||
|
public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack);
|
||||||
|
public native int getShapeInfoLength(OpaqueTadPack pack);
|
||||||
|
|
||||||
|
public native void deleteTadPack(OpaqueTadPack ptr);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* PullRow special op
|
* PullRow special op
|
||||||
*/
|
*/
|
||||||
|
@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers
|
||||||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
|
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length);
|
||||||
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
|
public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length);
|
||||||
|
|
||||||
|
|
||||||
// flatbuffers execution
|
// flatbuffers execution
|
||||||
public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr);
|
||||||
|
public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr);
|
||||||
|
|
||||||
public native @Cast("char*") String getAllCustomOps();
|
public native @Cast("char*") String getAllCustomOps();
|
||||||
|
|
||||||
|
@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer
|
||||||
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
|
public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace);
|
||||||
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
|
public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext);
|
||||||
|
|
||||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs);
|
||||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs);
|
||||||
public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs);
|
||||||
public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i);
|
||||||
|
|
||||||
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList);
|
||||||
|
|
||||||
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer);
|
||||||
|
|
||||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs);
|
||||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs);
|
||||||
public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs);
|
||||||
|
|
||||||
|
public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set);
|
||||||
|
public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set);
|
||||||
|
public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i);
|
||||||
|
public native int getVariableId(OpaqueVariable variable);
|
||||||
|
public native int getVariableIndex(OpaqueVariable variable);
|
||||||
|
public native @Cast("char*") String getVariableName(OpaqueVariable variable);
|
||||||
|
public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable);
|
||||||
|
public native Pointer getVariableBuffer(OpaqueVariable variable);
|
||||||
|
|
||||||
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
|
public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId);
|
||||||
|
|
||||||
|
@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||||
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||||
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer);
|
||||||
|
|
||||||
public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer);
|
public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer);
|
||||||
|
|
||||||
// GraphState creation
|
// GraphState creation
|
||||||
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id);
|
||||||
|
@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*"
|
||||||
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
//void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer);
|
||||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
|
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length);
|
||||||
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
|
public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length);
|
||||||
|
public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||||
|
public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||||
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr);
|
||||||
|
|
||||||
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
|
public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs,
|
||||||
|
@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe
|
||||||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||||
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo);
|
||||||
|
|
||||||
|
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||||
|
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||||
|
public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
||||||
|
|
||||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
||||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
||||||
public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty);
|
public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
||||||
|
public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
||||||
|
|
||||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length);
|
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length);
|
public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length);
|
public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length);
|
public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf);
|
||||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length);
|
|
||||||
public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length);
|
|
||||||
public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor);
|
|
||||||
|
|
||||||
|
public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
|
||||||
|
|
||||||
|
public native OpaqueContext createGraphContext(int nodeId);
|
||||||
|
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
|
||||||
|
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
|
||||||
|
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
|
||||||
|
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
|
public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
|
||||||
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments);
|
||||||
|
public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments);
|
||||||
|
public native void deleteGraphContext(OpaqueContext ptr);
|
||||||
|
|
||||||
|
public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||||
|
public native OpaqueRandomGenerator createRandomGenerator();
|
||||||
|
public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr);
|
||||||
|
public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr);
|
||||||
|
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/);
|
||||||
|
public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr);
|
||||||
|
public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||||
|
public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index);
|
||||||
|
public native void deleteRandomGenerator(OpaqueRandomGenerator ptr);
|
||||||
|
|
||||||
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||||
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut);
|
||||||
|
@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
|
||||||
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); }
|
||||||
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns new array with the same shape & data type
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public native @ByVal NDArray like();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method returns new uninitialized array with the same shape & data type
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public native @ByVal NDArray ulike();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this constructor creates new NDArray with shape matching "other" array,
|
* this constructor creates new NDArray with shape matching "other" array,
|
||||||
* doesn't copy "other" elements into new array !!!
|
* doesn't copy "other" elements into new array !!!
|
||||||
|
|
|
@ -156,6 +156,14 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
|
||||||
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
||||||
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
||||||
.put(new Info("NativeOps.h").objectify())
|
.put(new Info("NativeOps.h").objectify())
|
||||||
|
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||||
|
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||||
|
.put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList"))
|
||||||
|
.put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet"))
|
||||||
|
.put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable"))
|
||||||
|
.put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer"))
|
||||||
|
.put(new Info("OpaqueContext").pointerTypes("OpaqueContext"))
|
||||||
|
.put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator"))
|
||||||
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
|
.put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String",
|
||||||
"@Cast(\"char*\") BytePointer"))
|
"@Cast(\"char*\") BytePointer"))
|
||||||
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",
|
.put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer",
|
||||||
|
|
|
@ -442,7 +442,7 @@ public class CustomOpsTests extends BaseNd4jTest {
|
||||||
context.setOutputArray(0, arrayZ);
|
context.setOutputArray(0, arrayZ);
|
||||||
|
|
||||||
val addOp = new AddOp();
|
val addOp = new AddOp();
|
||||||
NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp(null, addOp.opHash(), context.contextPointer());
|
NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp2(null, addOp.opHash(), context.contextPointer());
|
||||||
|
|
||||||
assertEquals(exp, arrayZ);
|
assertEquals(exp, arrayZ);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue