/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by raver119 on 06.10.2017. // #ifndef LIBND4J_ENVIRONMENT_H #define LIBND4J_ENVIRONMENT_H #include #include #include #include #include #include namespace nd4j{ class ND4J_EXPORT Environment { private: std::atomic _tadThreshold; std::atomic _elementThreshold; std::atomic _verbose; std::atomic _debug; std::atomic _leaks; std::atomic _profile; std::atomic _maxThreads; std::atomic _dataType; std::atomic _precBoost; std::atomic _useMKLDNN{true}; #ifdef __ND4J_EXPERIMENTAL__ const bool _experimental = true; #else const bool _experimental = false; #endif // device compute capability for CUDA std::vector _capabilities; static Environment* _instance; Environment(); ~Environment(); public: /** * These 3 fields are mostly for CUDA/cuBLAS version tracking */ int _blasMajorVersion = 0; int _blasMinorVersion = 0; int _blasPatchVersion = 0; static Environment* getInstance(); bool isVerbose(); void setVerbose(bool reallyVerbose); bool isDebug(); bool isProfiling(); bool isDetectingLeaks(); bool isDebugAndVerbose(); void setDebug(bool reallyDebug); void setProfiling(bool reallyProfile); void setLeaksDetector(bool reallyDetect); int tadThreshold(); void setTadThreshold(int threshold); int elementwiseThreshold(); void setElementwiseThreshold(int threshold); int maxThreads(); void setMaxThreads(int max); bool isUseMKLDNN() { return _useMKLDNN.load(); } void setUseMKLDNN(bool useMKLDNN) { _useMKLDNN.store(useMKLDNN); } nd4j::DataType defaultFloatDataType(); void setDefaultFloatDataType(nd4j::DataType dtype); bool precisionBoostAllowed(); void allowPrecisionBoost(bool reallyAllow); bool isExperimentalBuild(); bool isCPU(); int blasMajorVersion(); int blasMinorVersion(); int blasPatchVersion(); std::vector& capabilities(); }; } #endif //LIBND4J_ENVIRONMENT_H