2022-09-20 15:40:53 +02:00

81 lines
2.9 KiB
Java

/*
*
* ******************************************************************************
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*
*/
package org.nd4j.nativeblas;
import lombok.extern.log4j.Log4j2;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Loader;
import org.nd4j.common.config.ND4JEnvironmentVars;
import org.nd4j.common.config.ND4JSystemProperties;
import org.nd4j.linalg.api.blas.Blas;
@Slf4j
public abstract class Nd4jBlas implements Blas {
public Nd4jBlas() {
int numThreads;
String skipper = System.getenv(ND4JEnvironmentVars.ND4J_SKIP_BLAS_THREADS);
if (skipper == null || skipper.isEmpty()) {
String numThreadsString = System.getenv(ND4JEnvironmentVars.OMP_NUM_THREADS);
if (numThreadsString != null && !numThreadsString.isEmpty()) {
numThreads = Integer.parseInt(numThreadsString);
setMaxThreads(numThreads);
} else {
int cores = Loader.totalCores();
int chips = Loader.totalChips();
if (cores > 0 && chips > 0)
numThreads = Math.max(1, cores / chips);
else
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
setMaxThreads(numThreads);
}
String logInit = System.getProperty(ND4JSystemProperties.LOG_INITIALIZATION);
if(logOpenMPBlasThreads() && (logInit == null || logInit.isEmpty() || Boolean.parseBoolean(logInit))) {
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
}
}
}
/**
* Returns the BLAS library vendor
*
* @return the BLAS library vendor
*/
@Override
public Vendor getBlasVendor() {
int vendor = getBlasVendorId();
boolean isUnknowVendor = ((vendor > Vendor.values().length - 1) || (vendor <= 0));
if (isUnknowVendor) {
return Vendor.UNKNOWN;
}
return Vendor.values()[vendor];
}
public boolean logOpenMPBlasThreads(){
return true;
}
}