multi-session routes in ArbiterModule, optimize getLastUpdateTime

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>
master
Tamás Fenyvesi 2020-04-30 11:39:52 +02:00
parent d3c759e03f
commit 64c9008ea7
2 changed files with 130 additions and 65 deletions

View File

@ -51,6 +51,7 @@ import org.deeplearning4j.ui.components.text.style.StyleText;
import org.deeplearning4j.ui.i18n.I18NResource; import org.deeplearning4j.ui.i18n.I18NResource;
import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.DateTimeFormatter;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
@ -140,29 +141,89 @@ public class ArbiterModule implements UIModule {
(path, rc) -> rc.response().end(multiSession ? "true" : "false"))); (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
if (multiSession) { if (multiSession) {
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> rc.response()
.putHeader("content-type", "text/html; charset=utf-8")
.sendFile("templates/ArbiterUI.html")));
r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getLastUpdateTime(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getCandidateInfo(path.get(0), path.get(1), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getOptimizationConfig(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getSummaryResults(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getSummaryStatus(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
} else { } else {
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
.putHeader("content-type", "text/html; charset=utf-8") .putHeader("content-type", "text/html; charset=utf-8")
.sendFile("templates/ArbiterUI.html"))); .sendFile("templates/ArbiterUI.html")));
r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc))); r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc)));
r.add(new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET,
(path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc)));
r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET,
(path, rc) -> this.getCandidateInfo(path.get(0), rc))); (path, rc) -> this.getCandidateInfo(null, path.get(0), rc)));
r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc))); r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc)));
r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc))); r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc)));
r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc))); r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc)));
r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc))); r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)));
r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET, r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET,
(path, rc) -> this.setSession(path.get(0), rc))); (path, rc) -> this.setSession(path.get(0), rc)));
} }
// common for single- and multi-session mode
r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
return r; return r;
} }
/**
* Load StatsStorage via provider, or return "not found"
*
* @param sessionId session ID to look fo with provider
* @param targetPath one of overview / model / system, or null
* @param rc routing context
*/
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
if (loader != null && loader.apply(sessionId)) {
if (targetPath != null) {
rc.reroute(targetPath);
} else {
rc.response().end();
}
} else {
rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
.end("Unknown session ID: " + sessionId);
}
}
/** /**
* List optimization sessions. Returns a HTML list of arbiter sessions * List optimization sessions. Returns a HTML list of arbiter sessions
*/ */
@ -182,7 +243,7 @@ public class ArbiterModule implements UIModule {
if (!knownSessionIDs.isEmpty()) { if (!knownSessionIDs.isEmpty()) {
sb.append(" <ul>"); sb.append(" <ul>");
for (String sessionId : knownSessionIDs.keySet()) { for (String sessionId : knownSessionIDs.keySet()) {
sb.append(" <li><a href=\"train/") sb.append(" <li><a href=\"arbiter/")
.append(sessionId).append("\">") .append(sessionId).append("\">")
.append(sessionId).append("</a></li>\n"); .append(sessionId).append("</a></li>\n");
} }
@ -306,10 +367,25 @@ public class ArbiterModule implements UIModule {
/** /**
* Return the last update time for the page * Return the last update time for the page
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/ */
private void getLastUpdateTime(RoutingContext rc){ private void getLastUpdateTime(String sessionId, RoutingContext rc){
//TODO - this forces updates on every request... which is fine, just inefficient if (sessionId == null) {
long t = System.currentTimeMillis(); sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
List<Persistable> latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID);
long t = 0;
if (latestUpdates.isEmpty()) {
t = System.currentTimeMillis();
} else {
for (Persistable update : latestUpdates) {
if (update.getTimeStamp() > t) {
t = update.getTimeStamp();
}
}
}
UpdateStatus us = new UpdateStatus(t, t, t); UpdateStatus us = new UpdateStatus(t, t, t);
rc.response().putHeader("content-type", "application/json").end(asJson(us)); rc.response().putHeader("content-type", "application/json").end(asJson(us));
@ -323,56 +399,28 @@ public class ArbiterModule implements UIModule {
} }
} }
/**
* Get the last update time for the specified model IDs
* @param modelIDs Model IDs to get the update time for
*/
private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){
if(currentSessionID == null){
rc.response().end();
return;
}
StatsStorage ss = knownSessionIDs.get(currentSessionID);
if(ss == null){
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
rc.response().end("-1");
return;
}
String[] split = modelIDs.split(",");
long[] lastUpdateTimes = new long[split.length];
for( int i=0; i<split.length; i++ ){
String s = split[i];
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, s);
if(p != null){
lastUpdateTimes[i] = p.getTimeStamp();
}
}
rc.response().putHeader("content-type", "application/json").end(asJson(lastUpdateTimes));
}
/** /**
* Get the info for a specific candidate - last section in the UI * Get the info for a specific candidate - last section in the UI
* * @param sessionId session ID (optional, for multi-session mode)
* @param candidateId ID for the candidate * @param candidateId ID for the candidate
* @param rc routing context
*/ */
private void getCandidateInfo(String candidateId, RoutingContext rc){ private void getCandidateInfo(String sessionId, String candidateId, RoutingContext rc){
if (sessionId == null) {
StatsStorage ss = knownSessionIDs.get(currentSessionID); sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss
.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); .getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); OptimizationConfiguration oc = gcp.getOptimizationConfiguration();
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId); Persistable p = ss.getLatestUpdate(sessionId, ARBITER_UI_TYPE_ID, candidateId);
if(p == null){ if(p == null){
String title = "No results found for model " + candidateId + "."; String title = "No results found for model " + candidateId + ".";
ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build(); ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build();
@ -542,17 +590,21 @@ public class ArbiterModule implements UIModule {
/** /**
* Get the optimization configuration - second section in the page * Get the optimization configuration - second section in the page
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/ */
private void getOptimizationConfig(RoutingContext rc){ private void getOptimizationConfig(String sessionId, RoutingContext rc){
if (sessionId == null) {
StatsStorage ss = knownSessionIDs.get(currentSessionID); sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
if(p == null){ if(p == null){
log.debug("No static info"); log.debug("No static info");
@ -642,15 +694,23 @@ public class ArbiterModule implements UIModule {
rc.response().putHeader("content-type", "application/json").end(asJson(cd)); rc.response().putHeader("content-type", "application/json").end(asJson(cd));
} }
private void getSummaryResults(RoutingContext rc){ /**
StatsStorage ss = knownSessionIDs.get(currentSessionID); * Get candidates summary results list - third section on the page: Results table
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/
private void getSummaryResults(String sessionId, RoutingContext rc){
if (sessionId == null) {
sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID); log.debug("getSummaryResults(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
List<String[]> table = new ArrayList<>(); List<String[]> table = new ArrayList<>();
for(Persistable per : allModelInfoTemp){ for(Persistable per : allModelInfoTemp){
ModelInfoPersistable mip = (ModelInfoPersistable)per; ModelInfoPersistable mip = (ModelInfoPersistable)per;
@ -663,16 +723,21 @@ public class ArbiterModule implements UIModule {
/** /**
* Get summary status information: first section in the page * Get summary status information: first section in the page
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/ */
private void getSummaryStatus(RoutingContext rc){ private void getSummaryStatus(String sessionId, RoutingContext rc){
StatsStorage ss = knownSessionIDs.get(currentSessionID); if (sessionId == null) {
sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
if(p == null){ if(p == null){
log.info("No static info"); log.info("No static info");
@ -692,7 +757,7 @@ public class ArbiterModule implements UIModule {
//How to get this? query all model infos... //How to get this? query all model infos...
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
List<ModelInfoPersistable> allModelInfo = new ArrayList<>(); List<ModelInfoPersistable> allModelInfo = new ArrayList<>();
for(Persistable per : allModelInfoTemp){ for(Persistable per : allModelInfoTemp){
ModelInfoPersistable mip = (ModelInfoPersistable)per; ModelInfoPersistable mip = (ModelInfoPersistable)per;

View File

@ -160,7 +160,7 @@ public class TestBasic extends BaseDL4JTest {
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
String serverAddress = uIServer.getAddress(); String serverAddress = uIServer.getAddress();
for (String sessionId : statsStorageForSession.keySet()) { for (String sessionId : statsStorageForSession.keySet()) {
log.info("Arbiter session will start at {}/arbiter/{}", serverAddress, sessionId); log.info("Arbiter session can be attached at {}/arbiter/{}", serverAddress, sessionId);
} }
runner.execute(); runner.execute();