simplify TrainModule to use getters of VertxUIServer instance for multi-session mode, session loader and server address
Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>master
parent
ef7c21c204
commit
f1ebced7a1
|
@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
private static Integer instancePort;
|
private static Integer instancePort;
|
||||||
private static Thread autoStopThread;
|
private static Thread autoStopThread;
|
||||||
|
|
||||||
private TrainModule trainModule;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||||
* @param port TCP socket port for {@link HttpServer} to listen
|
* @param port TCP socket port for {@link HttpServer} to listen
|
||||||
|
@ -208,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
|
|
||||||
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
||||||
private RemoteReceiverModule remoteReceiverModule;
|
private RemoteReceiverModule remoteReceiverModule;
|
||||||
private StatsStorageLoader statsStorageLoader;
|
/**
|
||||||
|
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private Function<String, Boolean> statsStorageLoader;
|
||||||
|
|
||||||
//typeIDModuleMap: Records which modules are registered for which type IDs
|
//typeIDModuleMap: Records which modules are registered for which type IDs
|
||||||
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
|
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
|
||||||
|
@ -248,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
*/
|
*/
|
||||||
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
||||||
if (statsStorageProvider != null) {
|
if (statsStorageProvider != null) {
|
||||||
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
this.statsStorageLoader = (sessionId) -> {
|
||||||
if (trainModule != null) {
|
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
||||||
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
||||||
|
if (statsStorage != null) {
|
||||||
|
if (statsStorage.sessionExists(sessionId)) {
|
||||||
|
attach(statsStorage);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
||||||
|
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
||||||
|
"StatsStorageProvider returned null.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -303,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
||||||
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
uiModules.add(new TrainModule());
|
||||||
uiModules.add(trainModule);
|
|
||||||
uiModules.add(new ConvolutionalListenerModule());
|
uiModules.add(new ConvolutionalListenerModule());
|
||||||
uiModules.add(new TsneModule());
|
uiModules.add(new TsneModule());
|
||||||
uiModules.add(new SameDiffModule());
|
uiModules.add(new SameDiffModule());
|
||||||
|
@ -597,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
|
||||||
*/
|
|
||||||
private class StatsStorageLoader implements Function<String, Boolean> {
|
|
||||||
|
|
||||||
Function<String, StatsStorage> statsStorageProvider;
|
|
||||||
|
|
||||||
StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) {
|
|
||||||
this.statsStorageProvider = statsStorageProvider;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Boolean apply(String sessionId) {
|
|
||||||
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
|
||||||
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
|
||||||
if (statsStorage != null) {
|
|
||||||
if (statsStorage.sessionExists(sessionId)) {
|
|
||||||
attach(statsStorage);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
|
||||||
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
|
||||||
"StatsStorageProvider returned null.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//==================================================================================================================
|
//==================================================================================================================
|
||||||
// CLI Launcher
|
// CLI Launcher
|
||||||
|
|
||||||
|
|
|
@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext;
|
||||||
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
|
@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
|
import org.deeplearning4j.ui.VertxUIServer;
|
||||||
import org.deeplearning4j.ui.api.HttpMethod;
|
import org.deeplearning4j.ui.api.HttpMethod;
|
||||||
import org.deeplearning4j.ui.api.I18N;
|
import org.deeplearning4j.ui.api.I18N;
|
||||||
import org.deeplearning4j.ui.api.Route;
|
import org.deeplearning4j.ui.api.Route;
|
||||||
|
@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
|
||||||
import org.deeplearning4j.ui.model.stats.api.StatsReport;
|
import org.deeplearning4j.ui.model.stats.api.StatsReport;
|
||||||
import org.deeplearning4j.ui.model.stats.api.StatsType;
|
import org.deeplearning4j.ui.model.stats.api.StatsType;
|
||||||
import org.nd4j.common.function.Function;
|
import org.nd4j.common.function.Function;
|
||||||
import org.nd4j.common.function.Supplier;
|
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.common.primitives.Triple;
|
import org.nd4j.common.primitives.Triple;
|
||||||
|
@ -86,8 +84,6 @@ public class TrainModule implements UIModule {
|
||||||
private static final DecimalFormat df2 = new DecimalFormat("#.00");
|
private static final DecimalFormat df2 = new DecimalFormat("#.00");
|
||||||
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
||||||
|
|
||||||
private final Supplier<String> addressSupplier;
|
|
||||||
|
|
||||||
private enum ModelType {
|
private enum ModelType {
|
||||||
MLN, CG, Layer
|
MLN, CG, Layer
|
||||||
}
|
}
|
||||||
|
@ -99,29 +95,14 @@ public class TrainModule implements UIModule {
|
||||||
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
|
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
|
||||||
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
|
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
|
||||||
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
|
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
|
||||||
private final boolean multiSession;
|
|
||||||
@Getter @Setter
|
|
||||||
private Function<String, Boolean> sessionLoader;
|
|
||||||
|
|
||||||
|
|
||||||
private final Configuration configuration;
|
private final Configuration configuration;
|
||||||
|
|
||||||
public TrainModule() {
|
|
||||||
this(false, null, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TrainModule
|
* TrainModule
|
||||||
*
|
|
||||||
* @param multiSession multi-session mode
|
|
||||||
* @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter
|
|
||||||
* in multi-session mode
|
|
||||||
* @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
|
|
||||||
*/
|
*/
|
||||||
public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) {
|
public TrainModule() {
|
||||||
this.multiSession = multiSession;
|
|
||||||
this.sessionLoader = sessionLoader;
|
|
||||||
this.addressSupplier = addressSupplier;
|
|
||||||
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
|
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
|
||||||
int value = DEFAULT_MAX_CHART_POINTS;
|
int value = DEFAULT_MAX_CHART_POINTS;
|
||||||
if (maxChartPointsProp != null) {
|
if (maxChartPointsProp != null) {
|
||||||
|
@ -159,8 +140,9 @@ public class TrainModule implements UIModule {
|
||||||
@Override
|
@Override
|
||||||
public List<Route> getRoutes() {
|
public List<Route> getRoutes() {
|
||||||
List<Route> r = new ArrayList<>();
|
List<Route> r = new ArrayList<>();
|
||||||
r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
|
r.add(new Route("/train/multisession", HttpMethod.GET,
|
||||||
if (multiSession) {
|
(path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
|
||||||
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
||||||
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
|
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
|
||||||
rc.response()
|
rc.response()
|
||||||
|
@ -264,7 +246,9 @@ public class TrainModule implements UIModule {
|
||||||
if (!knownSessionIDs.isEmpty()) {
|
if (!knownSessionIDs.isEmpty()) {
|
||||||
sb.append(" <ul>");
|
sb.append(" <ul>");
|
||||||
for (String sessionId : knownSessionIDs.keySet()) {
|
for (String sessionId : knownSessionIDs.keySet()) {
|
||||||
sb.append(" <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n");
|
sb.append(" <li><a href=\"train/")
|
||||||
|
.append(sessionId).append("\">")
|
||||||
|
.append(sessionId).append("</a></li>\n");
|
||||||
}
|
}
|
||||||
sb.append(" </ul>");
|
sb.append(" </ul>");
|
||||||
} else {
|
} else {
|
||||||
|
@ -284,9 +268,11 @@ public class TrainModule implements UIModule {
|
||||||
*
|
*
|
||||||
* @param sessionId session ID to look fo with provider
|
* @param sessionId session ID to look fo with provider
|
||||||
* @param targetPath one of overview / model / system, or null
|
* @param targetPath one of overview / model / system, or null
|
||||||
|
* @param rc routing context
|
||||||
*/
|
*/
|
||||||
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
||||||
if (sessionLoader != null && sessionLoader.apply(sessionId)) {
|
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
|
||||||
|
if (loader != null && loader.apply(sessionId)) {
|
||||||
if (targetPath != null) {
|
if (targetPath != null) {
|
||||||
rc.reroute(targetPath);
|
rc.reroute(targetPath);
|
||||||
} else {
|
} else {
|
||||||
|
@ -306,9 +292,9 @@ public class TrainModule implements UIModule {
|
||||||
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
|
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
|
||||||
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
|
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
|
||||||
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
|
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
|
||||||
if (multiSession) {
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||||
addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage());
|
VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,9 +318,9 @@ public class TrainModule implements UIModule {
|
||||||
if (!StatsListener.TYPE_ID.equals(typeID))
|
if (!StatsListener.TYPE_ID.equals(typeID))
|
||||||
continue;
|
continue;
|
||||||
knownSessionIDs.put(sessionID, statsStorage);
|
knownSessionIDs.put(sessionID, statsStorage);
|
||||||
if (multiSession) {
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||||
addressSupplier.get(), sessionID, statsStorage);
|
VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
|
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
|
||||||
|
@ -364,9 +350,9 @@ public class TrainModule implements UIModule {
|
||||||
}
|
}
|
||||||
for (String s : toRemove) {
|
for (String s : toRemove) {
|
||||||
knownSessionIDs.remove(s);
|
knownSessionIDs.remove(s);
|
||||||
if (multiSession) {
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
|
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
|
||||||
addressSupplier.get(), s, statsStorage);
|
VertxUIServer.getInstance().getAddress(), s, statsStorage);
|
||||||
}
|
}
|
||||||
lastUpdateForSession.remove(s);
|
lastUpdateForSession.remove(s);
|
||||||
}
|
}
|
||||||
|
@ -602,13 +588,13 @@ public class TrainModule implements UIModule {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session
|
* Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
|
||||||
*
|
*
|
||||||
* @param sessionId session ID
|
* @param sessionId session ID
|
||||||
* @return {@link I18N} instance
|
* @return {@link I18N} instance
|
||||||
*/
|
*/
|
||||||
private I18N getI18N(String sessionId) {
|
private I18N getI18N(String sessionId) {
|
||||||
return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue