diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index 65929b452..12d34ba06 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -882,6 +882,8 @@ ND4J_EXPORT void enableVerboseMode(bool reallyEnable); */ ND4J_EXPORT void setGridLimit(int gridSize); +typedef nd4j::TadPack OpaqueTadPack; + /** * * @param xShapeInfo @@ -890,10 +892,19 @@ ND4J_EXPORT void setGridLimit(int gridSize); * @param targetBuffer * @param offsetsBuffer */ -ND4J_EXPORT nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo, +ND4J_EXPORT OpaqueTadPack* tadOnlyShapeInfo(Nd4jLong *xShapeInfo, int *dimension, int dimensionLength); +ND4J_EXPORT Nd4jLong* getPrimaryShapeInfo(OpaqueTadPack* pack); +ND4J_EXPORT Nd4jLong* getPrimaryOffsets(OpaqueTadPack* pack); +ND4J_EXPORT Nd4jLong* getSpecialShapeInfo(OpaqueTadPack* pack); +ND4J_EXPORT Nd4jLong* getSpecialOffsets(OpaqueTadPack* pack); +ND4J_EXPORT Nd4jLong getNumberOfTads(OpaqueTadPack* pack); +ND4J_EXPORT int getShapeInfoLength(OpaqueTadPack* pack); + +ND4J_EXPORT void deleteTadPack(OpaqueTadPack* ptr); + /* * PullRow special op */ @@ -1639,10 +1650,13 @@ ND4J_EXPORT Nd4jLong* mmapFile(Nd4jPointer *extraPointers, const char *fileName, ND4J_EXPORT void munmapFile(Nd4jPointer *extraPointers, Nd4jLong* ptrMap, Nd4jLong length); +typedef nd4j::graph::ResultWrapper OpaqueResultWrapper; // flatbuffers execution -ND4J_EXPORT nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer); +ND4J_EXPORT OpaqueResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPointer flatBufferPointer); +ND4J_EXPORT Nd4jLong getResultWrapperSize(OpaqueResultWrapper* ptr); +ND4J_EXPORT Nd4jPointer getResultWrapperPointer(OpaqueResultWrapper* ptr); ND4J_EXPORT const char* getAllCustomOps(); @@ -1652,14 +1666,31 @@ ND4J_EXPORT const char* getAllOperations(); ND4J_EXPORT int execCustomOp(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace); ND4J_EXPORT int execCustomOp2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer opContext); -ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); -ND4J_EXPORT nd4j::ShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs); +typedef nd4j::ShapeList OpaqueShapeList; + +ND4J_EXPORT OpaqueShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs); +ND4J_EXPORT OpaqueShapeList* calculateOutputShapes2(Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs); + +ND4J_EXPORT Nd4jLong getShapeListSize(OpaqueShapeList* list); +ND4J_EXPORT Nd4jLong* getShape(OpaqueShapeList* list, Nd4jLong i); ND4J_EXPORT void deleteShapeList(Nd4jPointer shapeList); ND4J_EXPORT int registerGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer flatBufferPointer); -ND4J_EXPORT nd4j::graph::VariablesSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); +typedef nd4j::graph::VariablesSet OpaqueVariableSet; +typedef nd4j::graph::Variable OpaqueVariable; + +ND4J_EXPORT OpaqueVariableSet *executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, Nd4jPointer *inputBuffers, Nd4jPointer *inputShapes, int* inputIndices, int numInputs); + +ND4J_EXPORT Nd4jLong getVariableSetSize(OpaqueVariableSet* set); +ND4J_EXPORT Nd4jStatus getVariableSetStatus(OpaqueVariableSet* set); +ND4J_EXPORT OpaqueVariable* getVariable(OpaqueVariableSet* set, Nd4jLong i); +ND4J_EXPORT int getVariableId(OpaqueVariable* variable); +ND4J_EXPORT int getVariableIndex(OpaqueVariable* variable); +ND4J_EXPORT const char* getVariableName(OpaqueVariable* variable); +ND4J_EXPORT Nd4jLong* getVariableShape(OpaqueVariable* variable); +ND4J_EXPORT void* getVariableBuffer(OpaqueVariable* variable); ND4J_EXPORT int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId); @@ -1668,7 +1699,7 @@ ND4J_EXPORT void deleteIntArray(Nd4jPointer pointer); ND4J_EXPORT void deleteLongArray(Nd4jPointer pointer); ND4J_EXPORT void deletePointerArray(Nd4jPointer pointer); -ND4J_EXPORT void deleteVariablesSet(Nd4jPointer pointer); +ND4J_EXPORT void deleteVariablesSet(OpaqueVariableSet pointer); // GraphState creation ND4J_EXPORT Nd4jPointer getGraphState(Nd4jLong id); @@ -1684,7 +1715,9 @@ ND4J_EXPORT Nd4jStatus execCustomOpWithScope(Nd4jPointer *extraPointers, Nd4jPoi //void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); ND4J_EXPORT Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int length); -void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); +ND4J_EXPORT Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr); +ND4J_EXPORT char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr); +ND4J_EXPORT void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr); ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOfSubArrs, void* hX, Nd4jLong* hXShapeInfo, Nd4jLong* hXOffsets, @@ -1693,18 +1726,45 @@ ND4J_EXPORT void scatterUpdate(Nd4jPointer *extraPointers, int opCode, int numOf void* dY, Nd4jLong* dYShapeInfo, Nd4jLong* dYOffsets, int* hIindexes, int* dIindexes); -ND4J_EXPORT void deleteShapeBuffer(Nd4jPointer ptr); -ND4J_EXPORT void deleteTadPack(Nd4jPointer ptr); - ND4J_EXPORT void inspectArray(Nd4jPointer *extraPointers, Nd4jPointer buffer, Nd4jLong *shapeInfo, Nd4jPointer specialBuffer, Nd4jLong *specialShapeInfo, Nd4jPointer debugInfo); -ND4J_EXPORT nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty); +typedef nd4j::ConstantDataBuffer OpaqueConstantDataBuffer; -ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length); -ND4J_EXPORT nd4j::ConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length); -ND4J_EXPORT nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor); +ND4J_EXPORT OpaqueConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strides, nd4j::DataType dtype, char order, Nd4jLong ews, bool empty); +ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length); +ND4J_EXPORT OpaqueConstantDataBuffer* constantBufferDouble(nd4j::DataType dtype, double *data, int length); +ND4J_EXPORT OpaqueConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDescriptor *descriptor); + +ND4J_EXPORT Nd4jPointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer* dbf); +ND4J_EXPORT Nd4jPointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer* dbf); +ND4J_EXPORT Nd4jLong getConstantDataBufferLength(OpaqueConstantDataBuffer* dbf); +ND4J_EXPORT Nd4jLong getConstantDataBufferSizeOf(OpaqueConstantDataBuffer* dbf); + +ND4J_EXPORT void deleteShapeBuffer(OpaqueConstantDataBuffer* ptr); + +typedef nd4j::graph::Context OpaqueContext; +typedef nd4j::graph::RandomGenerator OpaqueRandomGenerator; + +ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId); +ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); +ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); +ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); +ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); +ND4J_EXPORT void setGraphContextOutputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); +ND4J_EXPORT void setGraphContextTArguments(OpaqueContext* ptr, double *arguments, int numberOfArguments); +ND4J_EXPORT void setGraphContextIArguments(OpaqueContext* ptr, Nd4jLong *arguments, int numberOfArguments); +ND4J_EXPORT void setGraphContextBArguments(OpaqueContext* ptr, bool *arguments, int numberOfArguments); +ND4J_EXPORT void deleteGraphContext(OpaqueContext* ptr); + +ND4J_EXPORT OpaqueRandomGenerator* createRandomGenerator(Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); +ND4J_EXPORT Nd4jLong getRandomGeneratorRootState(OpaqueRandomGenerator* ptr); +ND4J_EXPORT Nd4jLong getRandomGeneratorNodeState(OpaqueRandomGenerator* ptr); +ND4J_EXPORT void setRandomGeneratorStates(OpaqueRandomGenerator* ptr, Nd4jLong rootSeed = 0, Nd4jLong nodeSeed = 0); +ND4J_EXPORT int getRandomGeneratorRelativeInt(OpaqueRandomGenerator* ptr, Nd4jLong index); +ND4J_EXPORT Nd4jLong getRandomGeneratorRelativeLong(OpaqueRandomGenerator* ptr, Nd4jLong index); +ND4J_EXPORT void deleteRandomGenerator(OpaqueRandomGenerator* ptr); ND4J_EXPORT const char* runLightBenchmarkSuit(bool printOut); ND4J_EXPORT const char* runFullBenchmarkSuit(bool printOut); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 7e5560536..9dffb412c 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -1328,6 +1328,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *hXShapeInfo, int *dimension, int dimen return pack; } +Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) { + return pack->primaryShapeInfo(); +} +Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) { + return pack->primaryOffsets(); +} +Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) { + return pack->specialShapeInfo(); +} +Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) { + return pack->specialOffsets(); +} +Nd4jLong getNumberOfTads(nd4j::TadPack* pack) { + return pack->numberOfTads(); +} +int getShapeInfoLength(nd4j::TadPack* pack) { + return pack->shapeInfoLength(); +} + int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { // no-op return 0L; @@ -2005,6 +2024,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); } +Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) { + return ptr->size(); +} +Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) { + return ptr->pointer(); +} + const char* getAllCustomOps() { return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations(); } @@ -2041,7 +2067,13 @@ int estimateThreshold(Nd4jPointer *extraPointers, Nd4jPointer hX, Nd4jLong *hXSh BUILD_SINGLE_SELECTOR(xType, return estimateThresholdGeneric, (extraPointers, hX, N, threshold), FLOAT_TYPES); } +Nd4jLong getShapeListSize(nd4j::ShapeList* list) { + return list->size(); +} +Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) { + return list->at(i); +} void deleteShapeList(Nd4jPointer shapeList) { auto list = reinterpret_cast(shapeList); @@ -2305,6 +2337,38 @@ nd4j::graph::VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLo return nullptr; } +Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) { + return set->size(); +} + +Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) { + return set->status(); +} + +nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) { + return set->at(i); +} + +int getVariableId(nd4j::graph::Variable* variable) { + return variable->id(); +} + +int getVariableIndex(nd4j::graph::Variable* variable) { + return variable->index(); +} + +const char* getVariableName(nd4j::graph::Variable* variable) { + return variable->getName()->c_str(); +} + +Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) { + return variable->getNDArray()->shapeInfo(); +} + +void* getVariableBuffer(nd4j::graph::Variable* variable) { + return variable->getNDArray()->buffer(); +} + int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId); @@ -2628,6 +2692,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int return reinterpret_cast(u); } +Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_length; +} +char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_buffer; +} + void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { delete(reinterpret_cast(ptr)); } @@ -2710,14 +2781,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid return buffer; } -void deleteShapeBuffer(Nd4jPointer ptr) { - auto buffer = reinterpret_cast(ptr); - delete buffer; +void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) { + delete ptr; } -void deleteTadPack(Nd4jPointer ptr) { - auto buffer = reinterpret_cast(ptr); - delete buffer; +void deleteTadPack(nd4j::TadPack* ptr) { + delete ptr; } nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) { @@ -2732,6 +2801,78 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); } +Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) { + return dbf->primary(); +} +Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) { + return dbf->special(); +} +Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) { + return dbf->length(); +} +Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) { + return dbf->sizeOf(); +} + + +nd4j::graph::Context* createGraphContext(int nodeId) { + return new nd4j::graph::Context(nodeId); +} +nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) { + return &ptr->randomGenerator(); +} +void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) { + ptr->markInplace(reallyInplace); +} +void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { +} +void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +} +void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +} +void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) { + ptr->setTArguments(arguments, numberOfArguments); +} +void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { + ptr->setIArguments(arguments, numberOfArguments); +} +void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) { + ptr->setBArguments(arguments, numberOfArguments); +} +void deleteGraphContext(nd4j::graph::Context* ptr) { + delete ptr; +} + + +nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { + return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); +} + +Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) { + return ptr->rootState(); +} + +Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) { + return ptr->nodeState(); +} + +void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) { + ptr->setStates(rootSeed, nodeSeed); +} + +int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) { + return ptr->relativeInt(index); +} + +Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) { + return ptr->relativeLong(index); +} + +void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) { + delete ptr; +} int dataTypeFromNpyHeader(void *header) { diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index b592343f2..3f0e432a9 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -1499,6 +1499,25 @@ nd4j::TadPack* tadOnlyShapeInfo(Nd4jLong *dXShapeInfo, int *dimension, int dimen return pack; } +Nd4jLong* getPrimaryShapeInfo(nd4j::TadPack* pack) { + return pack->primaryShapeInfo(); +} +Nd4jLong* getPrimaryOffsets(nd4j::TadPack* pack) { + return pack->primaryOffsets(); +} +Nd4jLong* getSpecialShapeInfo(nd4j::TadPack* pack) { + return pack->specialShapeInfo(); +} +Nd4jLong* getSpecialOffsets(nd4j::TadPack* pack) { + return pack->specialOffsets(); +} +Nd4jLong getNumberOfTads(nd4j::TadPack* pack) { + return pack->numberOfTads(); +} +int getShapeInfoLength(nd4j::TadPack* pack) { + return pack->shapeInfoLength(); +} + int memcpyConstantAsync(Nd4jLong dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) { cudaStream_t *pStream = reinterpret_cast(reserved); @@ -2533,6 +2552,13 @@ nd4j::graph::ResultWrapper* executeFlatGraph(Nd4jPointer *extraPointers, Nd4jPoi return nd4j::graph::GraphExecutioner::executeFlatBuffer(flatBufferPointer); } +Nd4jLong getResultWrapperSize(nd4j::graph::ResultWrapper* ptr) { + return ptr->size(); +} +Nd4jPointer getResultWrapperPointer(nd4j::graph::ResultWrapper* ptr) { + return ptr->pointer(); +} + const char* getAllCustomOps() { return nd4j::ops::OpRegistrator::getInstance()->getAllCustomOperations(); @@ -2607,6 +2633,13 @@ nd4j::ShapeList* calculateOutputShapes(Nd4jPointer* extraPointers, Nd4jLong hash return _calculateOutputShapes(extraPointers, op, inputShapes, numInputShapes, tArgs, numTArgs, iArgs, numIArgs); } +Nd4jLong getShapeListSize(nd4j::ShapeList* list) { + return list->size(); +} + +Nd4jLong* getShape(nd4j::ShapeList* list, Nd4jLong i) { + return list->at(i); +} static FORCEINLINE Nd4jStatus realExec(nd4j::ops::DeclarableOp* op, Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputs, Nd4jPointer* outputBuffers, Nd4jPointer* outputShapes, int numOutputs, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool* bArgs, int numBArgs, bool isInplace) { if (op == nullptr) @@ -2775,6 +2808,38 @@ VariablesSet* executeStoredGraph(Nd4jPointer *extraPointers, Nd4jLong graphId, N return executeStoredGraphT(extraPointers, graphId, inputBuffers, inputShapes, inputIndices, numInputs); } +Nd4jLong getVariableSetSize(nd4j::graph::VariablesSet* set) { + return set->size(); +} + +Nd4jStatus getVariableSetStatus(nd4j::graph::VariablesSet* set) { + return set->status(); +} + +nd4j::graph::Variable* getVariable(nd4j::graph::VariablesSet* set, Nd4jLong i) { + return set->at(i); +} + +int getVariableId(nd4j::graph::Variable* variable) { + return variable->id(); +} + +int getVariableIndex(nd4j::graph::Variable* variable) { + return variable->index(); +} + +const char* getVariableName(nd4j::graph::Variable* variable) { + return variable->getName()->c_str(); +} + +Nd4jLong* getVariableShape(nd4j::graph::Variable* variable) { + return variable->getNDArray()->shapeInfo(); +} + +void* getVariableBuffer(nd4j::graph::Variable* variable) { + return variable->getNDArray()->buffer(); +} + int unregisterGraph(Nd4jPointer *extraPointers, Nd4jLong graphId) { nd4j::graph::GraphHolder::getInstance()->dropGraphAny(graphId); @@ -3102,6 +3167,13 @@ Nd4jPointer createUtf8String(Nd4jPointer *extraPointers, const char *string, int return reinterpret_cast(u); } +Nd4jLong getUtf8StringLength(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_length; +} +char* getUtf8StringBuffer(Nd4jPointer *extraPointers, Nd4jPointer ptr) { + return reinterpret_cast(ptr)->_buffer; +} + void deleteUtf8String(Nd4jPointer *extraPointers, Nd4jPointer ptr) { delete(reinterpret_cast(ptr)); } @@ -3237,14 +3309,12 @@ nd4j::ConstantDataBuffer* shapeBuffer(int rank, Nd4jLong *shape, Nd4jLong *strid return buffer; } -void deleteShapeBuffer(Nd4jPointer ptr) { - auto buffer = reinterpret_cast(ptr); - delete buffer; +void deleteShapeBuffer(nd4j::ConstantDataBuffer* ptr) { + delete ptr; } -void deleteTadPack(Nd4jPointer ptr) { - auto buffer = reinterpret_cast(ptr); - delete buffer; +void deleteTadPack(nd4j::TadPack* ptr) { + delete ptr; } nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) { @@ -3259,6 +3329,82 @@ nd4j::ConstantDataBuffer* constantBuffer(nd4j::DataType dtype, nd4j::ConstantDes return nd4j::ConstantHelper::getInstance()->constantBuffer(*descriptor, dtype); } + +Nd4jPointer getConstantDataBufferPrimary(nd4j::ConstantDataBuffer* dbf) { + return dbf->primary(); +} +Nd4jPointer getConstantDataBufferSpecial(nd4j::ConstantDataBuffer* dbf) { + return dbf->special(); +} +Nd4jLong getConstantDataBufferLength(nd4j::ConstantDataBuffer* dbf) { + return dbf->length(); +} +Nd4jLong getConstantDataBufferSizeOf(nd4j::ConstantDataBuffer* dbf) { + return dbf->sizeOf(); +} + + +nd4j::graph::Context* createGraphContext(int nodeId) { + return new nd4j::graph::Context(nodeId); +} +nd4j::graph::RandomGenerator* getGraphContextRandomGenerator(nd4j::graph::Context* ptr) { + return &ptr->randomGenerator(); +} +void markGraphContextInplace(nd4j::graph::Context* ptr, bool reallyInplace) { + ptr->markInplace(reallyInplace); +} +void setGraphContextCudaContext(nd4j::graph::Context* ptr, void *stream, void *reductionPointer, void *allocationPointer) { + ptr->setCudaContext(stream, reductionPointer, allocationPointer); +} +void setGraphContextInputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { + ptr->setInputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +} +void setGraphContextOutputArray(nd4j::graph::Context* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo) { + ptr->setOutputArray(index, buffer, shapeInfo, specialBuffer, specialShapeInfo); +} +void setGraphContextTArguments(nd4j::graph::Context* ptr, double *arguments, int numberOfArguments) { + ptr->setTArguments(arguments, numberOfArguments); +} +void setGraphContextIArguments(nd4j::graph::Context* ptr, Nd4jLong *arguments, int numberOfArguments) { + ptr->setIArguments(arguments, numberOfArguments); +} +void setGraphContextBArguments(nd4j::graph::Context* ptr, bool *arguments, int numberOfArguments) { + ptr->setBArguments(arguments, numberOfArguments); +} +void deleteGraphContext(nd4j::graph::Context* ptr) { + delete ptr; +} + + +nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { + return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); +} + +Nd4jLong getRandomGeneratorRootState(nd4j::graph::RandomGenerator* ptr) { + return ptr->rootState(); +} + +Nd4jLong getRandomGeneratorNodeState(nd4j::graph::RandomGenerator* ptr) { + return ptr->nodeState(); +} + +void setRandomGeneratorStates(nd4j::graph::RandomGenerator* ptr, Nd4jLong rootSeed, Nd4jLong nodeSeed) { + ptr->setStates(rootSeed, nodeSeed); +} + +int getRandomGeneratorRelativeInt(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) { + return ptr->relativeInt(index); +} + +Nd4jLong getRandomGeneratorRelativeLong(nd4j::graph::RandomGenerator* ptr, Nd4jLong index) { + return ptr->relativeLong(index); +} + +void deleteRandomGenerator(nd4j::graph::RandomGenerator* ptr) { + delete ptr; +} + + Nd4jPointer shapeBufferForNumpy(Nd4jPointer npyArray) { cnpy::NpyArray arr = cnpy::loadNpyFromPointer(reinterpret_cast(npyArray)); unsigned int shapeSize = arr.shape.size(); diff --git a/libnd4j/minifier/minifier.cpp b/libnd4j/minifier/minifier.cpp index 954853b7d..071dacc17 100644 --- a/libnd4j/minifier/minifier.cpp +++ b/libnd4j/minifier/minifier.cpp @@ -27,6 +27,9 @@ #include #include +using namespace nd4j::ops; +using namespace nd4j::graph; + int main(int argc, char *argv[]) { // this string will contain list of operations diff --git a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp index d4d01fc09..36828a807 100644 --- a/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/OpTrackerTests.cpp @@ -25,6 +25,7 @@ #include using namespace nd4j; +using namespace nd4j::ops; using namespace nd4j::graph; class OpTrackerTests : public testing::Test { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index b994630b3..cd74a60a0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -28,7 +28,7 @@ import java.util.List; * * @author raver119@gmail.com */ -public interface OpContext { +public interface OpContext extends AutoCloseable { /** * This method sets integer arguments required for operation diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java index 4089223ae..8253b67bb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/autodiff/execution/NativeGraphExecutioner.java @@ -38,8 +38,9 @@ import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; -import org.nd4j.nativeblas.ResultWrapperAbstraction; +import org.nd4j.nativeblas.OpaqueResultWrapper; import java.io.File; import java.nio.ByteBuffer; @@ -100,11 +101,12 @@ public class NativeGraphExecutioner implements GraphExecutioner { log.info("Buffer length: {}", buffer.limit()); - val res = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraph(null, bPtr); + NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + OpaqueResultWrapper res = nativeOps.executeFlatGraph(null, bPtr); if (res == null) throw new ND4JIllegalStateException("Graph execution failed"); - PagedPointer pagedPointer = new PagedPointer(res.pointer(),res.size()); + PagedPointer pagedPointer = new PagedPointer(nativeOps.getResultWrapperPointer(res), nativeOps.getResultWrapperSize(res)); FlatResult fr = FlatResult.getRootAsFlatResult(pagedPointer.asBytePointer().asByteBuffer()); log.info("VarMap: {}", sd.variableMap()); @@ -132,7 +134,7 @@ public class NativeGraphExecutioner implements GraphExecutioner { } // now we need to release native memory - NativeOpsHolder.getInstance().getDeviceNativeOps().deleteResultWrapper(res); + nativeOps.deleteResultWrapper(res); return results; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index c3942d0aa..77cbdd4eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -697,7 +697,16 @@ public interface NativeOps { void setGridLimit(int gridSize); - Pointer tadOnlyShapeInfo(@Cast("Nd4jLong *") LongPointer shapeInfo, IntPointer dimension, int dimensionLength); + OpaqueTadPack tadOnlyShapeInfo(LongPointer shapeInfo, IntPointer dimension, int dimensionLength); + + LongPointer getPrimaryShapeInfo(OpaqueTadPack pack); + LongPointer getPrimaryOffsets(OpaqueTadPack pack); + LongPointer getSpecialShapeInfo(OpaqueTadPack pack); + LongPointer getSpecialOffsets(OpaqueTadPack pack); + long getNumberOfTads(OpaqueTadPack pack); + int getShapeInfoLength(OpaqueTadPack pack); + + void deleteTadPack(OpaqueTadPack pointer); /////////////// @@ -1037,7 +1046,10 @@ public interface NativeOps { void munmapFile(PointerPointer extraPointers, LongPointer ptrMap, long length); - ResultWrapperAbstraction executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer); + OpaqueResultWrapper executeFlatGraph(PointerPointer extraPointers, Pointer flatBufferPointer); + + long getResultWrapperSize(OpaqueResultWrapper ptr); + Pointer getResultWrapperPointer(OpaqueResultWrapper ptr); String getAllCustomOps(); @@ -1047,13 +1059,25 @@ public interface NativeOps { int execCustomOp(PointerPointer extraPointers, long opHashCode, PointerPointer inputBuffers, PointerPointer inputShapes, int numInput, PointerPointer outputBuffers, PointerPointer outputShapes, int numOutputs, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs, boolean isInplace); - Pointer calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs); + OpaqueShapeList calculateOutputShapes(PointerPointer extraPointers, long hash, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs); - Pointer calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs); + OpaqueShapeList calculateOutputShapes2(PointerPointer extraPointers, long hash, PointerPointer inputBunffers, PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong *") LongPointer iArgs, int numIArgs, @Cast("bool *") BooleanPointer bArgs, int numBArgs); + + long getShapeListSize(OpaqueShapeList list); + LongPointer getShape(OpaqueShapeList list, long i); int registerGraph(PointerPointer extraPointers, long graphId, Pointer flatBufferPointer); - Pointer executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs); + OpaqueVariableSet executeStoredGraph(PointerPointer extraPointers, long graphId, PointerPointer inputBuffers, PointerPointer inputShapes, IntPointer inputIndices, int numInputs); + + long getVariableSetSize(OpaqueVariableSet set); + int getVariableSetStatus(OpaqueVariableSet set); + OpaqueVariable getVariable(OpaqueVariableSet set, long i); + int getVariableId(OpaqueVariable variable); + int getVariableIndex(OpaqueVariable variable); + String getVariableName(OpaqueVariable variable); + LongPointer getVariableShape(OpaqueVariable variable); + Pointer getVariableBuffer(OpaqueVariable variable); void deleteResultWrapper(Pointer ptr); @@ -1071,15 +1095,11 @@ public interface NativeOps { void deleteNPArrayMap(Pointer pointer); - void deleteVariablesSet(Pointer pointer); + void deleteVariablesSet(OpaqueVariableSet pointer); // GraphState creation Pointer getGraphState(long id); - void deleteShapeBuffer(Pointer state); - - void deleteTadPack(Pointer pointer); - void deleteGraphState(Pointer state); int estimateThreshold(PointerPointer extraPointers, Pointer x, LongPointer xShapeInfo, int N, float threshold); @@ -1096,6 +1116,8 @@ public interface NativeOps { //void fillUtf8String(PointerPointer extraPointers, String[] string, int numStrings, Pointer buffer); Pointer createUtf8String(PointerPointer extraPointers, String string, int length); + long getUtf8StringLength(PointerPointer extraPointers, Pointer ptr); + BytePointer getUtf8StringBuffer(PointerPointer extraPointers, Pointer ptr); void deleteUtf8String(PointerPointer extraPointers, Pointer ptr); @@ -1116,11 +1138,37 @@ public interface NativeOps { */ int dataTypeFromNpyHeader(Pointer numpyHeader); - Pointer shapeBuffer(int rank, @Cast("Nd4jLong *") LongPointer shape, @Cast("Nd4jLong *") LongPointer strides, int dtype, char order, long ews, boolean empty); + OpaqueConstantDataBuffer shapeBuffer(int rank, LongPointer shape, LongPointer strides, int dtype, char order, long ews, boolean empty); - Pointer constantBufferDouble(int dtype, DoublePointer data, int length); + OpaqueConstantDataBuffer constantBufferDouble(int dtype, DoublePointer data, int length); - Pointer constantBufferLong(int dtype, @Cast("Nd4jLong *") LongPointer data, int length); + OpaqueConstantDataBuffer constantBufferLong(int dtype, LongPointer data, int length); + + Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf); + Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf); + long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf); + long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf); + + void deleteShapeBuffer(OpaqueConstantDataBuffer state); + + OpaqueContext createGraphContext(int nodeId); + OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); + void markGraphContextInplace(OpaqueContext ptr, boolean reallyInplace); + void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); + void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); + void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); + void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); + void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); + void deleteGraphContext(OpaqueContext ptr); + + OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed); + long getRandomGeneratorRootState(OpaqueRandomGenerator ptr); + long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr); + void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); + int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index); + long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index); + void deleteRandomGenerator(OpaqueRandomGenerator ptr); String runLightBenchmarkSuit(boolean printOut); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueConstantDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueConstantDataBuffer.java new file mode 100644 index 000000000..8e198c186 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueConstantDataBuffer.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueConstantDataBuffer extends Pointer { + public OpaqueConstantDataBuffer(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueContext.java new file mode 100644 index 000000000..058649c02 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueContext.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueContext extends Pointer { + public OpaqueContext(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueRandomGenerator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueRandomGenerator.java new file mode 100644 index 000000000..b76015285 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueRandomGenerator.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueRandomGenerator extends Pointer { + public OpaqueRandomGenerator(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueResultWrapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueResultWrapper.java new file mode 100644 index 000000000..331bf465c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueResultWrapper.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueResultWrapper extends Pointer { + public OpaqueResultWrapper(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueShapeList.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueShapeList.java new file mode 100644 index 000000000..b290d88cf --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueShapeList.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueShapeList extends Pointer { + public OpaqueShapeList(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueTadPack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueTadPack.java new file mode 100644 index 000000000..a959fd375 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueTadPack.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueTadPack extends Pointer { + public OpaqueTadPack(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueVariable.java new file mode 100644 index 000000000..6051e81a2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueVariable.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueVariable extends Pointer { + public OpaqueVariable(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueVariableSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueVariableSet.java new file mode 100644 index 000000000..321a049ce --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/OpaqueVariableSet.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.nativeblas; + +import org.bytedeco.javacpp.Pointer; + +/** + * + * @author saudet + */ +public class OpaqueVariableSet extends Pointer { + public OpaqueVariableSet(Pointer p) { super(p); } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 3d096bca2..90e77c7ae 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -72,6 +72,11 @@ import org.nd4j.nativeblas.LongPointerWrapper; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.Nd4jCuda; +import org.nd4j.nativeblas.OpaqueConstantDataBuffer; +import org.nd4j.nativeblas.OpaqueShapeList; +import org.nd4j.nativeblas.OpaqueTadPack; +import org.nd4j.nativeblas.OpaqueVariable; +import org.nd4j.nativeblas.OpaqueVariableSet; import java.util.*; @@ -2208,13 +2213,13 @@ public class CudaExecutioner extends DefaultOpExecutioner { for (val t: op.tArgs()) tArgs.put(cnt++, (float) t); - val ptrptr = (Nd4jCuda.ShapeList) nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); + OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.inputArguments().length, tArgs, op.tArgs().length, iArgs, op.iArgs().length, bArgs, op.numBArguments()); if (ptrptr == null) throw new RuntimeException(); - for (int e = 0; e < ptrptr.size(); e++ ) - result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer())); + for (int e = 0; e < nativeOps.getShapeListSize(ptrptr); e++ ) + result.add(getShapeFromPointer(new PagedPointer(nativeOps.getShape(ptrptr, e)).asLongPointer())); nativeOps.deleteShapeList(ptrptr); @@ -2251,28 +2256,32 @@ public class CudaExecutioner extends DefaultOpExecutioner { val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - val context = (CudaOpContext) buildContext(); - context.markInplace(op.isInplaceCall()); + val name = op.opName(); + try (val context = (CudaOpContext) buildContext()) { + context.markInplace(op.isInplaceCall()); - // transferring rng state - context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState()); + // transferring rng state + context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState()); - //transferring input/output arrays - context.setInputArrays(op.inputArguments()); - context.setOutputArrays(op.outputArguments()); + //transferring input/output arrays + context.setInputArrays(op.inputArguments()); + context.setOutputArrays(op.outputArguments()); - // transferring static args - context.setBArguments(op.bArgs()); - context.setIArguments(op.iArgs()); - context.setTArguments(op.tArgs()); + // transferring static args + context.setBArguments(op.bArgs()); + context.setIArguments(op.iArgs()); + context.setTArguments(op.tArgs()); - val result = exec(op, context); - val states = context.getRngStates(); + val result = exec(op, context); + val states = context.getRngStates(); - // pulling states back - Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); + // pulling states back + Nd4j.getRandom().setStates(states.getFirst(), states.getSecond()); - return result; + return result; + } catch (Exception e) { + throw new RuntimeException("Op [" + name + "] execution failed", e); + } /* long st = profilingConfigurableHookIn(op); @@ -2418,19 +2427,19 @@ public class CudaExecutioner extends DefaultOpExecutioner { val newMap = new LinkedHashMap(); - val result = (Nd4jCuda.VariablesSet) nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); + OpaqueVariableSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); - val status = OpStatus.byNumber(result.status()); + OpStatus status = OpStatus.byNumber(nativeOps.getVariableSetStatus(result)); if (status != OpStatus.ND4J_STATUS_OK) throw new ND4JIllegalStateException("Op execution failed: " + status); - for (int e = 0; e < result.size(); e++) { - val var = result.at(e); - val nodeId = var.id(); - val index = var.index(); - val shapeInfo = var.getNDArray().shapeInfo(); - val buffer = var.getNDArray().buffer(); + for (int e = 0; e < nativeOps.getVariableSetSize(result); e++) { + OpaqueVariable var = nativeOps.getVariable(result, e); + int nodeId = nativeOps.getVariableId(var); + int index = nativeOps.getVariableIndex(var); + LongPointer shapeInfo = nativeOps.getVariableShape(var); + Pointer buffer = nativeOps.getVariableBuffer(var); val rank = (int) shapeInfo.get(0); val jshape = new long[rank * 2 + 4]; @@ -2446,7 +2455,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType()); AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); - val nodeName = var.getName().getString(); + String nodeName = nativeOps.getVariableName(var); newMap.put(nodeName, array); } @@ -2584,9 +2593,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { - val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); + OpaqueConstantDataBuffer dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); - val result = new CudaLongDataBuffer(dbf.primary(), dbf.special(), Shape.shapeInfoLength(shape.length)); + val result = new CudaLongDataBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), Shape.shapeInfoLength(shape.length)); nativeOps.deleteShapeBuffer(dbf); @@ -2595,10 +2604,10 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { - val pack = (Nd4jCuda.TadPack) nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); + OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); - val tadShape = new CudaLongDataBuffer(pack.primaryShapeInfo(), pack.specialShapeInfo(), pack.shapeInfoLength()); - val tadOffsets = new CudaLongDataBuffer(pack.primaryOffsets(), pack.specialOffsets(), pack.numberOfTads()); + val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack)); + val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack)); nativeOps.deleteTadPack(pack); @@ -2607,9 +2616,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createConstantBuffer(long[] values, DataType desiredType) { - val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); + OpaqueConstantDataBuffer dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length); - val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType); + val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType); buffer.setConstant(true); return buffer; @@ -2617,9 +2626,9 @@ public class CudaExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createConstantBuffer(double[] values, DataType desiredType) { - val dbf = (Nd4jCuda.ConstantDataBuffer) nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); + OpaqueConstantDataBuffer dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length); - val buffer = Nd4j.createBuffer(dbf.primary(), dbf.special(), values.length, desiredType); + val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType); buffer.setConstant(true); return buffer; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 476105ee6..8db04257b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -18,6 +18,9 @@ package org.nd4j.linalg.jcublas.ops.executioner; import lombok.NonNull; import lombok.val; +import org.bytedeco.javacpp.BooleanPointer; +import org.bytedeco.javacpp.DoublePointer; +import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; @@ -28,7 +31,10 @@ import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.jcublas.context.CudaContext; import org.nd4j.linalg.primitives.Pair; -import org.nd4j.nativeblas.Nd4jCuda; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueContext; +import org.nd4j.nativeblas.OpaqueRandomGenerator; /** * CUDA wrapper for op Context @@ -36,34 +42,41 @@ import org.nd4j.nativeblas.Nd4jCuda; */ public class CudaOpContext extends BaseOpContext implements OpContext { // we might want to have configurable - private Nd4jCuda.Context context = new Nd4jCuda.Context(1); + private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private OpaqueContext context = nativeOps.createGraphContext(1); + + @Override + public void close() { + nativeOps.deleteGraphContext(context); + } @Override public void setIArguments(long... arguments) { super.setIArguments(arguments); - context.setIArguments(arguments, arguments.length); + nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); } @Override public void setBArguments(boolean... arguments) { super.setBArguments(arguments); - context.setBArguments(arguments, arguments.length); + nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); } @Override public void setTArguments(double... arguments) { super.setTArguments(arguments); - context.setTArguments(arguments, arguments.length); + nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); } @Override public void setRngStates(long rootState, long nodeState) { - context.randomGenerator().setStates(rootState, nodeState); + nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState); } @Override public Pair getRngStates() { - return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState()); + OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context); + return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g)); } @Override @@ -72,7 +85,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); + nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setInputArray(index, array); } @@ -82,7 +95,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext { Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE); val ctx = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); + nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer())); super.setOutputArray(index, array); } @@ -113,11 +126,11 @@ public class CudaOpContext extends BaseOpContext implements OpContext { public void setCudaStream(cudaStream_t stream, Pointer reductionPointer, Pointer allocationPointer) { - context.setCudaContext(stream, reductionPointer, allocationPointer); + nativeOps.setGraphContextCudaContext(context, stream, reductionPointer, allocationPointer); } @Override public void markInplace(boolean reallyInplace) { - context.markInplace(reallyInplace); + nativeOps.markGraphContextInplace(context, reallyInplace); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java index 384da19b9..6e2d8ebf0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/rng/CudaNativeRandom.java @@ -21,7 +21,9 @@ import org.bytedeco.javacpp.PointerPointer; import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.jcublas.context.CudaContext; -import org.nd4j.nativeblas.Nd4jCuda; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueRandomGenerator; import org.nd4j.rng.NativeRandom; import java.util.List; @@ -33,7 +35,7 @@ import java.util.List; */ @Slf4j public class CudaNativeRandom extends NativeRandom { - + private NativeOps nativeOps; protected List stateBuffers; public CudaNativeRandom() { @@ -50,10 +52,16 @@ public class CudaNativeRandom extends NativeRandom { @Override public void init() { - statePointer = new Nd4jCuda.RandomGenerator(seed, seed ^ 0xdeadbeef); + nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef); setSeed(seed); } + @Override + public void close() { + nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer); + } + @Override public PointerPointer getExtraPointers() { return null; @@ -63,7 +71,7 @@ public class CudaNativeRandom extends NativeRandom { public void setSeed(long seed) { this.seed = seed; this.currentPosition.set(0); - ((Nd4jCuda.RandomGenerator) statePointer).setStates(seed, seed ^ 0xdeadbeef); + nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef); } @Override @@ -73,24 +81,24 @@ public class CudaNativeRandom extends NativeRandom { @Override public int nextInt() { - return ((Nd4jCuda.RandomGenerator) statePointer).relativeInt(currentPosition.getAndIncrement()); + return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement()); } @Override public long nextLong() { - return ((Nd4jCuda.RandomGenerator) statePointer).relativeLong(currentPosition.getAndIncrement()); + return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement()); } public long rootState() { - return ((Nd4jCuda.RandomGenerator) statePointer).rootState(); + return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer); } public long nodeState() { - return ((Nd4jCuda.RandomGenerator) statePointer).nodeState(); + return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer); } @Override public void setStates(long rootState, long nodeState) { - ((Nd4jCuda.RandomGenerator) statePointer).setStates(rootState, nodeState); + nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 934ba6537..be024e5f4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize); * @param targetBuffer * @param offsetsBuffer */ -public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo, +public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo, IntPointer dimension, int dimensionLength); -public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo, +public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo, IntBuffer dimension, int dimensionLength); -public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo, +public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo, int[] dimension, int dimensionLength); +public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack); +public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack); +public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack); +public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack); +public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack); +public native int getShapeInfoLength(OpaqueTadPack pack); + +public native void deleteTadPack(OpaqueTadPack ptr); + /* * PullRow special op */ @@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length); public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length); - // flatbuffers execution -public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr); +public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr); public native @Cast("char*") String getAllCustomOps(); @@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext); -public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); -public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); -public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); + +public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); +public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); -public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); -public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); -public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); +public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); +public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); +public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); + +public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set); +public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set); +public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i); +public native int getVariableId(OpaqueVariable variable); +public native int getVariableIndex(OpaqueVariable variable); +public native @Cast("char*") String getVariableName(OpaqueVariable variable); +public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable); +public native Pointer getVariableBuffer(OpaqueVariable variable); public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId); @@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer); public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); -public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer); +public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer); // GraphState creation public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); @@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*" //void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length); +public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); +public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, @@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); +public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length); +public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor); -public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length); -public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length); -public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length); -public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length); -public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length); -public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length); -public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor); +public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf); +public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); + +public native OpaqueContext createGraphContext(int nodeId); +public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); +public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); +public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); +public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments); +public native void deleteGraphContext(OpaqueContext ptr); + +public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); +public native OpaqueRandomGenerator createRandomGenerator(); +public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr); +public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr); +public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); +public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr); +public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index); +public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index); +public native void deleteRandomGenerator(OpaqueRandomGenerator ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); @@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); + + /** + * This method returns new array with the same shape & data type + * @return + */ + public native @ByVal NDArray like(); + + /** + * This method returns new uninitialized array with the same shape & data type + * @return + */ + public native @ByVal NDArray ulike(); + + /** * this constructor creates new NDArray with shape matching "other" array, * doesn't copy "other" elements into new array !!! diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java index 7df0abd1c..78093cf04 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCudaPresets.java @@ -113,6 +113,14 @@ public class Nd4jCudaPresets implements InfoMapper { infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", "_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations()) .put(new Info("NativeOps.h").objectify()) + .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) + .put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper")) + .put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList")) + .put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet")) + .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) + .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) + .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) + .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 5f6fd2e5f..0a7987492 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -48,7 +48,6 @@ import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.nativeblas.BaseNativeNDArrayFactory; import org.nd4j.nativeblas.LongPointerWrapper; import org.nd4j.nativeblas.NativeOpsHolder; -import org.nd4j.nativeblas.Nd4jCpu; import java.util.*; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index fb9845b21..8db359d01 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -17,12 +17,18 @@ package org.nd4j.linalg.cpu.nativecpu.ops; import lombok.NonNull; +import org.bytedeco.javacpp.BooleanPointer; +import org.bytedeco.javacpp.DoublePointer; +import org.bytedeco.javacpp.LongPointer; import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.primitives.Pair; -import org.nd4j.nativeblas.Nd4jCpu; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueContext; +import org.nd4j.nativeblas.OpaqueRandomGenerator; import java.util.List; @@ -33,46 +39,53 @@ import java.util.List; */ public class CpuOpContext extends BaseOpContext implements OpContext { // we might want to have configurable - private Nd4jCpu.Context context = new Nd4jCpu.Context(1); + private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private OpaqueContext context = nativeOps.createGraphContext(1); + + @Override + public void close() { + nativeOps.deleteGraphContext(context); + } @Override public void setIArguments(long... arguments) { super.setIArguments(arguments); - context.setIArguments(arguments, arguments.length); + nativeOps.setGraphContextIArguments(context, new LongPointer(arguments), arguments.length); } @Override public void setBArguments(boolean... arguments) { super.setBArguments(arguments); - context.setBArguments(arguments, arguments.length); + nativeOps.setGraphContextBArguments(context, new BooleanPointer(arguments), arguments.length); } @Override public void setTArguments(double... arguments) { super.setTArguments(arguments); - context.setTArguments(arguments, arguments.length); + nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); } @Override public void setRngStates(long rootState, long nodeState) { - context.randomGenerator().setStates(rootState, nodeState); + nativeOps.setRandomGeneratorStates(nativeOps.getGraphContextRandomGenerator(context), rootState, nodeState); } @Override public Pair getRngStates() { - return Pair.makePair(context.randomGenerator().rootState(), context.randomGenerator().nodeState()); + OpaqueRandomGenerator g = nativeOps.getGraphContextRandomGenerator(context); + return Pair.makePair(nativeOps.getRandomGeneratorRootState(g), nativeOps.getRandomGeneratorNodeState(g)); } @Override public void setInputArray(int index, @NonNull INDArray array) { - context.setInputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); super.setInputArray(index, array); } @Override public void setOutputArray(int index, @NonNull INDArray array) { - context.setOutputArray(index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); + nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), null, null); super.setOutputArray(index, array); } @@ -84,6 +97,6 @@ public class CpuOpContext extends BaseOpContext implements OpContext { @Override public void markInplace(boolean reallyInplace) { - context.markInplace(reallyInplace); + nativeOps.markGraphContextInplace(context, reallyInplace); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index e84f82186..49eb05208 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -70,10 +70,14 @@ import org.nd4j.nativeblas.LongPointerWrapper; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.nativeblas.Nd4jCpu; +import org.nd4j.nativeblas.OpaqueConstantDataBuffer; +import org.nd4j.nativeblas.OpaqueShapeList; +import org.nd4j.nativeblas.OpaqueTadPack; +import org.nd4j.nativeblas.OpaqueVariable; +import org.nd4j.nativeblas.OpaqueVariableSet; import java.util.*; - /** * * Native operation @@ -1641,23 +1645,22 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } val name = op.opName(); - val context = buildContext(); + try (val context = buildContext()) { - context.markInplace(op.isInplaceCall()); + context.markInplace(op.isInplaceCall()); - // transferring rng state - context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState()); + // transferring rng state + context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState()); - //transferring input/output arrays - context.setInputArrays(op.inputArguments()); - context.setOutputArrays(op.outputArguments()); + //transferring input/output arrays + context.setInputArrays(op.inputArguments()); + context.setOutputArrays(op.outputArguments()); - // transferring static args - context.setBArguments(op.bArgs()); - context.setIArguments(op.iArgs()); - context.setTArguments(op.tArgs()); + // transferring static args + context.setBArguments(op.bArgs()); + context.setIArguments(op.iArgs()); + context.setTArguments(op.tArgs()); - try { val result = exec(op, context); val states = context.getRngStates(); @@ -1860,9 +1863,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { for (val t: tArgs1) tArgs.put(cnt++, t); - Nd4jCpu.ShapeList ptrptr; + OpaqueShapeList ptrptr; try { - ptrptr = (Nd4jCpu.ShapeList) loop.calculateOutputShapes2(null, + ptrptr = loop.calculateOutputShapes2(null, hash, inputBuffers, inputShapes, op.numInputArguments(), tArgs, op.numTArguments(), iArgs, op.numIArguments(), bArgs, op.numBArguments()); } catch (Throwable t){ @@ -1891,8 +1894,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (ptrptr == null) throw new RuntimeException(); - for (int e = 0; e < ptrptr.size(); e++ ) - result.add(getShapeFromPointer(new PagedPointer(ptrptr.at(e)).asLongPointer())); + for (int e = 0; e < loop.getShapeListSize(ptrptr); e++ ) + result.add(getShapeFromPointer(new PagedPointer(loop.getShape(ptrptr, e)).asLongPointer())); loop.deleteShapeList(ptrptr); @@ -1947,19 +1950,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val newMap = new LinkedHashMap(); - val result = (Nd4jCpu.VariablesSet) loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); + OpaqueVariableSet result = loop.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size()); - val status = OpStatus.byNumber(result.status()); + OpStatus status = OpStatus.byNumber(loop.getVariableSetStatus(result)); if (status != OpStatus.ND4J_STATUS_OK) throw new ND4JIllegalStateException("Op execution failed: " + status); - for (int e = 0; e < result.size(); e++) { - val var = result.at(e); - val nodeId = var.id(); - val index = var.index(); - val shapeInfo = var.getNDArray().shapeInfo(); - val buffer = var.getNDArray().buffer(); + for (int e = 0; e < loop.getVariableSetSize(result); e++) { + OpaqueVariable var = loop.getVariable(result, e); + int nodeId = loop.getVariableId(var); + int index = loop.getVariableIndex(var); + LongPointer shapeInfo = loop.getVariableShape(var); + Pointer buffer = loop.getVariableBuffer(var); val rank = (int) shapeInfo.get(0); val jshape = new long[rank * 2 + 4]; @@ -1979,7 +1982,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(array.dataType()), MemcpyDirection.HOST_TO_HOST); //newMap.put(keySet.get(nodeId), array); - val nodeName = var.getName().getString(); + String nodeName = loop.getVariableName(var); newMap.put(nodeName, array); } @@ -2160,9 +2163,9 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) { - val dbf = (Nd4jCpu.ConstantDataBuffer) loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); + OpaqueConstantDataBuffer dbf = loop.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty); - val result = new LongBuffer(dbf.primary(), Shape.shapeInfoLength(shape.length)); + val result = new LongBuffer(loop.getConstantDataBufferPrimary(dbf), Shape.shapeInfoLength(shape.length)); loop.deleteShapeBuffer(dbf); @@ -2171,10 +2174,10 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { @Override public TadPack tadShapeInfoAndOffsets(INDArray array, int[] dimension) { - val pack = (Nd4jCpu.TadPack) loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); + OpaqueTadPack pack = loop.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new IntPointer(dimension), dimension.length); - val tadShape = new LongBuffer(pack.primaryShapeInfo(), pack.shapeInfoLength()); - val tadOffsets = new LongBuffer(pack.primaryOffsets(), pack.numberOfTads()); + val tadShape = new LongBuffer(loop.getPrimaryShapeInfo(pack), loop.getShapeInfoLength(pack)); + val tadOffsets = new LongBuffer(loop.getPrimaryOffsets(pack), loop.getNumberOfTads(pack)); loop.deleteTadPack(pack); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/rng/CpuNativeRandom.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/rng/CpuNativeRandom.java index 30c82cf87..360bc9c2e 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/rng/CpuNativeRandom.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/rng/CpuNativeRandom.java @@ -18,7 +18,9 @@ package org.nd4j.linalg.cpu.nativecpu.rng; import org.bytedeco.javacpp.PointerPointer; import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.nativeblas.Nd4jCpu; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueRandomGenerator; import org.nd4j.rng.NativeRandom; import java.util.concurrent.atomic.AtomicLong; @@ -29,6 +31,8 @@ import java.util.concurrent.atomic.AtomicLong; * @author raver119@gmail.com */ public class CpuNativeRandom extends NativeRandom { + private NativeOps nativeOps; + public CpuNativeRandom() { super(); } @@ -43,7 +47,13 @@ public class CpuNativeRandom extends NativeRandom { @Override public void init() { - statePointer = new Nd4jCpu.RandomGenerator(this.seed, this.seed ^ 0xdeadbeef); + nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + statePointer = nativeOps.createRandomGenerator(this.seed, this.seed ^ 0xdeadbeef); + } + + @Override + public void close() { + nativeOps.deleteRandomGenerator((OpaqueRandomGenerator)statePointer); } @Override @@ -55,7 +65,7 @@ public class CpuNativeRandom extends NativeRandom { public void setSeed(long seed) { this.seed = seed; this.currentPosition.set(0); - ((Nd4jCpu.RandomGenerator)statePointer).setStates(seed, seed ^ 0xdeadbeef); + nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, seed, seed ^ 0xdeadbeef); } @Override @@ -65,24 +75,24 @@ public class CpuNativeRandom extends NativeRandom { @Override public int nextInt() { - return ((Nd4jCpu.RandomGenerator)statePointer).relativeInt(currentPosition.getAndIncrement()); + return nativeOps.getRandomGeneratorRelativeInt((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement()); } @Override public long nextLong() { - return ((Nd4jCpu.RandomGenerator)statePointer).relativeLong(currentPosition.getAndIncrement()); + return nativeOps.getRandomGeneratorRelativeLong((OpaqueRandomGenerator)statePointer, currentPosition.getAndIncrement()); } public long rootState() { - return ((Nd4jCpu.RandomGenerator) statePointer).rootState(); + return nativeOps.getRandomGeneratorRootState((OpaqueRandomGenerator)statePointer); } public long nodeState() { - return ((Nd4jCpu.RandomGenerator) statePointer).nodeState(); + return nativeOps.getRandomGeneratorNodeState((OpaqueRandomGenerator)statePointer); } @Override public void setStates(long rootState, long nodeState) { - ((Nd4jCpu.RandomGenerator) statePointer).setStates(rootState, nodeState); + nativeOps.setRandomGeneratorStates((OpaqueRandomGenerator)statePointer, rootState, nodeState); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index d93d0f759..0af5910e8 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -2097,16 +2097,25 @@ public native void setGridLimit(int gridSize); * @param targetBuffer * @param offsetsBuffer */ -public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo, +public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongPointer xShapeInfo, IntPointer dimension, int dimensionLength); -public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo, +public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") LongBuffer xShapeInfo, IntBuffer dimension, int dimensionLength); -public native TadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo, +public native OpaqueTadPack tadOnlyShapeInfo(@Cast("Nd4jLong*") long[] xShapeInfo, int[] dimension, int dimensionLength); +public native @Cast("Nd4jLong*") LongPointer getPrimaryShapeInfo(OpaqueTadPack pack); +public native @Cast("Nd4jLong*") LongPointer getPrimaryOffsets(OpaqueTadPack pack); +public native @Cast("Nd4jLong*") LongPointer getSpecialShapeInfo(OpaqueTadPack pack); +public native @Cast("Nd4jLong*") LongPointer getSpecialOffsets(OpaqueTadPack pack); +public native @Cast("Nd4jLong") long getNumberOfTads(OpaqueTadPack pack); +public native int getShapeInfoLength(OpaqueTadPack pack); + +public native void deleteTadPack(OpaqueTadPack ptr); + /* * PullRow special op */ @@ -2943,10 +2952,11 @@ public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") LongBuffer ptrMap, @Cast("Nd4jLong") long length); public native void munmapFile(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong*") long[] ptrMap, @Cast("Nd4jLong") long length); - // flatbuffers execution -public native ResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native OpaqueResultWrapper executeFlatGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer flatBufferPointer); +public native @Cast("Nd4jLong") long getResultWrapperSize(OpaqueResultWrapper ptr); +public native @Cast("Nd4jPointer") Pointer getResultWrapperPointer(OpaqueResultWrapper ptr); public native @Cast("char*") String getAllCustomOps(); @@ -2961,23 +2971,35 @@ public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointer public native int execCustomOp(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputs, @Cast("Nd4jPointer*") PointerPointer outputBuffers, @Cast("Nd4jPointer*") PointerPointer outputShapes, int numOutputs, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs, @Cast("bool") boolean isInplace); public native int execCustomOp2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer") Pointer opContext); -public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); -public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); -public native ShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); -public native ShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoublePointer tArgs, int numTArgs, @Cast("Nd4jLong*") LongPointer iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, DoubleBuffer tArgs, int numTArgs, @Cast("Nd4jLong*") LongBuffer iArgs, int numIArgs, @Cast("bool*") BooleanPointer bArgs, int numBArgs); +public native OpaqueShapeList calculateOutputShapes2(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long hash, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int numInputShapes, double[] tArgs, int numTArgs, @Cast("Nd4jLong*") long[] iArgs, int numIArgs, @Cast("bool*") boolean[] bArgs, int numBArgs); + +public native @Cast("Nd4jLong") long getShapeListSize(OpaqueShapeList list); +public native @Cast("Nd4jLong*") LongPointer getShape(OpaqueShapeList list, @Cast("Nd4jLong") long i); public native void deleteShapeList(@Cast("Nd4jPointer") Pointer shapeList); public native int registerGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer") Pointer flatBufferPointer); -public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); -public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); -public native VariablesSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); +public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntPointer inputIndices, int numInputs); +public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, IntBuffer inputIndices, int numInputs); +public native OpaqueVariableSet executeStoredGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId, @Cast("Nd4jPointer*") PointerPointer inputBuffers, @Cast("Nd4jPointer*") PointerPointer inputShapes, int[] inputIndices, int numInputs); + +public native @Cast("Nd4jLong") long getVariableSetSize(OpaqueVariableSet set); +public native @Cast("Nd4jStatus") int getVariableSetStatus(OpaqueVariableSet set); +public native OpaqueVariable getVariable(OpaqueVariableSet set, @Cast("Nd4jLong") long i); +public native int getVariableId(OpaqueVariable variable); +public native int getVariableIndex(OpaqueVariable variable); +public native @Cast("char*") String getVariableName(OpaqueVariable variable); +public native @Cast("Nd4jLong*") LongPointer getVariableShape(OpaqueVariable variable); +public native Pointer getVariableBuffer(OpaqueVariable variable); public native int unregisterGraph(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jLong") long graphId); @@ -2986,7 +3008,7 @@ public native void deleteIntArray(@Cast("Nd4jPointer") Pointer pointer); public native void deleteLongArray(@Cast("Nd4jPointer") Pointer pointer); public native void deletePointerArray(@Cast("Nd4jPointer") Pointer pointer); -public native void deleteVariablesSet(@Cast("Nd4jPointer") Pointer pointer); +public native void deleteVariablesSet(@ByVal OpaqueVariableSet pointer); // GraphState creation public native @Cast("Nd4jPointer") Pointer getGraphState(@Cast("Nd4jLong") long id); @@ -3007,6 +3029,8 @@ public native @Cast("Nd4jStatus") int execCustomOpWithScope(@Cast("Nd4jPointer*" //void fillUtf8String(Nd4jPointer *extraPointers, const char **string, int numStrings, Nd4jPointer buffer); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") String string, int length); public native @Cast("Nd4jPointer") Pointer createUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("char*") BytePointer string, int length); +public native @Cast("Nd4jLong") long getUtf8StringLength(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); +public native @Cast("char*") BytePointer getUtf8StringBuffer(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); public native void deleteUtf8String(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer ptr); public native void scatterUpdate(@Cast("Nd4jPointer*") PointerPointer extraPointers, int opCode, int numOfSubArrs, @@ -3032,19 +3056,50 @@ public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointe public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") LongBuffer shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") LongBuffer specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); public native void inspectArray(@Cast("Nd4jPointer*") PointerPointer extraPointers, @Cast("Nd4jPointer") Pointer buffer, @Cast("Nd4jLong*") long[] shapeInfo, @Cast("Nd4jPointer") Pointer specialBuffer, @Cast("Nd4jLong*") long[] specialShapeInfo, @Cast("Nd4jPointer") Pointer debugInfo); +public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongPointer shape, @Cast("Nd4jLong*") LongPointer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") LongBuffer shape, @Cast("Nd4jLong*") LongBuffer strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); -public native ConstantDataBuffer shapeBuffer(int rank, @Cast("Nd4jLong*") long[] shape, @Cast("Nd4jLong*") long[] strides, @Cast("nd4j::DataType") int dtype, char order, @Cast("Nd4jLong") long ews, @Cast("bool") boolean empty); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length); +public native OpaqueConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length); +public native OpaqueConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length); +public native OpaqueConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor); -public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongPointer data, int length); -public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") LongBuffer data, int length); -public native ConstantDataBuffer constantBufferLong(@Cast("nd4j::DataType") int dtype, @Cast("Nd4jLong*") long[] data, int length); -public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoublePointer data, int length); -public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, DoubleBuffer data, int length); -public native ConstantDataBuffer constantBufferDouble(@Cast("nd4j::DataType") int dtype, double[] data, int length); -public native ConstantDataBuffer constantBuffer(@Cast("nd4j::DataType") int dtype, ConstantDescriptor descriptor); +public native @Cast("Nd4jPointer") Pointer getConstantDataBufferPrimary(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jPointer") Pointer getConstantDataBufferSpecial(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jLong") long getConstantDataBufferLength(OpaqueConstantDataBuffer dbf); +public native @Cast("Nd4jLong") long getConstantDataBufferSizeOf(OpaqueConstantDataBuffer dbf); +public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); + +public native OpaqueContext createGraphContext(int nodeId); +public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); +public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); +public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); +public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextOutputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); +public native void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, DoubleBuffer arguments, int numberOfArguments); +public native void setGraphContextTArguments(OpaqueContext ptr, double[] arguments, int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongPointer arguments, int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") LongBuffer arguments, int numberOfArguments); +public native void setGraphContextIArguments(OpaqueContext ptr, @Cast("Nd4jLong*") long[] arguments, int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") BooleanPointer arguments, int numberOfArguments); +public native void setGraphContextBArguments(OpaqueContext ptr, @Cast("bool*") boolean[] arguments, int numberOfArguments); +public native void deleteGraphContext(OpaqueContext ptr); + +public native OpaqueRandomGenerator createRandomGenerator(@Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); +public native OpaqueRandomGenerator createRandomGenerator(); +public native @Cast("Nd4jLong") long getRandomGeneratorRootState(OpaqueRandomGenerator ptr); +public native @Cast("Nd4jLong") long getRandomGeneratorNodeState(OpaqueRandomGenerator ptr); +public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long rootSeed/*=0*/, @Cast("Nd4jLong") long nodeSeed/*=0*/); +public native void setRandomGeneratorStates(OpaqueRandomGenerator ptr); +public native int getRandomGeneratorRelativeInt(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index); +public native @Cast("Nd4jLong") long getRandomGeneratorRelativeLong(OpaqueRandomGenerator ptr, @Cast("Nd4jLong") long index); +public native void deleteRandomGenerator(OpaqueRandomGenerator ptr); public native @Cast("char*") String runLightBenchmarkSuit(@Cast("bool") boolean printOut); public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean printOut); @@ -3705,6 +3760,20 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p public NDArray(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype) { super((Pointer)null); allocate(buffer, order, shape, dtype); } private native void allocate(Pointer buffer, byte order, @Cast("Nd4jLong*") @StdVector long[] shape, @Cast("nd4j::DataType") int dtype); + + /** + * This method returns new array with the same shape & data type + * @return + */ + public native @ByVal NDArray like(); + + /** + * This method returns new uninitialized array with the same shape & data type + * @return + */ + public native @ByVal NDArray ulike(); + + /** * this constructor creates new NDArray with shape matching "other" array, * doesn't copy "other" elements into new array !!! diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java index 374a2c52d..e695a724b 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpuPresets.java @@ -156,6 +156,14 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled { infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", "_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations()) .put(new Info("NativeOps.h").objectify()) + .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) + .put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper")) + .put(new Info("OpaqueShapeList").pointerTypes("OpaqueShapeList")) + .put(new Info("OpaqueVariableSet").pointerTypes("OpaqueVariableSet")) + .put(new Info("OpaqueVariable").pointerTypes("OpaqueVariable")) + .put(new Info("OpaqueConstantDataBuffer").pointerTypes("OpaqueConstantDataBuffer")) + .put(new Info("OpaqueContext").pointerTypes("OpaqueContext")) + .put(new Info("OpaqueRandomGenerator").pointerTypes("OpaqueRandomGenerator")) .put(new Info("const char").valueTypes("byte").pointerTypes("@Cast(\"char*\") String", "@Cast(\"char*\") BytePointer")) .put(new Info("char").valueTypes("char").pointerTypes("@Cast(\"char*\") BytePointer", diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index d99c50af9..57f2e75d9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -442,7 +442,7 @@ public class CustomOpsTests extends BaseNd4jTest { context.setOutputArray(0, arrayZ); val addOp = new AddOp(); - NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp(null, addOp.opHash(), context.contextPointer()); + NativeOpsHolder.getInstance().getDeviceNativeOps().execCustomOp2(null, addOp.opHash(), context.contextPointer()); assertEquals(exp, arrayZ); }