diff --git a/libnd4j/blas/Environment.cpp b/libnd4j/blas/Environment.cpp index 90c391cf1..de0ac925b 100644 --- a/libnd4j/blas/Environment.cpp +++ b/libnd4j/blas/Environment.cpp @@ -43,7 +43,7 @@ namespace nd4j { nd4j::Environment::Environment() { - _tadThreshold.store(8); + _tadThreshold.store(1); _elementThreshold.store(1024); _verbose.store(false); _debug.store(false); @@ -52,6 +52,7 @@ namespace nd4j { _leaks.store(false); _dataType.store(nd4j::DataType::FLOAT32); _maxThreads = std::thread::hardware_concurrency(); + _maxMasterThreads = _maxThreads.load(); #ifndef ANDROID const char* omp_threads = std::getenv("OMP_NUM_THREADS"); @@ -66,6 +67,94 @@ namespace nd4j { // still do nothing } } + + /** + * Defines size of thread pool used for parallelism + */ + const char* max_threads = std::getenv("SD_MAX_THREADS"); + if (max_threads != nullptr) { + try { + std::string t(max_threads); + int val = std::stoi(t); + _maxThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * Defines max number of threads usable at once + */ + const char* max_master_threads = std::getenv("SD_MASTER_THREADS"); + if (max_master_threads != nullptr) { + try { + std::string t(max_master_threads); + int val = std::stoi(t); + _maxMasterThreads.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * If this env var is defined - we'll disallow use of platform-specific helpers (mkldnn, cudnn, etc) + */ + const char* forbid_helpers = std::getenv("SD_FORBID_HELPERS"); + if (max_master_threads != nullptr) { + _allowHelpers = false; + } + + /** + * This var defines max amount of host memory library can allocate + */ + const char* max_primary_memory = std::getenv("SD_MAX_PRIMARY_BYTES"); + if (max_primary_memory != nullptr) { + try { + std::string t(max_primary_memory); + auto val = std::stol(t); + _maxTotalPrimaryMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * This var defines max amount of special (i.e. device) memory library can allocate on all devices combined + */ + const char* max_special_memory = std::getenv("SD_MAX_SPECIAL_BYTES"); + if (max_special_memory != nullptr) { + try { + std::string t(max_special_memory); + auto val = std::stol(t); + _maxTotalSpecialMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } + + /** + * This var defines max amount of special (i.e. device) memory library can allocate on all devices combined + */ + const char* max_device_memory = std::getenv("SD_MAX_DEVICE_BYTES"); + if (max_device_memory != nullptr) { + try { + std::string t(max_device_memory); + auto val = std::stol(t); + _maxDeviceMemory.store(val); + } catch (std::invalid_argument &e) { + // just do nothing + } catch (std::out_of_range &e) { + // still do nothing + } + } #endif #ifdef __CUDABLAS__ @@ -97,6 +186,18 @@ namespace nd4j { // } + void Environment::setMaxPrimaryMemory(uint64_t maxBytes) { + _maxTotalPrimaryMemory = maxBytes; + } + + void Environment::setMaxSpecialyMemory(uint64_t maxBytes) { + _maxTotalSpecialMemory; + } + + void Environment::setMaxDeviceMemory(uint64_t maxBytes) { + _maxDeviceMemory = maxBytes; + } + Environment *Environment::getInstance() { if (_instance == 0) _instance = new Environment(); @@ -179,8 +280,16 @@ namespace nd4j { return _maxThreads.load(); } + int Environment::maxMasterThreads() { + return _maxMasterThreads.load(); + } + void Environment::setMaxThreads(int max) { - _maxThreads.store(max); + //_maxThreads.store(max); + } + + void Environment::setMaxMasterThreads(int max) { + //_maxMasterThreads = max; } bool Environment::precisionBoostAllowed() { @@ -211,6 +320,14 @@ namespace nd4j { return _blasPatchVersion; } + bool Environment::helpersAllowed() { + return _allowHelpers.load(); + } + + void Environment::allowHelpers(bool reallyAllow) { + _allowHelpers.store(reallyAllow); + } + nd4j::Environment *nd4j::Environment::_instance = 0; } diff --git a/libnd4j/blas/Environment.h b/libnd4j/blas/Environment.h index a303d27d0..54982471f 100644 --- a/libnd4j/blas/Environment.h +++ b/libnd4j/blas/Environment.h @@ -37,10 +37,18 @@ namespace nd4j{ std::atomic _debug; std::atomic _leaks; std::atomic _profile; - std::atomic _maxThreads; std::atomic _dataType; std::atomic _precBoost; std::atomic _useMKLDNN{true}; + std::atomic _allowHelpers{true}; + + std::atomic _maxThreads; + std::atomic _maxMasterThreads; + + // these fields hold defaults + std::atomic _maxTotalPrimaryMemory{-1}; + std::atomic _maxTotalSpecialMemory{-1}; + std::atomic _maxDeviceMemory{-1}; #ifdef __ND4J_EXPERIMENTAL__ const bool _experimental = true; @@ -74,6 +82,8 @@ namespace nd4j{ void setDebug(bool reallyDebug); void setProfiling(bool reallyProfile); void setLeaksDetector(bool reallyDetect); + bool helpersAllowed(); + void allowHelpers(bool reallyAllow); int tadThreshold(); void setTadThreshold(int threshold); @@ -84,6 +94,13 @@ namespace nd4j{ int maxThreads(); void setMaxThreads(int max); + int maxMasterThreads(); + void setMaxMasterThreads(int max); + + void setMaxPrimaryMemory(uint64_t maxBytes); + void setMaxSpecialyMemory(uint64_t maxBytes); + void setMaxDeviceMemory(uint64_t maxBytes); + bool isUseMKLDNN() { return _useMKLDNN.load(); } void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } diff --git a/libnd4j/blas/NativeOps.h b/libnd4j/blas/NativeOps.h index b10b3807a..ff368d7c8 100755 --- a/libnd4j/blas/NativeOps.h +++ b/libnd4j/blas/NativeOps.h @@ -1732,6 +1732,7 @@ typedef nd4j::graph::RandomGenerator OpaqueRandomGenerator; ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId); ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); +ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow); ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); diff --git a/libnd4j/blas/cpu/NativeOps.cpp b/libnd4j/blas/cpu/NativeOps.cpp index 151f5c883..df6ccc240 100644 --- a/libnd4j/blas/cpu/NativeOps.cpp +++ b/libnd4j/blas/cpu/NativeOps.cpp @@ -2874,6 +2874,9 @@ void deleteGraphContext(nd4j::graph::Context* ptr) { delete ptr; } +void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { + ptr->allowHelpers(reallyAllow); +} nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 2af0e3783..cda6acbad 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3558,4 +3558,8 @@ bool isMinimalRequirementsMet() { bool isOptimalRequirementsMet() { return true; +} + +void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) { + ptr->allowHelpers(reallyAllow); } \ No newline at end of file diff --git a/libnd4j/include/execution/Threads.h b/libnd4j/include/execution/Threads.h index 683220b61..be12a311a 100644 --- a/libnd4j/include/execution/Threads.h +++ b/libnd4j/include/execution/Threads.h @@ -107,11 +107,22 @@ namespace samediff { * @param increment * @return */ - static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads()); - - static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads()); + static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads()); /** + * This function executes 1 dimensional loop for a given number of threads + * + * @param function + * @param start + * @param stop + * @param increment + * @param numThreads + * @return + */ + static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads()); + + /** + * This method will execute function splitting 2 nested loops space with multiple threads * * @param function * @param numThreads @@ -123,9 +134,10 @@ namespace samediff { * @param inc_y * @return */ - static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads(), bool debug = false); + static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads(), bool debug = false); /** + * This method will execute function splitting 3 nested loops space with multiple threads * * @param function * @param numThreads @@ -140,7 +152,7 @@ namespace samediff { * @param inc_z * @return */ - static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); + static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads()); /** * @@ -148,11 +160,11 @@ namespace samediff { * @param numThreads * @return */ - static int parallel_do(FUNC_DO function, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); + static int parallel_do(FUNC_DO function, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads()); - static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); + static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads()); - static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); + static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads()); }; } diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index f397d46f3..96079e5a2 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -58,9 +58,12 @@ namespace nd4j { std::vector _dataTypes; + // fields for fast execution (out-of-graph ops use) std::vector _fastpath_in; std::vector _fastpath_out; std::vector _handles; + + bool _helpersAllowed = true; public: Context(ContextPrototype* prototype, VariableSpace* variableSpace); @@ -188,6 +191,10 @@ namespace nd4j { void setBArguments(bool *arguments, int numberOfArguments); void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); + + + void allowHelpers(bool reallyAllow); + bool helpersAllowed(); }; } } diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 085fa969e..b18d3f347 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -461,6 +461,14 @@ namespace nd4j { v->setContext(_context); #endif } + + void Context::allowHelpers(bool reallyAllow) { + _helpersAllowed = reallyAllow; + } + + bool Context::helpersAllowed() { + return _helpersAllowed; + } } } diff --git a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp index fe1574ea1..5ee19b007 100644 --- a/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp +++ b/libnd4j/include/ops/declarable/impl/DeclarableOp.cpp @@ -506,12 +506,15 @@ namespace nd4j { Nd4jStatus status; bool hasHelper = false; - // if we have platform-specific helper for this op - invoke it - if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) { - auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash()); - if (helper->isUsable(*block)) { - status = helper->invokeHelper(*block); - hasHelper = true; + // platform helpers use might be forbidden for various reasons, so we'll check it out first + if (block->helpersAllowed() && nd4j::Environment::getInstance()->helpersAllowed()) { + // if we have platform-specific helper for this op - invoke it + if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) { + auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash()); + if (helper->isUsable(*block)) { + status = helper->invokeHelper(*block); + hasHelper = true; + } } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index cd74a60a0..e66d52f91 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -128,4 +128,12 @@ public interface OpContext extends AutoCloseable { * @param reallyInplace */ void markInplace(boolean reallyInplace); + + /** + * This method allows to enable/disable use of platform helpers within ops. I.e. mkldnn or cuDNN. + * PLEASE NOTE: default value is True + * + * @param reallyAllow + */ + void allowHelpers(boolean reallyAllow); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index 8f621668b..d4a7b8f8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1123,6 +1123,7 @@ public interface NativeOps { void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); + void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); void deleteGraphContext(OpaqueContext ptr); OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 32f1b0a10..b75f688fe 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -136,4 +136,9 @@ public class CudaOpContext extends BaseOpContext implements OpContext { public void markInplace(boolean reallyInplace) { nativeOps.markGraphContextInplace(context, reallyInplace); } + + @Override + public void allowHelpers(boolean reallyAllow) { + nativeOps.ctxAllowHelpers(context, reallyAllow); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index fecb64012..22b2068d4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.1-1: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.2: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -575,6 +575,8 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public native void setDebug(@Cast("bool") boolean reallyDebug); public native void setProfiling(@Cast("bool") boolean reallyProfile); public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); + public native @Cast("bool") boolean helpersAllowed(); + public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native int tadThreshold(); public native void setTadThreshold(int threshold); @@ -585,6 +587,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper { public native int maxThreads(); public native void setMaxThreads(int max); + public native int maxMasterThreads(); + public native void setMaxMasterThreads(int max); + + public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + public native @Cast("bool") boolean isUseMKLDNN(); public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); @@ -3087,6 +3096,7 @@ public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); public native OpaqueContext createGraphContext(int nodeId); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); +public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); @@ -5454,6 +5464,10 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { + + + + @@ -6741,6 +6755,10 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); + + + public native void allowHelpers(@Cast("bool") boolean reallyAllow); + public native @Cast("bool") boolean helpersAllowed(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 9431a3453..6700f9019 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -105,4 +105,9 @@ public class CpuOpContext extends BaseOpContext implements OpContext { public void markInplace(boolean reallyInplace) { nativeOps.markGraphContextInplace(context, reallyInplace); } + + @Override + public void allowHelpers(boolean reallyAllow) { + nativeOps.ctxAllowHelpers(context, reallyAllow); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 0441cd3b3..d99a8240a 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -575,6 +575,8 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public native void setDebug(@Cast("bool") boolean reallyDebug); public native void setProfiling(@Cast("bool") boolean reallyProfile); public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); + public native @Cast("bool") boolean helpersAllowed(); + public native void allowHelpers(@Cast("bool") boolean reallyAllow); public native int tadThreshold(); public native void setTadThreshold(int threshold); @@ -585,6 +587,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper { public native int maxThreads(); public native void setMaxThreads(int max); + public native int maxMasterThreads(); + public native void setMaxMasterThreads(int max); + + public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes); + public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes); + public native @Cast("bool") boolean isUseMKLDNN(); public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); @@ -3087,6 +3096,7 @@ public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr); public native OpaqueContext createGraphContext(int nodeId); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); +public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); @@ -6745,6 +6755,10 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) { public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); + + + public native void allowHelpers(@Cast("bool") boolean reallyAllow); + public native @Cast("bool") boolean helpersAllowed(); } @@ -11383,6 +11397,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // #elif _MSC_VER // #define FORCEINLINE __forceinline // #elif __GNUC__ +// #define INLINE_LOOPS // #define FORCEINLINE __attribute__((always_inline)) inline // #elif __CUDACC__ // #else diff --git a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JEnvironmentVars.java b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JEnvironmentVars.java index 3bcff03f0..c77f945d0 100644 --- a/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JEnvironmentVars.java +++ b/nd4j/nd4j-common/src/main/java/org/nd4j/config/ND4JEnvironmentVars.java @@ -137,6 +137,39 @@ public class ND4JEnvironmentVars { */ public static final String ND4J_IGNORE_AVX = "ND4J_IGNORE_AVX"; + /** + * This variable defines how many threads will be used in ThreadPool for parallel execution of linear algebra. + * Default value: number of threads supported by this system. + */ + public static final String SD_MAX_THREADS = "SD_MAX_THREADS"; + + /** + * This variable defines how many threads will be used for any 1 linear algebra operation. + * Default value: number of threads supported by this system. + */ + public static final String SD_MASTER_THREADS = "SD_MASTER_THREADS"; + + /** + * If set, this variable disables use of optimized platform helpers (i.e. mkldnn or cuDNN) + */ + public static final String SD_FORBID_HELPERS = "SD_FORBID_HELPERS"; + + /** + * If set, this variables defines how much memory application is allowed to use off-heap. + * PLEASE NOTE: this option is separate from JVM XMS/XMX options + */ + public static final String SD_MAX_PRIMARY_BYTES = "SD_MAX_PRIMARY_BYTES"; + + /** + * If set, this variable defines how much memory application is allowed to use ON ALL computational devices COMBINED. + */ + public static final String SD_MAX_SPECIAL_BYTES = "SD_MAX_SPECIAL_BYTES"; + + /** + * If set, this variable defines how much memory application is allowed to use on any one computational device + */ + public static final String SD_MAX_DEVICE_BYTES = "SD_MAX_DEVICE_BYTES"; + private ND4JEnvironmentVars() { } }