[WIP] Platform helpers switches (#44)

* - platform helpers can be disabled on per-op basis now via Context::allowHelpers
- java has access to it as well

Signed-off-by: raver119 <raver119@gmail.com>

* global platform-helpers trigger

Signed-off-by: raver119 <raver119@gmail.com>

* few signatures renamed

Signed-off-by: raver119 <raver119@gmail.com>

* - few new env variables to follow
- maxThreads/masterThreads differentiation

Signed-off-by: raver119 <raver119@gmail.com>

* Javadoc update

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-11-14 14:35:02 +03:00 committed by GitHub
parent 47d19908f4
commit 1eb3de90d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 275 additions and 18 deletions

View File

@ -43,7 +43,7 @@
namespace nd4j { namespace nd4j {
nd4j::Environment::Environment() { nd4j::Environment::Environment() {
_tadThreshold.store(8); _tadThreshold.store(1);
_elementThreshold.store(1024); _elementThreshold.store(1024);
_verbose.store(false); _verbose.store(false);
_debug.store(false); _debug.store(false);
@ -52,6 +52,7 @@ namespace nd4j {
_leaks.store(false); _leaks.store(false);
_dataType.store(nd4j::DataType::FLOAT32); _dataType.store(nd4j::DataType::FLOAT32);
_maxThreads = std::thread::hardware_concurrency(); _maxThreads = std::thread::hardware_concurrency();
_maxMasterThreads = _maxThreads.load();
#ifndef ANDROID #ifndef ANDROID
const char* omp_threads = std::getenv("OMP_NUM_THREADS"); const char* omp_threads = std::getenv("OMP_NUM_THREADS");
@ -66,6 +67,94 @@ namespace nd4j {
// still do nothing // still do nothing
} }
} }
/**
* Defines size of thread pool used for parallelism
*/
const char* max_threads = std::getenv("SD_MAX_THREADS");
if (max_threads != nullptr) {
try {
std::string t(max_threads);
int val = std::stoi(t);
_maxThreads.store(val);
} catch (std::invalid_argument &e) {
// just do nothing
} catch (std::out_of_range &e) {
// still do nothing
}
}
/**
* Defines max number of threads usable at once
*/
const char* max_master_threads = std::getenv("SD_MASTER_THREADS");
if (max_master_threads != nullptr) {
try {
std::string t(max_master_threads);
int val = std::stoi(t);
_maxMasterThreads.store(val);
} catch (std::invalid_argument &e) {
// just do nothing
} catch (std::out_of_range &e) {
// still do nothing
}
}
/**
* If this env var is defined - we'll disallow use of platform-specific helpers (mkldnn, cudnn, etc)
*/
const char* forbid_helpers = std::getenv("SD_FORBID_HELPERS");
if (max_master_threads != nullptr) {
_allowHelpers = false;
}
/**
* This var defines max amount of host memory library can allocate
*/
const char* max_primary_memory = std::getenv("SD_MAX_PRIMARY_BYTES");
if (max_primary_memory != nullptr) {
try {
std::string t(max_primary_memory);
auto val = std::stol(t);
_maxTotalPrimaryMemory.store(val);
} catch (std::invalid_argument &e) {
// just do nothing
} catch (std::out_of_range &e) {
// still do nothing
}
}
/**
* This var defines max amount of special (i.e. device) memory library can allocate on all devices combined
*/
const char* max_special_memory = std::getenv("SD_MAX_SPECIAL_BYTES");
if (max_special_memory != nullptr) {
try {
std::string t(max_special_memory);
auto val = std::stol(t);
_maxTotalSpecialMemory.store(val);
} catch (std::invalid_argument &e) {
// just do nothing
} catch (std::out_of_range &e) {
// still do nothing
}
}
/**
* This var defines max amount of special (i.e. device) memory library can allocate on all devices combined
*/
const char* max_device_memory = std::getenv("SD_MAX_DEVICE_BYTES");
if (max_device_memory != nullptr) {
try {
std::string t(max_device_memory);
auto val = std::stol(t);
_maxDeviceMemory.store(val);
} catch (std::invalid_argument &e) {
// just do nothing
} catch (std::out_of_range &e) {
// still do nothing
}
}
#endif #endif
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
@ -97,6 +186,18 @@ namespace nd4j {
// //
} }
void Environment::setMaxPrimaryMemory(uint64_t maxBytes) {
_maxTotalPrimaryMemory = maxBytes;
}
void Environment::setMaxSpecialyMemory(uint64_t maxBytes) {
_maxTotalSpecialMemory;
}
void Environment::setMaxDeviceMemory(uint64_t maxBytes) {
_maxDeviceMemory = maxBytes;
}
Environment *Environment::getInstance() { Environment *Environment::getInstance() {
if (_instance == 0) if (_instance == 0)
_instance = new Environment(); _instance = new Environment();
@ -179,8 +280,16 @@ namespace nd4j {
return _maxThreads.load(); return _maxThreads.load();
} }
int Environment::maxMasterThreads() {
return _maxMasterThreads.load();
}
void Environment::setMaxThreads(int max) { void Environment::setMaxThreads(int max) {
_maxThreads.store(max); //_maxThreads.store(max);
}
void Environment::setMaxMasterThreads(int max) {
//_maxMasterThreads = max;
} }
bool Environment::precisionBoostAllowed() { bool Environment::precisionBoostAllowed() {
@ -211,6 +320,14 @@ namespace nd4j {
return _blasPatchVersion; return _blasPatchVersion;
} }
bool Environment::helpersAllowed() {
return _allowHelpers.load();
}
void Environment::allowHelpers(bool reallyAllow) {
_allowHelpers.store(reallyAllow);
}
nd4j::Environment *nd4j::Environment::_instance = 0; nd4j::Environment *nd4j::Environment::_instance = 0;
} }

View File

@ -37,10 +37,18 @@ namespace nd4j{
std::atomic<bool> _debug; std::atomic<bool> _debug;
std::atomic<bool> _leaks; std::atomic<bool> _leaks;
std::atomic<bool> _profile; std::atomic<bool> _profile;
std::atomic<int> _maxThreads;
std::atomic<nd4j::DataType> _dataType; std::atomic<nd4j::DataType> _dataType;
std::atomic<bool> _precBoost; std::atomic<bool> _precBoost;
std::atomic<bool> _useMKLDNN{true}; std::atomic<bool> _useMKLDNN{true};
std::atomic<bool> _allowHelpers{true};
std::atomic<int> _maxThreads;
std::atomic<int> _maxMasterThreads;
// these fields hold defaults
std::atomic<int64_t> _maxTotalPrimaryMemory{-1};
std::atomic<int64_t> _maxTotalSpecialMemory{-1};
std::atomic<int64_t> _maxDeviceMemory{-1};
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
const bool _experimental = true; const bool _experimental = true;
@ -74,6 +82,8 @@ namespace nd4j{
void setDebug(bool reallyDebug); void setDebug(bool reallyDebug);
void setProfiling(bool reallyProfile); void setProfiling(bool reallyProfile);
void setLeaksDetector(bool reallyDetect); void setLeaksDetector(bool reallyDetect);
bool helpersAllowed();
void allowHelpers(bool reallyAllow);
int tadThreshold(); int tadThreshold();
void setTadThreshold(int threshold); void setTadThreshold(int threshold);
@ -84,6 +94,13 @@ namespace nd4j{
int maxThreads(); int maxThreads();
void setMaxThreads(int max); void setMaxThreads(int max);
int maxMasterThreads();
void setMaxMasterThreads(int max);
void setMaxPrimaryMemory(uint64_t maxBytes);
void setMaxSpecialyMemory(uint64_t maxBytes);
void setMaxDeviceMemory(uint64_t maxBytes);
bool isUseMKLDNN() { return _useMKLDNN.load(); } bool isUseMKLDNN() { return _useMKLDNN.load(); }
void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); }

View File

@ -1732,6 +1732,7 @@ typedef nd4j::graph::RandomGenerator OpaqueRandomGenerator;
ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId); ND4J_EXPORT OpaqueContext* createGraphContext(int nodeId);
ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr); ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext* ptr);
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace); ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer); ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo); ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);

View File

@ -2874,6 +2874,9 @@ void deleteGraphContext(nd4j::graph::Context* ptr) {
delete ptr; delete ptr;
} }
void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
ptr->allowHelpers(reallyAllow);
}
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);

View File

@ -3559,3 +3559,7 @@ bool isMinimalRequirementsMet() {
bool isOptimalRequirementsMet() { bool isOptimalRequirementsMet() {
return true; return true;
} }
void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow) {
ptr->allowHelpers(reallyAllow);
}

View File

@ -107,11 +107,22 @@ namespace samediff {
* @param increment * @param increment
* @return * @return
*/ */
static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads()); static int parallel_for(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads());
static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxThreads());
/** /**
* This function executes 1 dimensional loop for a given number of threads
*
* @param function
* @param start
* @param stop
* @param increment
* @param numThreads
* @return
*/
static int parallel_tad(FUNC_1D function, int64_t start, int64_t stop, int64_t increment = 1, uint32_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads());
/**
* This method will execute function splitting 2 nested loops space with multiple threads
* *
* @param function * @param function
* @param numThreads * @param numThreads
@ -123,9 +134,10 @@ namespace samediff {
* @param inc_y * @param inc_y
* @return * @return
*/ */
static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads(), bool debug = false); static int parallel_for(FUNC_2D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads(), bool debug = false);
/** /**
* This method will execute function splitting 3 nested loops space with multiple threads
* *
* @param function * @param function
* @param numThreads * @param numThreads
@ -140,7 +152,7 @@ namespace samediff {
* @param inc_z * @param inc_z
* @return * @return
*/ */
static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); static int parallel_for(FUNC_3D function, int64_t start_x, int64_t stop_x, int64_t inc_x, int64_t start_y, int64_t stop_y, int64_t inc_y, int64_t start_z, int64_t stop_z, int64_t inc_z, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads());
/** /**
* *
@ -148,11 +160,11 @@ namespace samediff {
* @param numThreads * @param numThreads
* @return * @return
*/ */
static int parallel_do(FUNC_DO function, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); static int parallel_do(FUNC_DO function, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads());
static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); static int64_t parallel_long(FUNC_RL function, FUNC_AL aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads());
static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxThreads()); static double parallel_double(FUNC_RD function, FUNC_AD aggregator, int64_t start, int64_t stop, int64_t increment = 1, uint64_t numThreads = nd4j::Environment::getInstance()->maxMasterThreads());
}; };
} }

View File

@ -58,9 +58,12 @@ namespace nd4j {
std::vector<nd4j::DataType> _dataTypes; std::vector<nd4j::DataType> _dataTypes;
// fields for fast execution (out-of-graph ops use)
std::vector<NDArray*> _fastpath_in; std::vector<NDArray*> _fastpath_in;
std::vector<NDArray*> _fastpath_out; std::vector<NDArray*> _fastpath_out;
std::vector<NDArray*> _handles; std::vector<NDArray*> _handles;
bool _helpersAllowed = true;
public: public:
Context(ContextPrototype* prototype, VariableSpace* variableSpace); Context(ContextPrototype* prototype, VariableSpace* variableSpace);
@ -188,6 +191,10 @@ namespace nd4j {
void setBArguments(bool *arguments, int numberOfArguments); void setBArguments(bool *arguments, int numberOfArguments);
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
void allowHelpers(bool reallyAllow);
bool helpersAllowed();
}; };
} }
} }

View File

@ -461,6 +461,14 @@ namespace nd4j {
v->setContext(_context); v->setContext(_context);
#endif #endif
} }
void Context::allowHelpers(bool reallyAllow) {
_helpersAllowed = reallyAllow;
}
bool Context::helpersAllowed() {
return _helpersAllowed;
}
} }
} }

View File

@ -506,12 +506,15 @@ namespace nd4j {
Nd4jStatus status; Nd4jStatus status;
bool hasHelper = false; bool hasHelper = false;
// if we have platform-specific helper for this op - invoke it // platform helpers use might be forbidden for various reasons, so we'll check it out first
if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) { if (block->helpersAllowed() && nd4j::Environment::getInstance()->helpersAllowed()) {
auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash()); // if we have platform-specific helper for this op - invoke it
if (helper->isUsable(*block)) { if (OpRegistrator::getInstance()->hasHelper(this->getOpHash())) {
status = helper->invokeHelper(*block); auto helper = OpRegistrator::getInstance()->getPlatformHelper(this->getOpHash());
hasHelper = true; if (helper->isUsable(*block)) {
status = helper->invokeHelper(*block);
hasHelper = true;
}
} }
} }

View File

@ -128,4 +128,12 @@ public interface OpContext extends AutoCloseable {
* @param reallyInplace * @param reallyInplace
*/ */
void markInplace(boolean reallyInplace); void markInplace(boolean reallyInplace);
/**
* This method allows to enable/disable use of platform helpers within ops. I.e. mkldnn or cuDNN.
* PLEASE NOTE: default value is True
*
* @param reallyAllow
*/
void allowHelpers(boolean reallyAllow);
} }

View File

@ -1123,6 +1123,7 @@ public interface NativeOps {
void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments); void setGraphContextTArguments(OpaqueContext ptr, DoublePointer arguments, int numberOfArguments);
void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments); void setGraphContextIArguments(OpaqueContext ptr, LongPointer arguments, int numberOfArguments);
void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments); void setGraphContextBArguments(OpaqueContext ptr, BooleanPointer arguments, int numberOfArguments);
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
void deleteGraphContext(OpaqueContext ptr); void deleteGraphContext(OpaqueContext ptr);
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed); OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);

View File

@ -136,4 +136,9 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
public void markInplace(boolean reallyInplace) { public void markInplace(boolean reallyInplace) {
nativeOps.markGraphContextInplace(context, reallyInplace); nativeOps.markGraphContextInplace(context, reallyInplace);
} }
@Override
public void allowHelpers(boolean reallyAllow) {
nativeOps.ctxAllowHelpers(context, reallyAllow);
}
} }

View File

@ -1,4 +1,4 @@
// Targeted by JavaCPP version 1.5.1-1: DO NOT EDIT THIS FILE // Targeted by JavaCPP version 1.5.2: DO NOT EDIT THIS FILE
package org.nd4j.nativeblas; package org.nd4j.nativeblas;
@ -575,6 +575,8 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public native void setDebug(@Cast("bool") boolean reallyDebug); public native void setDebug(@Cast("bool") boolean reallyDebug);
public native void setProfiling(@Cast("bool") boolean reallyProfile); public native void setProfiling(@Cast("bool") boolean reallyProfile);
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
public native @Cast("bool") boolean helpersAllowed();
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native int tadThreshold(); public native int tadThreshold();
public native void setTadThreshold(int threshold); public native void setTadThreshold(int threshold);
@ -585,6 +587,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
public native int maxThreads(); public native int maxThreads();
public native void setMaxThreads(int max); public native void setMaxThreads(int max);
public native int maxMasterThreads();
public native void setMaxMasterThreads(int max);
public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes);
public native @Cast("bool") boolean isUseMKLDNN(); public native @Cast("bool") boolean isUseMKLDNN();
public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN);
@ -3087,6 +3096,7 @@ public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
public native OpaqueContext createGraphContext(int nodeId); public native OpaqueContext createGraphContext(int nodeId);
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@ -5458,6 +5468,10 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
@ -6741,6 +6755,10 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native @Cast("bool") boolean helpersAllowed();
} }

View File

@ -105,4 +105,9 @@ public class CpuOpContext extends BaseOpContext implements OpContext {
public void markInplace(boolean reallyInplace) { public void markInplace(boolean reallyInplace) {
nativeOps.markGraphContextInplace(context, reallyInplace); nativeOps.markGraphContextInplace(context, reallyInplace);
} }
@Override
public void allowHelpers(boolean reallyAllow) {
nativeOps.ctxAllowHelpers(context, reallyAllow);
}
} }

View File

@ -575,6 +575,8 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public native void setDebug(@Cast("bool") boolean reallyDebug); public native void setDebug(@Cast("bool") boolean reallyDebug);
public native void setProfiling(@Cast("bool") boolean reallyProfile); public native void setProfiling(@Cast("bool") boolean reallyProfile);
public native void setLeaksDetector(@Cast("bool") boolean reallyDetect); public native void setLeaksDetector(@Cast("bool") boolean reallyDetect);
public native @Cast("bool") boolean helpersAllowed();
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native int tadThreshold(); public native int tadThreshold();
public native void setTadThreshold(int threshold); public native void setTadThreshold(int threshold);
@ -585,6 +587,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
public native int maxThreads(); public native int maxThreads();
public native void setMaxThreads(int max); public native void setMaxThreads(int max);
public native int maxMasterThreads();
public native void setMaxMasterThreads(int max);
public native void setMaxPrimaryMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxSpecialyMemory(@Cast("uint64_t") long maxBytes);
public native void setMaxDeviceMemory(@Cast("uint64_t") long maxBytes);
public native @Cast("bool") boolean isUseMKLDNN(); public native @Cast("bool") boolean isUseMKLDNN();
public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN); public native void setUseMKLDNN(@Cast("bool") boolean useMKLDNN);
@ -3087,6 +3096,7 @@ public native void deleteShapeBuffer(OpaqueConstantDataBuffer ptr);
public native OpaqueContext createGraphContext(int nodeId); public native OpaqueContext createGraphContext(int nodeId);
public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr); public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext ptr);
public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow);
public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace);
public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer);
public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo);
@ -6745,6 +6755,10 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments); public native void setBArguments(@Cast("bool*") boolean[] arguments, int numberOfArguments);
public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer);
public native void allowHelpers(@Cast("bool") boolean reallyAllow);
public native @Cast("bool") boolean helpersAllowed();
} }
@ -11383,6 +11397,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #elif _MSC_VER // #elif _MSC_VER
// #define FORCEINLINE __forceinline // #define FORCEINLINE __forceinline
// #elif __GNUC__ // #elif __GNUC__
// #define INLINE_LOOPS
// #define FORCEINLINE __attribute__((always_inline)) inline // #define FORCEINLINE __attribute__((always_inline)) inline
// #elif __CUDACC__ // #elif __CUDACC__
// #else // #else

View File

@ -137,6 +137,39 @@ public class ND4JEnvironmentVars {
*/ */
public static final String ND4J_IGNORE_AVX = "ND4J_IGNORE_AVX"; public static final String ND4J_IGNORE_AVX = "ND4J_IGNORE_AVX";
/**
* This variable defines how many threads will be used in ThreadPool for parallel execution of linear algebra.
* Default value: number of threads supported by this system.
*/
public static final String SD_MAX_THREADS = "SD_MAX_THREADS";
/**
* This variable defines how many threads will be used for any 1 linear algebra operation.
* Default value: number of threads supported by this system.
*/
public static final String SD_MASTER_THREADS = "SD_MASTER_THREADS";
/**
* If set, this variable disables use of optimized platform helpers (i.e. mkldnn or cuDNN)
*/
public static final String SD_FORBID_HELPERS = "SD_FORBID_HELPERS";
/**
* If set, this variables defines how much memory application is allowed to use off-heap.
* PLEASE NOTE: this option is separate from JVM XMS/XMX options
*/
public static final String SD_MAX_PRIMARY_BYTES = "SD_MAX_PRIMARY_BYTES";
/**
* If set, this variable defines how much memory application is allowed to use ON ALL computational devices COMBINED.
*/
public static final String SD_MAX_SPECIAL_BYTES = "SD_MAX_SPECIAL_BYTES";
/**
* If set, this variable defines how much memory application is allowed to use on any one computational device
*/
public static final String SD_MAX_DEVICE_BYTES = "SD_MAX_DEVICE_BYTES";
private ND4JEnvironmentVars() { private ND4JEnvironmentVars() {
} }
} }