2021-02-01 13:31:45 +01:00
|
|
|
/* ******************************************************************************
|
|
|
|
*
|
2019-06-06 14:21:15 +02:00
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
2021-02-01 13:31:45 +01:00
|
|
|
* See the NOTICE file distributed with this work for additional
|
|
|
|
* information regarding copyright ownership.
|
2019-06-06 14:21:15 +02:00
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <helpers/BlasHelper.h>
|
2020-03-02 10:49:41 +01:00
|
|
|
namespace sd {
|
2020-06-06 14:26:55 +02:00
|
|
|
BlasHelper& BlasHelper::getInstance() {
|
|
|
|
static BlasHelper instance;
|
|
|
|
return instance;
|
2019-06-06 14:21:15 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void BlasHelper::initializeFunctions(Nd4jPointer *functions) {
|
|
|
|
nd4j_debug("Initializing BLAS\n","");
|
|
|
|
|
|
|
|
_hasSgemv = functions[0] != nullptr;
|
|
|
|
_hasSgemm = functions[2] != nullptr;
|
|
|
|
|
|
|
|
_hasDgemv = functions[1] != nullptr;
|
|
|
|
_hasDgemm = functions[3] != nullptr;
|
|
|
|
|
|
|
|
_hasSgemmBatch = functions[4] != nullptr;
|
|
|
|
_hasDgemmBatch = functions[5] != nullptr;
|
|
|
|
|
|
|
|
this->cblasSgemv = (CblasSgemv)functions[0];
|
|
|
|
this->cblasDgemv = (CblasDgemv)functions[1];
|
|
|
|
this->cblasSgemm = (CblasSgemm)functions[2];
|
|
|
|
this->cblasDgemm = (CblasDgemm)functions[3];
|
|
|
|
this->cblasSgemmBatch = (CblasSgemmBatch)functions[4];
|
|
|
|
this->cblasDgemmBatch = (CblasDgemmBatch)functions[5];
|
|
|
|
this->lapackeSgesvd = (LapackeSgesvd)functions[6];
|
|
|
|
this->lapackeDgesvd = (LapackeDgesvd)functions[7];
|
|
|
|
this->lapackeSgesdd = (LapackeSgesdd)functions[8];
|
|
|
|
this->lapackeDgesdd = (LapackeDgesdd)functions[9];
|
|
|
|
}
|
|
|
|
|
|
|
|
void BlasHelper::initializeDeviceFunctions(Nd4jPointer *functions) {
|
|
|
|
nd4j_debug("Initializing device BLAS\n","");
|
|
|
|
|
|
|
|
/*
|
|
|
|
this->cublasSgemv = (CublasSgemv)functions[0];
|
|
|
|
this->cublasDgemv = (CublasDgemv)functions[1];
|
|
|
|
this->cublasHgemm = (CublasHgemm)functions[2];
|
|
|
|
this->cublasSgemm = (CublasSgemm)functions[3];
|
|
|
|
this->cublasDgemm = (CublasDgemm)functions[4];
|
|
|
|
this->cublasSgemmEx = (CublasSgemmEx)functions[5];
|
|
|
|
this->cublasHgemmBatched = (CublasHgemmBatched)functions[6];
|
|
|
|
this->cublasSgemmBatched = (CublasSgemmBatched)functions[7];
|
|
|
|
this->cublasDgemmBatched = (CublasDgemmBatched)functions[8];
|
|
|
|
this->cusolverDnSgesvdBufferSize = (CusolverDnSgesvdBufferSize)functions[9];
|
|
|
|
this->cusolverDnDgesvdBufferSize = (CusolverDnDgesvdBufferSize)functions[10];
|
|
|
|
this->cusolverDnSgesvd = (CusolverDnSgesvd)functions[11];
|
|
|
|
this->cusolverDnDgesvd = (CusolverDnDgesvd)functions[12];
|
|
|
|
*/
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<float>() {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasSgemv;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<double>() {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasDgemv;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<float16>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<bfloat16>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<bool>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<int>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<int8_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<uint8_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<int16_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMV<Nd4jLong>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool BlasHelper::hasGEMV(const sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if(dtype == DataType::FLOAT32) {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasSgemv;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
if(dtype == DataType::DOUBLE) {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasDgemv;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<float>() {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasSgemm;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<double>() {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasDgemm;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<float16>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<bfloat16>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<int>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<uint8_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<int8_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<int16_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<bool>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasGEMM<Nd4jLong>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2020-03-02 10:49:41 +01:00
|
|
|
bool BlasHelper:: hasGEMM(const sd::DataType dtype) {
|
2019-06-06 14:21:15 +02:00
|
|
|
if(dtype == DataType::FLOAT32) {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasSgemm;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
if(dtype == DataType::DOUBLE) {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return true;
|
|
|
|
#else
|
|
|
|
return _hasDgemm;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<float>() {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
return _hasSgemmBatch;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<double>() {
|
2020-06-06 14:26:55 +02:00
|
|
|
if (sd::Environment::getInstance().blasFallback())
|
2020-03-05 12:11:13 +01:00
|
|
|
return false;
|
|
|
|
|
2019-06-06 14:21:15 +02:00
|
|
|
return _hasDgemmBatch;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<float16>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<bfloat16>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<Nd4jLong>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<int>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<int8_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<uint8_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<int16_t>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
bool BlasHelper::hasBatchedGEMM<bool>() {
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
|
|
|
CblasSgemv BlasHelper::sgemv() {
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__)|| defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return (CblasSgemv)&cblas_sgemv;
|
|
|
|
#else
|
|
|
|
return this->cblasSgemv;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
CblasDgemv BlasHelper::dgemv() {
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return (CblasDgemv)&cblas_dgemv;
|
|
|
|
#else
|
|
|
|
return this->cblasDgemv;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
CblasSgemm BlasHelper::sgemm() {
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return (CblasSgemm)&cblas_sgemm;
|
|
|
|
#else
|
|
|
|
return this->cblasSgemm;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
CblasDgemm BlasHelper::dgemm() {
|
2019-11-13 15:15:18 +01:00
|
|
|
#if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS)
|
2019-06-06 14:21:15 +02:00
|
|
|
return (CblasDgemm)&cblas_dgemm;
|
|
|
|
#else
|
|
|
|
return this->cblasDgemm;
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
CblasSgemmBatch BlasHelper::sgemmBatched() {
|
|
|
|
return this->cblasSgemmBatch;
|
|
|
|
}
|
|
|
|
|
|
|
|
CblasDgemmBatch BlasHelper::dgemmBatched() {
|
|
|
|
return this->cblasDgemmBatch;
|
|
|
|
}
|
|
|
|
|
|
|
|
LapackeSgesvd BlasHelper::sgesvd() {
|
|
|
|
return this->lapackeSgesvd;
|
|
|
|
}
|
|
|
|
|
|
|
|
LapackeDgesvd BlasHelper::dgesvd() {
|
|
|
|
return this->lapackeDgesvd;
|
|
|
|
}
|
|
|
|
|
|
|
|
LapackeSgesdd BlasHelper::sgesdd() {
|
|
|
|
return this->lapackeSgesdd;
|
|
|
|
}
|
|
|
|
|
|
|
|
LapackeDgesdd BlasHelper::dgesdd() {
|
|
|
|
return this->lapackeDgesdd;
|
|
|
|
}
|
|
|
|
|
|
|
|
// destructor
|
|
|
|
BlasHelper::~BlasHelper() noexcept { }
|
|
|
|
}
|