From 64c9008ea765f8f869c43463241807338d8ebd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Thu, 30 Apr 2020 11:39:52 +0200 Subject: [PATCH] multi-session routes in ArbiterModule, optimize getLastUpdateTime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../arbiter/ui/module/ArbiterModule.java | 193 ++++++++++++------ .../arbiter/optimize/TestBasic.java | 2 +- 2 files changed, 130 insertions(+), 65 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java index 6bcf3acbd..dab0174df 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java @@ -51,6 +51,7 @@ import org.deeplearning4j.ui.components.text.style.StyleText; import org.deeplearning4j.ui.i18n.I18NResource; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; +import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; import org.nd4j.shade.jackson.core.JsonProcessingException; @@ -140,29 +141,89 @@ public class ArbiterModule implements UIModule { (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); if (multiSession) { r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); + r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> rc.response() + .putHeader("content-type", "text/html; charset=utf-8") + .sendFile("templates/ArbiterUI.html"))); + + r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getLastUpdateTime(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getCandidateInfo(path.get(0), path.get(1), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getOptimizationConfig(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getSummaryResults(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getSummaryStatus(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); } else { r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() .putHeader("content-type", "text/html; charset=utf-8") .sendFile("templates/ArbiterUI.html"))); - r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc))); - r.add(new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, - (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc))); + r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc))); r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, - (path, rc) -> this.getCandidateInfo(path.get(0), rc))); - r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc))); - r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc))); - r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc))); + (path, rc) -> this.getCandidateInfo(null, path.get(0), rc))); + r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc))); + r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc))); + r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc))); - r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc))); r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc))); } + // common for single- and multi-session mode + r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); return r; } + /** + * Load StatsStorage via provider, or return "not found" + * + * @param sessionId session ID to look fo with provider + * @param targetPath one of overview / model / system, or null + * @param rc routing context + */ + private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { + Function loader = VertxUIServer.getInstance().getStatsStorageLoader(); + if (loader != null && loader.apply(sessionId)) { + if (targetPath != null) { + rc.reroute(targetPath); + } else { + rc.response().end(); + } + } else { + rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code()) + .end("Unknown session ID: " + sessionId); + } + } + + /** * List optimization sessions. Returns a HTML list of arbiter sessions */ @@ -182,7 +243,7 @@ public class ArbiterModule implements UIModule { if (!knownSessionIDs.isEmpty()) { sb.append("
    "); for (String sessionId : knownSessionIDs.keySet()) { - sb.append("
  • ") .append(sessionId).append("
  • \n"); } @@ -306,10 +367,25 @@ public class ArbiterModule implements UIModule { /** * Return the last update time for the page + * @param sessionId session ID (optional, for multi-session mode) + * @param rc routing context */ - private void getLastUpdateTime(RoutingContext rc){ - //TODO - this forces updates on every request... which is fine, just inefficient - long t = System.currentTimeMillis(); + private void getLastUpdateTime(String sessionId, RoutingContext rc){ + if (sessionId == null) { + sessionId = currentSessionID; + } + StatsStorage ss = knownSessionIDs.get(sessionId); + List latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID); + long t = 0; + if (latestUpdates.isEmpty()) { + t = System.currentTimeMillis(); + } else { + for (Persistable update : latestUpdates) { + if (update.getTimeStamp() > t) { + t = update.getTimeStamp(); + } + } + } UpdateStatus us = new UpdateStatus(t, t, t); rc.response().putHeader("content-type", "application/json").end(asJson(us)); @@ -323,56 +399,28 @@ public class ArbiterModule implements UIModule { } } - /** - * Get the last update time for the specified model IDs - * @param modelIDs Model IDs to get the update time for - */ - private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){ - if(currentSessionID == null){ - rc.response().end(); - return; - } - - StatsStorage ss = knownSessionIDs.get(currentSessionID); - if(ss == null){ - log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); - rc.response().end("-1"); - return; - } - - String[] split = modelIDs.split(","); - - long[] lastUpdateTimes = new long[split.length]; - for( int i=0; i allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); + List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); List table = new ArrayList<>(); for(Persistable per : allModelInfoTemp){ ModelInfoPersistable mip = (ModelInfoPersistable)per; @@ -663,16 +723,21 @@ public class ArbiterModule implements UIModule { /** * Get summary status information: first section in the page + * @param sessionId session ID (optional, for multi-session mode) + * @param rc routing context */ - private void getSummaryStatus(RoutingContext rc){ - StatsStorage ss = knownSessionIDs.get(currentSessionID); + private void getSummaryStatus(String sessionId, RoutingContext rc){ + if (sessionId == null) { + sessionId = currentSessionID; + } + StatsStorage ss = knownSessionIDs.get(sessionId); if(ss == null){ - log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); + log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId); rc.response().end(); return; } - Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); + Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); if(p == null){ log.info("No static info"); @@ -692,7 +757,7 @@ public class ArbiterModule implements UIModule { //How to get this? query all model infos... - List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); + List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); List allModelInfo = new ArrayList<>(); for(Persistable per : allModelInfoTemp){ ModelInfoPersistable mip = (ModelInfoPersistable)per; diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index cb2374d64..025ce85c6 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -160,7 +160,7 @@ public class TestBasic extends BaseDL4JTest { UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); String serverAddress = uIServer.getAddress(); for (String sessionId : statsStorageForSession.keySet()) { - log.info("Arbiter session will start at {}/arbiter/{}", serverAddress, sessionId); + log.info("Arbiter session can be attached at {}/arbiter/{}", serverAddress, sessionId); } runner.execute();