simplify TrainModule to use getters of VertxUIServer instance for multi-session mode, session loader and server address
Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>master
parent
ef7c21c204
commit
f1ebced7a1
|
@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
private static Integer instancePort;
|
||||
private static Thread autoStopThread;
|
||||
|
||||
private TrainModule trainModule;
|
||||
|
||||
/**
|
||||
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||
* @param port TCP socket port for {@link HttpServer} to listen
|
||||
|
@ -208,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
|
||||
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
||||
private RemoteReceiverModule remoteReceiverModule;
|
||||
private StatsStorageLoader statsStorageLoader;
|
||||
/**
|
||||
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
||||
*/
|
||||
@Getter
|
||||
private Function<String, Boolean> statsStorageLoader;
|
||||
|
||||
//typeIDModuleMap: Records which modules are registered for which type IDs
|
||||
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
|
||||
|
@ -248,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
*/
|
||||
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
||||
if (statsStorageProvider != null) {
|
||||
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
||||
if (trainModule != null) {
|
||||
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
||||
}
|
||||
this.statsStorageLoader = (sessionId) -> {
|
||||
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
||||
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
||||
if (statsStorage != null) {
|
||||
if (statsStorage.sessionExists(sessionId)) {
|
||||
attach(statsStorage);
|
||||
return true;
|
||||
}
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
||||
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
||||
return false;
|
||||
} else {
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
||||
"StatsStorageProvider returned null.");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -303,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
|
||||
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
||||
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
||||
uiModules.add(trainModule);
|
||||
uiModules.add(new TrainModule());
|
||||
uiModules.add(new ConvolutionalListenerModule());
|
||||
uiModules.add(new TsneModule());
|
||||
uiModules.add(new SameDiffModule());
|
||||
|
@ -597,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
||||
*/
|
||||
private class StatsStorageLoader implements Function<String, Boolean> {
|
||||
|
||||
Function<String, StatsStorage> statsStorageProvider;
|
||||
|
||||
StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) {
|
||||
this.statsStorageProvider = statsStorageProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean apply(String sessionId) {
|
||||
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
||||
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
||||
if (statsStorage != null) {
|
||||
if (statsStorage.sessionExists(sessionId)) {
|
||||
attach(statsStorage);
|
||||
return true;
|
||||
}
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
||||
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
||||
return false;
|
||||
} else {
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
||||
"StatsStorageProvider returned null.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//==================================================================================================================
|
||||
// CLI Launcher
|
||||
|
||||
|
|
|
@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext;
|
|||
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
|
@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
|||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||
import org.deeplearning4j.ui.VertxUIServer;
|
||||
import org.deeplearning4j.ui.api.HttpMethod;
|
||||
import org.deeplearning4j.ui.api.I18N;
|
||||
import org.deeplearning4j.ui.api.Route;
|
||||
|
@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
|
|||
import org.deeplearning4j.ui.model.stats.api.StatsReport;
|
||||
import org.deeplearning4j.ui.model.stats.api.StatsType;
|
||||
import org.nd4j.common.function.Function;
|
||||
import org.nd4j.common.function.Supplier;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
@ -86,8 +84,6 @@ public class TrainModule implements UIModule {
|
|||
private static final DecimalFormat df2 = new DecimalFormat("#.00");
|
||||
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
||||
|
||||
private final Supplier<String> addressSupplier;
|
||||
|
||||
private enum ModelType {
|
||||
MLN, CG, Layer
|
||||
}
|
||||
|
@ -99,29 +95,14 @@ public class TrainModule implements UIModule {
|
|||
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
|
||||
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
|
||||
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
|
||||
private final boolean multiSession;
|
||||
@Getter @Setter
|
||||
private Function<String, Boolean> sessionLoader;
|
||||
|
||||
|
||||
private final Configuration configuration;
|
||||
|
||||
public TrainModule() {
|
||||
this(false, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* TrainModule
|
||||
*
|
||||
* @param multiSession multi-session mode
|
||||
* @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter
|
||||
* in multi-session mode
|
||||
* @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
|
||||
*/
|
||||
public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) {
|
||||
this.multiSession = multiSession;
|
||||
this.sessionLoader = sessionLoader;
|
||||
this.addressSupplier = addressSupplier;
|
||||
public TrainModule() {
|
||||
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
|
||||
int value = DEFAULT_MAX_CHART_POINTS;
|
||||
if (maxChartPointsProp != null) {
|
||||
|
@ -159,8 +140,9 @@ public class TrainModule implements UIModule {
|
|||
@Override
|
||||
public List<Route> getRoutes() {
|
||||
List<Route> r = new ArrayList<>();
|
||||
r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
|
||||
if (multiSession) {
|
||||
r.add(new Route("/train/multisession", HttpMethod.GET,
|
||||
(path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
||||
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
|
||||
rc.response()
|
||||
|
@ -264,7 +246,9 @@ public class TrainModule implements UIModule {
|
|||
if (!knownSessionIDs.isEmpty()) {
|
||||
sb.append(" <ul>");
|
||||
for (String sessionId : knownSessionIDs.keySet()) {
|
||||
sb.append(" <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n");
|
||||
sb.append(" <li><a href=\"train/")
|
||||
.append(sessionId).append("\">")
|
||||
.append(sessionId).append("</a></li>\n");
|
||||
}
|
||||
sb.append(" </ul>");
|
||||
} else {
|
||||
|
@ -284,9 +268,11 @@ public class TrainModule implements UIModule {
|
|||
*
|
||||
* @param sessionId session ID to look fo with provider
|
||||
* @param targetPath one of overview / model / system, or null
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
||||
if (sessionLoader != null && sessionLoader.apply(sessionId)) {
|
||||
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
|
||||
if (loader != null && loader.apply(sessionId)) {
|
||||
if (targetPath != null) {
|
||||
rc.reroute(targetPath);
|
||||
} else {
|
||||
|
@ -306,9 +292,9 @@ public class TrainModule implements UIModule {
|
|||
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
|
||||
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
|
||||
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
|
||||
if (multiSession) {
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||
addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage());
|
||||
VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -332,9 +318,9 @@ public class TrainModule implements UIModule {
|
|||
if (!StatsListener.TYPE_ID.equals(typeID))
|
||||
continue;
|
||||
knownSessionIDs.put(sessionID, statsStorage);
|
||||
if (multiSession) {
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||
addressSupplier.get(), sessionID, statsStorage);
|
||||
VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
|
||||
}
|
||||
|
||||
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
|
||||
|
@ -364,9 +350,9 @@ public class TrainModule implements UIModule {
|
|||
}
|
||||
for (String s : toRemove) {
|
||||
knownSessionIDs.remove(s);
|
||||
if (multiSession) {
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
|
||||
addressSupplier.get(), s, statsStorage);
|
||||
VertxUIServer.getInstance().getAddress(), s, statsStorage);
|
||||
}
|
||||
lastUpdateForSession.remove(s);
|
||||
}
|
||||
|
@ -602,13 +588,13 @@ public class TrainModule implements UIModule {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session
|
||||
* Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
|
||||
*
|
||||
* @param sessionId session ID
|
||||
* @return {@link I18N} instance
|
||||
*/
|
||||
private I18N getI18N(String sessionId) {
|
||||
return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
||||
return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue