202 lines
8.2 KiB
Java
202 lines
8.2 KiB
Java
/*
|
|
* ******************************************************************************
|
|
* *
|
|
* *
|
|
* * This program and the accompanying materials are made available under the
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
* *
|
|
* * See the NOTICE file distributed with this work for additional
|
|
* * information regarding copyright ownership.
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* * License for the specific language governing permissions and limitations
|
|
* * under the License.
|
|
* *
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
* *****************************************************************************
|
|
*/
|
|
|
|
package org.deeplearning4j;
|
|
|
|
import ch.qos.logback.classic.LoggerContext;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.bytedeco.javacpp.Pointer;
|
|
import org.junit.jupiter.api.AfterEach;
|
|
import org.junit.jupiter.api.Assumptions;
|
|
import org.junit.jupiter.api.BeforeEach;
|
|
import org.junit.jupiter.api.Timeout;
|
|
import org.nd4j.common.base.Preconditions;
|
|
import org.nd4j.common.config.ND4JSystemProperties;
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.profiler.ProfilerConfig;
|
|
import org.slf4j.ILoggerFactory;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
import java.lang.management.ManagementFactory;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.Properties;
|
|
|
|
@Slf4j
|
|
@Timeout(60*10)
|
|
public abstract class BaseDL4JTest {
|
|
|
|
protected long startTime;
|
|
protected int threadCountBefore;
|
|
|
|
private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
|
|
|
|
/**
|
|
* Override this to specify the number of threads for C++ execution, via
|
|
* {@link org.nd4j.linalg.factory.Environment#setMaxMasterThreads(int)}
|
|
* @return Number of threads to use for C++ op execution
|
|
*/
|
|
public int numThreads(){
|
|
return DEFAULT_THREADS;
|
|
}
|
|
|
|
/**
|
|
* Override this to set the profiling mode for the tests defined in the child class
|
|
*/
|
|
public OpExecutioner.ProfilingMode getProfilingMode(){
|
|
return OpExecutioner.ProfilingMode.SCOPE_PANIC;
|
|
}
|
|
|
|
/**
|
|
* Override this to set the datatype of the tests defined in the child class
|
|
*/
|
|
public DataType getDataType(){
|
|
return DataType.DOUBLE;
|
|
}
|
|
|
|
public DataType getDefaultFPDataType(){
|
|
return getDataType();
|
|
}
|
|
|
|
protected static Boolean integrationTest;
|
|
|
|
/**
|
|
* @return True if integration tests maven profile is enabled, false otherwise.
|
|
*/
|
|
public static boolean isIntegrationTests(){
|
|
if(integrationTest == null){
|
|
String prop = System.getenv("DL4J_INTEGRATION_TESTS");
|
|
integrationTest = Boolean.parseBoolean(prop);
|
|
}
|
|
return integrationTest;
|
|
}
|
|
|
|
/**
|
|
* Call this as the first line of a test in order to skip that test, only when the integration tests maven profile is not enabled.
|
|
* This can be used to dynamically skip integration tests when the integration test profile is not enabled.
|
|
* Note that the integration test profile is not enabled by default - "integration-tests" profile
|
|
*/
|
|
public static void skipUnlessIntegrationTests(){
|
|
Assumptions.assumeTrue(isIntegrationTests(), () -> "Skipping integration test - integration profile is not enabled");
|
|
}
|
|
|
|
@BeforeEach
|
|
public void beforeTest(){
|
|
log.info("{}:", getClass().getName());
|
|
//log.info("{}.{}", getClass().getSimpleName());
|
|
//Suppress ND4J initialization - don't need this logged for every test...
|
|
System.setProperty(ND4JSystemProperties.LOG_INITIALIZATION, "false");
|
|
System.setProperty(ND4JSystemProperties.ND4J_IGNORE_AVX, "true");
|
|
Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
|
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
|
Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
|
|
Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
|
|
Nd4j.getExecutioner().enableDebugMode(false);
|
|
Nd4j.getExecutioner().enableVerboseMode(false);
|
|
int numThreads = numThreads();
|
|
Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
|
|
if(numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
|
|
Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
|
|
}
|
|
startTime = System.currentTimeMillis();
|
|
threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
|
|
}
|
|
|
|
@AfterEach
|
|
public void afterTest(){
|
|
//Attempt to keep workspaces isolated between tests
|
|
Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
|
|
MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
|
|
Nd4j.getMemoryManager().setCurrentWorkspace(null);
|
|
if(currWS != null){
|
|
//Not really safe to continue testing under this situation... other tests will likely fail with obscure
|
|
// errors that are hard to track back to this
|
|
log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", currWS.getId(), currWS.isScopeActive(), currWS);
|
|
System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
|
|
System.out.flush();
|
|
//Try to flush logs also:
|
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
|
ILoggerFactory lf = LoggerFactory.getILoggerFactory();
|
|
if( lf instanceof LoggerContext){
|
|
((LoggerContext)lf).stop();
|
|
}
|
|
try{ Thread.sleep(1000); } catch (InterruptedException e){ }
|
|
System.exit(1);
|
|
}
|
|
|
|
StringBuilder sb = new StringBuilder();
|
|
long maxPhys = Pointer.maxPhysicalBytes();
|
|
long maxBytes = Pointer.maxBytes();
|
|
long currPhys = Pointer.physicalBytes();
|
|
long currBytes = Pointer.totalBytes();
|
|
|
|
long jvmTotal = Runtime.getRuntime().totalMemory();
|
|
long jvmMax = Runtime.getRuntime().maxMemory();
|
|
|
|
int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
|
|
|
|
long duration = System.currentTimeMillis() - startTime;
|
|
sb.append(getClass().getSimpleName()).append(".")
|
|
.append(": ").append(duration).append(" ms")
|
|
.append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")")
|
|
.append(", jvmTotal=").append(jvmTotal)
|
|
.append(", jvmMax=").append(jvmMax)
|
|
.append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes)
|
|
.append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
|
|
|
|
List<MemoryWorkspace> ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
|
|
if(ws != null && ws.size() > 0){
|
|
long currSize = 0;
|
|
for(MemoryWorkspace w : ws){
|
|
currSize += w.getCurrentSize();
|
|
}
|
|
if(currSize > 0){
|
|
sb.append(", threadWSSize=").append(currSize)
|
|
.append(" (").append(ws.size()).append(" WSs)");
|
|
}
|
|
}
|
|
|
|
|
|
Properties p = Nd4j.getExecutioner().getEnvironmentInformation();
|
|
Object o = p.get("cuda.devicesInformation");
|
|
if(o instanceof List){
|
|
List<Map<String,Object>> l = (List<Map<String, Object>>) o;
|
|
if(l.size() > 0) {
|
|
|
|
sb.append(" [").append(l.size())
|
|
.append(" GPUs: ");
|
|
|
|
for (int i = 0; i < l.size(); i++) {
|
|
Map<String,Object> m = l.get(i);
|
|
if(i > 0)
|
|
sb.append(",");
|
|
sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ")
|
|
.append(m.get("cuda.totalMemory")).append(" total)");
|
|
}
|
|
sb.append("]");
|
|
}
|
|
}
|
|
log.info(sb.toString());
|
|
}
|
|
}
|