blas fallback (#291)
Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
784a2d13f8
commit
2911da061b
|
@ -74,6 +74,9 @@ namespace sd {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMV<float>() {
|
bool BlasHelper::hasGEMV<float>() {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -83,6 +86,9 @@ namespace sd {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMV<double>() {
|
bool BlasHelper::hasGEMV<double>() {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -132,6 +138,9 @@ namespace sd {
|
||||||
|
|
||||||
bool BlasHelper::hasGEMV(const sd::DataType dtype) {
|
bool BlasHelper::hasGEMV(const sd::DataType dtype) {
|
||||||
if(dtype == DataType::FLOAT32) {
|
if(dtype == DataType::FLOAT32) {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -139,6 +148,9 @@ namespace sd {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
if(dtype == DataType::DOUBLE) {
|
if(dtype == DataType::DOUBLE) {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -150,6 +162,9 @@ namespace sd {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMM<float>() {
|
bool BlasHelper::hasGEMM<float>() {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -159,6 +174,9 @@ namespace sd {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasGEMM<double>() {
|
bool BlasHelper::hasGEMM<double>() {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -208,6 +226,9 @@ namespace sd {
|
||||||
|
|
||||||
bool BlasHelper:: hasGEMM(const sd::DataType dtype) {
|
bool BlasHelper:: hasGEMM(const sd::DataType dtype) {
|
||||||
if(dtype == DataType::FLOAT32) {
|
if(dtype == DataType::FLOAT32) {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -215,6 +236,9 @@ namespace sd {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
if(dtype == DataType::DOUBLE) {
|
if(dtype == DataType::DOUBLE) {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
|
@ -227,11 +251,17 @@ namespace sd {
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasBatchedGEMM<float>() {
|
bool BlasHelper::hasBatchedGEMM<float>() {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
return _hasSgemmBatch;
|
return _hasSgemmBatch;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
bool BlasHelper::hasBatchedGEMM<double>() {
|
bool BlasHelper::hasBatchedGEMM<double>() {
|
||||||
|
if (sd::Environment::getInstance()->blasFallback())
|
||||||
|
return false;
|
||||||
|
|
||||||
return _hasDgemmBatch;
|
return _hasDgemmBatch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -162,6 +162,11 @@ namespace sd {
|
||||||
// still do nothing
|
// still do nothing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char* blas_fallback = std::getenv("SD_BLAS_FALLBACK");
|
||||||
|
if (blas_fallback != nullptr) {
|
||||||
|
_blasFallback = true;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef __CUDABLAS__
|
#ifdef __CUDABLAS__
|
||||||
|
@ -189,6 +194,10 @@ namespace sd {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool sd::Environment::blasFallback() {
|
||||||
|
return _blasFallback;
|
||||||
|
}
|
||||||
|
|
||||||
sd::Environment::~Environment() {
|
sd::Environment::~Environment() {
|
||||||
//
|
//
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,6 +51,8 @@ namespace sd{
|
||||||
std::atomic<int64_t> _maxTotalSpecialMemory{-1};
|
std::atomic<int64_t> _maxTotalSpecialMemory{-1};
|
||||||
std::atomic<int64_t> _maxDeviceMemory{-1};
|
std::atomic<int64_t> _maxDeviceMemory{-1};
|
||||||
|
|
||||||
|
bool _blasFallback = false;
|
||||||
|
|
||||||
#ifdef __ND4J_EXPERIMENTAL__
|
#ifdef __ND4J_EXPERIMENTAL__
|
||||||
const bool _experimental = true;
|
const bool _experimental = true;
|
||||||
#else
|
#else
|
||||||
|
@ -85,6 +87,8 @@ namespace sd{
|
||||||
void setLeaksDetector(bool reallyDetect);
|
void setLeaksDetector(bool reallyDetect);
|
||||||
bool helpersAllowed();
|
bool helpersAllowed();
|
||||||
void allowHelpers(bool reallyAllow);
|
void allowHelpers(bool reallyAllow);
|
||||||
|
|
||||||
|
bool blasFallback();
|
||||||
|
|
||||||
int tadThreshold();
|
int tadThreshold();
|
||||||
void setTadThreshold(int threshold);
|
void setTadThreshold(int threshold);
|
||||||
|
|
Loading…
Reference in New Issue