simplify TrainModule to use getters of VertxUIServer instance for multi-session mode, session loader and server address

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>
master
Tamás Fenyvesi 2020-04-23 15:05:29 +02:00
parent ef7c21c204
commit f1ebced7a1
2 changed files with 42 additions and 73 deletions

View File

@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
private static Integer instancePort;
private static Thread autoStopThread;
private TrainModule trainModule;
/**
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* @param port TCP socket port for {@link HttpServer} to listen
@ -208,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
private RemoteReceiverModule remoteReceiverModule;
private StatsStorageLoader statsStorageLoader;
/**
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
*/
@Getter
private Function<String, Boolean> statsStorageLoader;
//typeIDModuleMap: Records which modules are registered for which type IDs
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
@ -248,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
*/
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
if (statsStorageProvider != null) {
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
if (trainModule != null) {
this.trainModule.setSessionLoader(this.statsStorageLoader);
}
this.statsStorageLoader = (sessionId) -> {
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
if (statsStorage != null) {
if (statsStorage.sessionExists(sessionId)) {
attach(statsStorage);
return true;
}
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
return false;
} else {
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
"StatsStorageProvider returned null.");
return false;
}
};
}
}
@ -303,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
uiModules.add(trainModule);
uiModules.add(new TrainModule());
uiModules.add(new ConvolutionalListenerModule());
uiModules.add(new TsneModule());
uiModules.add(new SameDiffModule());
@ -597,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
}
/**
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
*/
private class StatsStorageLoader implements Function<String, Boolean> {
Function<String, StatsStorage> statsStorageProvider;
StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) {
this.statsStorageProvider = statsStorageProvider;
}
@Override
public Boolean apply(String sessionId) {
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
if (statsStorage != null) {
if (statsStorage.sessionExists(sessionId)) {
attach(statsStorage);
return true;
}
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
return false;
} else {
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
"StatsStorageProvider returned null.");
return false;
}
}
}
//==================================================================================================================
// CLI Launcher

View File

@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.ui.VertxUIServer;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.I18N;
import org.deeplearning4j.ui.api.Route;
@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.model.stats.api.StatsReport;
import org.deeplearning4j.ui.model.stats.api.StatsType;
import org.nd4j.common.function.Function;
import org.nd4j.common.function.Supplier;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
@ -86,8 +84,6 @@ public class TrainModule implements UIModule {
private static final DecimalFormat df2 = new DecimalFormat("#.00");
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
private final Supplier<String> addressSupplier;
private enum ModelType {
MLN, CG, Layer
}
@ -99,29 +95,14 @@ public class TrainModule implements UIModule {
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
private final boolean multiSession;
@Getter @Setter
private Function<String, Boolean> sessionLoader;
private final Configuration configuration;
public TrainModule() {
this(false, null, null);
}
/**
* TrainModule
*
* @param multiSession multi-session mode
* @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter
* in multi-session mode
* @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
*/
public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) {
this.multiSession = multiSession;
this.sessionLoader = sessionLoader;
this.addressSupplier = addressSupplier;
public TrainModule() {
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
int value = DEFAULT_MAX_CHART_POINTS;
if (maxChartPointsProp != null) {
@ -159,8 +140,9 @@ public class TrainModule implements UIModule {
@Override
public List<Route> getRoutes() {
List<Route> r = new ArrayList<>();
r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
if (multiSession) {
r.add(new Route("/train/multisession", HttpMethod.GET,
(path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
if (VertxUIServer.getInstance().isMultiSession()) {
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
rc.response()
@ -264,7 +246,9 @@ public class TrainModule implements UIModule {
if (!knownSessionIDs.isEmpty()) {
sb.append(" <ul>");
for (String sessionId : knownSessionIDs.keySet()) {
sb.append(" <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n");
sb.append(" <li><a href=\"train/")
.append(sessionId).append("\">")
.append(sessionId).append("</a></li>\n");
}
sb.append(" </ul>");
} else {
@ -284,9 +268,11 @@ public class TrainModule implements UIModule {
*
* @param sessionId session ID to look fo with provider
* @param targetPath one of overview / model / system, or null
* @param rc routing context
*/
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
if (sessionLoader != null && sessionLoader.apply(sessionId)) {
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
if (loader != null && loader.apply(sessionId)) {
if (targetPath != null) {
rc.reroute(targetPath);
} else {
@ -306,9 +292,9 @@ public class TrainModule implements UIModule {
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
if (multiSession) {
if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage());
VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
}
}
@ -332,9 +318,9 @@ public class TrainModule implements UIModule {
if (!StatsListener.TYPE_ID.equals(typeID))
continue;
knownSessionIDs.put(sessionID, statsStorage);
if (multiSession) {
if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
addressSupplier.get(), sessionID, statsStorage);
VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
}
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
@ -364,9 +350,9 @@ public class TrainModule implements UIModule {
}
for (String s : toRemove) {
knownSessionIDs.remove(s);
if (multiSession) {
if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
addressSupplier.get(), s, statsStorage);
VertxUIServer.getInstance().getAddress(), s, statsStorage);
}
lastUpdateForSession.remove(s);
}
@ -602,13 +588,13 @@ public class TrainModule implements UIModule {
}
/**
* Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session
* Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
*
* @param sessionId session ID
* @return {@link I18N} instance
*/
private I18N getI18N(String sessionId) {
return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
}