cavis/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java

202 lines
8.2 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j;
import ch.qos.logback.classic.LoggerContext;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Timeout;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.config.ND4JSystemProperties;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.LoggerFactory;
import java.lang.management.ManagementFactory;
import java.util.List;
import java.util.Map;
import java.util.Properties;
@Slf4j
@Timeout(60*10)
public abstract class BaseDL4JTest {
protected long startTime;
protected int threadCountBefore;
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
/**
* Override this to specify the number of threads for C++ execution, via
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
* @return Number of threads to use for C++ op execution
*/
public int numThreads(){
return DEFAULT_THREADS;
}
/**
* Override this to set the profiling mode for the tests defined in the child class
*/
public OpExecutioner.ProfilingMode getProfilingMode(){
return OpExecutioner.ProfilingMode.SCOPE_PANIC;
}
/**
* Override this to set the datatype of the tests defined in the child class
*/
public DataType getDataType(){
return DataType.DOUBLE;
}
public DataType getDefaultFPDataType(){
return getDataType();
}
protected static Boolean integrationTest;
/**
* @return True if integration tests maven profile is enabled, false otherwise.
*/
public static boolean isIntegrationTests(){
if(integrationTest == null){
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
integrationTest = Boolean.parseBoolean(prop);
}
return integrationTest;
}
/**
* Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
* Note that the integration test profile is not enabled by default - "integration-tests" profile
*/
public static void skipUnlessIntegrationTests(){
Assumptions.assumeTrue(isIntegrationTests(), () -> "Skipping integration test - integration profile is not enabled");
}
@BeforeEach
public void beforeTest(){
log.info("{}:", getClass().getName());
//log.info("{}.{}", getClass().getSimpleName());
//Suppress ND4J initialization - don't need this logged for every test...
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
Nd4j.getExecutioner().enableDebugMode(false);
Nd4j.getExecutioner().enableVerboseMode(false);
int numThreads = numThreads();
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
}
startTime = System.currentTimeMillis();
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
}
@AfterEach
public void afterTest(){
//Attempt to keep workspaces isolated between tests
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
Nd4j.getMemoryManager().setCurrentWorkspace(null);
if(currWS != null){
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
// errors that are hard to track back to this
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
System.out.flush();
//Try to flush logs also:
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
if( lf instanceof LoggerContext){
((LoggerContext)lf).stop();
}
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
System.exit(1);
}
StringBuilder sb = new StringBuilder();
long maxPhys = Pointer.maxPhysicalBytes();
long maxBytes = Pointer.maxBytes();
long currPhys = Pointer.physicalBytes();
long currBytes = Pointer.totalBytes();
long jvmTotal = Runtime.getRuntime().totalMemory();
long jvmMax = Runtime.getRuntime().maxMemory();
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
long duration = System.currentTimeMillis() - startTime;
sb.append(getClass().getSimpleName()).append(".")
.append(": ").append(duration).append(" ms")
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
.append(", jvmTotal=").append(jvmTotal)
.append(", jvmMax=").append(jvmMax)
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
if(ws != null && ws.size() > 0){
long currSize = 0;
for(MemoryWorkspace w : ws){
currSize += w.getCurrentSize();
}
if(currSize > 0){
sb.append(", threadWSSize=").append(currSize)
.append(" (").append(ws.size()).append(" WSs)");
}
}
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
Object o = p.get("cuda.devicesInformation");
if(o instanceof List){
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
if(l.size() > 0) {
sb.append(" [").append(l.size())
.append(" GPUs: ");
for (int i = 0; i < l.size(); i++) {
Map<String,Object> m = l.get(i);
if(i > 0)
sb.append(",");
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
.append(m.get("cuda.totalMemory")).append(" total)");
}
sb.append("]");
}
}
log.info(sb.toString());
}
}