blas fallback (#291)

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
raver119 2020-03-05 14:11:13 +03:00 committed by GitHub
parent 784a2d13f8
commit 2911da061b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 0 deletions

View File

@ -74,6 +74,9 @@ namespace sd {
template <>
bool BlasHelper::hasGEMV<float>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -83,6 +86,9 @@ namespace sd {
template <>
bool BlasHelper::hasGEMV<double>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -132,6 +138,9 @@ namespace sd {
bool BlasHelper::hasGEMV(const sd::DataType dtype) {
if(dtype == DataType::FLOAT32) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -139,6 +148,9 @@ namespace sd {
#endif
}
if(dtype == DataType::DOUBLE) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -150,6 +162,9 @@ namespace sd {
template <>
bool BlasHelper::hasGEMM<float>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -159,6 +174,9 @@ namespace sd {
template <>
bool BlasHelper::hasGEMM<double>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -208,6 +226,9 @@ namespace sd {
bool BlasHelper:: hasGEMM(const sd::DataType dtype) {
if(dtype == DataType::FLOAT32) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -215,6 +236,9 @@ namespace sd {
#endif
}
if(dtype == DataType::DOUBLE) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true;
#else
@ -227,11 +251,17 @@ namespace sd {
template <>
bool BlasHelper::hasBatchedGEMM<float>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
return _hasSgemmBatch;
}
template <>
bool BlasHelper::hasBatchedGEMM<double>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
return _hasDgemmBatch;
}

View File

@ -162,6 +162,11 @@ namespace sd {
// still do nothing
}
}
const char* blas_fallback = std::getenv("SD_BLAS_FALLBACK");
if (blas_fallback != nullptr) {
_blasFallback = true;
}
#endif
#ifdef __CUDABLAS__
@ -189,6 +194,10 @@ namespace sd {
#endif
}
bool sd::Environment::blasFallback() {
return _blasFallback;
}
sd::Environment::~Environment() {
//
}

View File

@ -51,6 +51,8 @@ namespace sd{
std::atomic<int64_t> _maxTotalSpecialMemory{-1};
std::atomic<int64_t> _maxDeviceMemory{-1};
bool _blasFallback = false;
#ifdef __ND4J_EXPERIMENTAL__
const bool _experimental = true;
#else
@ -86,6 +88,8 @@ namespace sd{
bool helpersAllowed();
void allowHelpers(bool reallyAllow);
bool blasFallback();
int tadThreshold();
void setTadThreshold(int threshold);