blas fallback (#291)

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
raver119 2020-03-05 14:11:13 +03:00 committed by GitHub
parent 784a2d13f8
commit 2911da061b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 43 additions and 0 deletions

View File

@ -74,6 +74,9 @@ namespace sd {
template <> template <>
bool BlasHelper::hasGEMV<float>() { bool BlasHelper::hasGEMV<float>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -83,6 +86,9 @@ namespace sd {
template <> template <>
bool BlasHelper::hasGEMV<double>() { bool BlasHelper::hasGEMV<double>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -132,6 +138,9 @@ namespace sd {
bool BlasHelper::hasGEMV(const sd::DataType dtype) { bool BlasHelper::hasGEMV(const sd::DataType dtype) {
if(dtype == DataType::FLOAT32) { if(dtype == DataType::FLOAT32) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -139,6 +148,9 @@ namespace sd {
#endif #endif
} }
if(dtype == DataType::DOUBLE) { if(dtype == DataType::DOUBLE) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -150,6 +162,9 @@ namespace sd {
template <> template <>
bool BlasHelper::hasGEMM<float>() { bool BlasHelper::hasGEMM<float>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -159,6 +174,9 @@ namespace sd {
template <> template <>
bool BlasHelper::hasGEMM<double>() { bool BlasHelper::hasGEMM<double>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -208,6 +226,9 @@ namespace sd {
bool BlasHelper:: hasGEMM(const sd::DataType dtype) { bool BlasHelper:: hasGEMM(const sd::DataType dtype) {
if(dtype == DataType::FLOAT32) { if(dtype == DataType::FLOAT32) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -215,6 +236,9 @@ namespace sd {
#endif #endif
} }
if(dtype == DataType::DOUBLE) { if(dtype == DataType::DOUBLE) {
if (sd::Environment::getInstance()->blasFallback())
return false;
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
return true; return true;
#else #else
@ -227,11 +251,17 @@ namespace sd {
template <> template <>
bool BlasHelper::hasBatchedGEMM<float>() { bool BlasHelper::hasBatchedGEMM<float>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
return _hasSgemmBatch; return _hasSgemmBatch;
} }
template <> template <>
bool BlasHelper::hasBatchedGEMM<double>() { bool BlasHelper::hasBatchedGEMM<double>() {
if (sd::Environment::getInstance()->blasFallback())
return false;
return _hasDgemmBatch; return _hasDgemmBatch;
} }

View File

@ -162,6 +162,11 @@ namespace sd {
// still do nothing // still do nothing
} }
} }
const char* blas_fallback = std::getenv("SD_BLAS_FALLBACK");
if (blas_fallback != nullptr) {
_blasFallback = true;
}
#endif #endif
#ifdef __CUDABLAS__ #ifdef __CUDABLAS__
@ -189,6 +194,10 @@ namespace sd {
#endif #endif
} }
bool sd::Environment::blasFallback() {
return _blasFallback;
}
sd::Environment::~Environment() { sd::Environment::~Environment() {
// //
} }

View File

@ -51,6 +51,8 @@ namespace sd{
std::atomic<int64_t> _maxTotalSpecialMemory{-1}; std::atomic<int64_t> _maxTotalSpecialMemory{-1};
std::atomic<int64_t> _maxDeviceMemory{-1}; std::atomic<int64_t> _maxDeviceMemory{-1};
bool _blasFallback = false;
#ifdef __ND4J_EXPERIMENTAL__ #ifdef __ND4J_EXPERIMENTAL__
const bool _experimental = true; const bool _experimental = true;
#else #else
@ -85,6 +87,8 @@ namespace sd{
void setLeaksDetector(bool reallyDetect); void setLeaksDetector(bool reallyDetect);
bool helpersAllowed(); bool helpersAllowed();
void allowHelpers(bool reallyAllow); void allowHelpers(bool reallyAllow);
bool blasFallback();
int tadThreshold(); int tadThreshold();
void setTadThreshold(int threshold); void setTadThreshold(int threshold);