Merge pull request #8872 from printomi/ui_multisession_arbiter
DL4J UI: Add multi-session support for Arbiter
This commit is contained in:
		
						commit
						a76f957b72
					
				| @ -36,6 +36,7 @@ import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; | |||||||
| import org.deeplearning4j.arbiter.ui.misc.UIUtils; | import org.deeplearning4j.arbiter.ui.misc.UIUtils; | ||||||
| import org.deeplearning4j.arbiter.util.ObjectUtils; | import org.deeplearning4j.arbiter.util.ObjectUtils; | ||||||
| import org.deeplearning4j.nn.conf.serde.JsonMappers; | import org.deeplearning4j.nn.conf.serde.JsonMappers; | ||||||
|  | import org.deeplearning4j.ui.VertxUIServer; | ||||||
| import org.deeplearning4j.ui.api.Component; | import org.deeplearning4j.ui.api.Component; | ||||||
| import org.deeplearning4j.ui.api.*; | import org.deeplearning4j.ui.api.*; | ||||||
| import org.deeplearning4j.ui.components.chart.ChartLine; | import org.deeplearning4j.ui.components.chart.ChartLine; | ||||||
| @ -50,6 +51,7 @@ import org.deeplearning4j.ui.components.text.style.StyleText; | |||||||
| import org.deeplearning4j.ui.i18n.I18NResource; | import org.deeplearning4j.ui.i18n.I18NResource; | ||||||
| import org.joda.time.format.DateTimeFormat; | import org.joda.time.format.DateTimeFormat; | ||||||
| import org.joda.time.format.DateTimeFormatter; | import org.joda.time.format.DateTimeFormatter; | ||||||
|  | import org.nd4j.common.function.Function; | ||||||
| import org.nd4j.common.primitives.Pair; | import org.nd4j.common.primitives.Pair; | ||||||
| import org.nd4j.shade.jackson.core.JsonProcessingException; | import org.nd4j.shade.jackson.core.JsonProcessingException; | ||||||
| 
 | 
 | ||||||
| @ -77,7 +79,6 @@ public class ArbiterModule implements UIModule { | |||||||
| 
 | 
 | ||||||
|     private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>()); |     private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>()); | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     //Styles for UI: |     //Styles for UI: | ||||||
|     private static final StyleTable STYLE_TABLE = new StyleTable.Builder() |     private static final StyleTable STYLE_TABLE = new StyleTable.Builder() | ||||||
|             .width(100, LengthUnit.Percent) |             .width(100, LengthUnit.Percent) | ||||||
| @ -134,20 +135,135 @@ public class ArbiterModule implements UIModule { | |||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public List<Route> getRoutes() { |     public List<Route> getRoutes() { | ||||||
|         Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() |         boolean multiSession = VertxUIServer.getMultiSession().get(); | ||||||
|                 .putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html")); |         List<Route> r = new ArrayList<>(); | ||||||
|         Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc)); |         r.add(new Route("/arbiter/multisession", HttpMethod.GET, | ||||||
|         Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc)); |                 (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); | ||||||
|         Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc)); |         if (multiSession) { | ||||||
|         Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc)); |             r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); | ||||||
|         Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc)); |             r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> { | ||||||
|         Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc)); |                 if (knownSessionIDs.containsKey(path.get(0))) { | ||||||
|  |                     rc.response() | ||||||
|  |                             .putHeader("content-type", "text/html; charset=utf-8") | ||||||
|  |                             .sendFile("templates/ArbiterUI.html"); | ||||||
|  |                 } else { | ||||||
|  |                     sessionNotFound(path.get(0), rc.request().path(), rc); | ||||||
|  |                 } | ||||||
|  |             })); | ||||||
| 
 | 
 | ||||||
|         Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc)); |             r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> { | ||||||
|         Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)); |                 if (knownSessionIDs.containsKey(path.get(0))) { | ||||||
|         Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)); |                     this.getLastUpdateTime(path.get(0), rc); | ||||||
|  |                 } else { | ||||||
|  |                     sessionNotFound(path.get(0), rc.request().path(), rc); | ||||||
|  |                 } | ||||||
|  |             })); | ||||||
|  |             r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> { | ||||||
|  |                 if (knownSessionIDs.containsKey(path.get(0))) { | ||||||
|  |                     this.getCandidateInfo(path.get(0), path.get(1), rc); | ||||||
|  |                 } else { | ||||||
|  |                     sessionNotFound(path.get(0), rc.request().path(), rc); | ||||||
|  |                 } | ||||||
|  |             })); | ||||||
|  |             r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> { | ||||||
|  |                 if (knownSessionIDs.containsKey(path.get(0))) { | ||||||
|  |                     this.getOptimizationConfig(path.get(0), rc); | ||||||
|  |                 } else { | ||||||
|  |                     sessionNotFound(path.get(0), rc.request().path(), rc); | ||||||
|  |                 } | ||||||
|  |             })); | ||||||
|  |             r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> { | ||||||
|  |                 if (knownSessionIDs.containsKey(path.get(0))) { | ||||||
|  |                     this.getSummaryResults(path.get(0), rc); | ||||||
|  |                 } else { | ||||||
|  |                     sessionNotFound(path.get(0), rc.request().path(), rc); | ||||||
|  |                 } | ||||||
|  |             })); | ||||||
|  |             r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> { | ||||||
|  |                 if (knownSessionIDs.containsKey(path.get(0))) { | ||||||
|  |                     this.getSummaryStatus(path.get(0), rc); | ||||||
|  |                 } else { | ||||||
|  |                     sessionNotFound(path.get(0), rc.request().path(), rc); | ||||||
|  |                 } | ||||||
|  |             })); | ||||||
|  |         } else { | ||||||
|  |             r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() | ||||||
|  |                     .putHeader("content-type", "text/html; charset=utf-8") | ||||||
|  |                     .sendFile("templates/ArbiterUI.html"))); | ||||||
|  |             r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc))); | ||||||
|  |             r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, | ||||||
|  |                     (path, rc) -> this.getCandidateInfo(null, path.get(0), rc))); | ||||||
|  |             r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc))); | ||||||
|  |             r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc))); | ||||||
|  |             r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc))); | ||||||
| 
 | 
 | ||||||
|         return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c); |             r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc))); | ||||||
|  |             r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET, | ||||||
|  |                     (path, rc) -> this.setSession(path.get(0), rc))); | ||||||
|  |         } | ||||||
|  |         // common for single- and multi-session mode | ||||||
|  |         r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); | ||||||
|  | 
 | ||||||
|  |         return r; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Load StatsStorage via provider, or return "not found" | ||||||
|  |      * | ||||||
|  |      * @param sessionId  session ID to look fo with provider | ||||||
|  |      * @param targetPath one of overview / model / system, or null | ||||||
|  |      * @param rc routing context | ||||||
|  |      */ | ||||||
|  |     private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { | ||||||
|  |         Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader(); | ||||||
|  |         if (loader != null && loader.apply(sessionId)) { | ||||||
|  |             if (targetPath != null) { | ||||||
|  |                 rc.reroute(targetPath); | ||||||
|  |             } else { | ||||||
|  |                 rc.response().end(); | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code()) | ||||||
|  |                     .end("Unknown session ID: " + sessionId); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * List optimization sessions. Returns a HTML list of arbiter sessions | ||||||
|  |      */ | ||||||
|  |     private synchronized void listSessions(RoutingContext rc) { | ||||||
|  |         StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" + | ||||||
|  |                 "<html lang=\"en\">\n" + | ||||||
|  |                 "<head>\n" + | ||||||
|  |                 "        <meta charset=\"utf-8\">\n" + | ||||||
|  |                 "        <title>Optimization sessions - DL4J Arbiter UI</title>\n" + | ||||||
|  |                 "    </head>\n" + | ||||||
|  |                 "\n" + | ||||||
|  |                 "    <body>\n" + | ||||||
|  |                 "        <h1>DL4J Arbiter UI</h1>\n" + | ||||||
|  |                 "        <p>UI server is in multi-session mode." + | ||||||
|  |                 " To visualize an optimization session, please select one from the following list.</p>\n" + | ||||||
|  |                 "        <h2>List of attached optimization sessions</h2>\n"); | ||||||
|  |         if (!knownSessionIDs.isEmpty()) { | ||||||
|  |             sb.append("        <ul>"); | ||||||
|  |             for (String sessionId : knownSessionIDs.keySet()) { | ||||||
|  |                 sb.append("            <li><a href=\"/arbiter/") | ||||||
|  |                         .append(sessionId).append("\">") | ||||||
|  |                         .append(sessionId).append("</a></li>\n"); | ||||||
|  |             } | ||||||
|  |             sb.append("        </ul>"); | ||||||
|  |         } else { | ||||||
|  |             sb.append("No optimization session attached."); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         sb.append("    </body>\n" + | ||||||
|  |                 "</html>\n"); | ||||||
|  | 
 | ||||||
|  |         rc.response() | ||||||
|  |                 .putHeader("content-type", "text/html; charset=utf-8") | ||||||
|  |                 .end(sb.toString()); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
| @ -201,7 +317,7 @@ public class ArbiterModule implements UIModule { | |||||||
|                 .end(asJson(sid)); |                 .end(asJson(sid)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private void listSessions(RoutingContext rc) { |     private void sessionInfo(RoutingContext rc) { | ||||||
|         rc.response() |         rc.response() | ||||||
|                 .putHeader("content-type", "application/json") |                 .putHeader("content-type", "application/json") | ||||||
|                 .end(asJson(knownSessionIDs.keySet())); |                 .end(asJson(knownSessionIDs.keySet())); | ||||||
| @ -257,10 +373,25 @@ public class ArbiterModule implements UIModule { | |||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Return the last update time for the page |      * Return the last update time for the page | ||||||
|  |      * @param sessionId session ID (optional, for multi-session mode) | ||||||
|  |      * @param rc routing context | ||||||
|      */ |      */ | ||||||
|     private void getLastUpdateTime(RoutingContext rc){ |     private void getLastUpdateTime(String sessionId, RoutingContext rc){ | ||||||
|         //TODO - this forces updates on every request... which is fine, just inefficient |         if (sessionId == null) { | ||||||
|         long t = System.currentTimeMillis(); |             sessionId = currentSessionID; | ||||||
|  |         } | ||||||
|  |         StatsStorage ss = knownSessionIDs.get(sessionId); | ||||||
|  |         List<Persistable> latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID); | ||||||
|  |         long t = 0; | ||||||
|  |         if (latestUpdates.isEmpty()) { | ||||||
|  |             t = System.currentTimeMillis(); | ||||||
|  |         } else { | ||||||
|  |             for (Persistable update : latestUpdates) { | ||||||
|  |                 if (update.getTimeStamp() > t) { | ||||||
|  |                     t = update.getTimeStamp(); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|         UpdateStatus us = new UpdateStatus(t, t, t); |         UpdateStatus us = new UpdateStatus(t, t, t); | ||||||
| 
 | 
 | ||||||
|         rc.response().putHeader("content-type", "application/json").end(asJson(us)); |         rc.response().putHeader("content-type", "application/json").end(asJson(us)); | ||||||
| @ -274,56 +405,28 @@ public class ArbiterModule implements UIModule { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |  | ||||||
|      * Get the last update time for the specified model IDs |  | ||||||
|      * @param modelIDs Model IDs to get the update time for |  | ||||||
|      */ |  | ||||||
|     private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){ |  | ||||||
|         if(currentSessionID == null){ |  | ||||||
|             rc.response().end(); |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         StatsStorage ss = knownSessionIDs.get(currentSessionID); |  | ||||||
|         if(ss == null){ |  | ||||||
|             log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); |  | ||||||
|             rc.response().end("-1"); |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         String[] split = modelIDs.split(","); |  | ||||||
| 
 |  | ||||||
|         long[] lastUpdateTimes = new long[split.length]; |  | ||||||
|         for( int i=0; i<split.length; i++ ){ |  | ||||||
|             String s = split[i]; |  | ||||||
|             Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, s); |  | ||||||
|             if(p != null){ |  | ||||||
|                 lastUpdateTimes[i] = p.getTimeStamp(); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         rc.response().putHeader("content-type", "application/json").end(asJson(lastUpdateTimes)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |     /** | ||||||
|      * Get the info for a specific candidate - last section in the UI |      * Get the info for a specific candidate - last section in the UI | ||||||
|      * |      * @param sessionId session ID (optional, for multi-session mode) | ||||||
|      * @param candidateId ID for the candidate |      * @param candidateId ID for the candidate | ||||||
|      * @return Content/info for the candidate |      * @param rc routing context | ||||||
|      */ |      */ | ||||||
|     private void getCandidateInfo(String candidateId, RoutingContext rc){ |     private void getCandidateInfo(String sessionId, String candidateId, RoutingContext rc){ | ||||||
| 
 |         if (sessionId == null) { | ||||||
|         StatsStorage ss = knownSessionIDs.get(currentSessionID); |             sessionId = currentSessionID; | ||||||
|  |         } | ||||||
|  |         StatsStorage ss = knownSessionIDs.get(sessionId); | ||||||
|         if(ss == null){ |         if(ss == null){ | ||||||
|             log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); |             log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", sessionId); | ||||||
|             rc.response().end(); |             rc.response().end(); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);; |         GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss | ||||||
|  |                 .getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); | ||||||
|         OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); |         OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); | ||||||
| 
 | 
 | ||||||
|         Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId); |         Persistable p = ss.getLatestUpdate(sessionId, ARBITER_UI_TYPE_ID, candidateId); | ||||||
|         if(p == null){ |         if(p == null){ | ||||||
|             String title = "No results found for model " + candidateId + "."; |             String title = "No results found for model " + candidateId + "."; | ||||||
|             ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build(); |             ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build(); | ||||||
| @ -493,17 +596,21 @@ public class ArbiterModule implements UIModule { | |||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Get the optimization configuration - second section in the page |      * Get the optimization configuration - second section in the page | ||||||
|  |      * @param sessionId session ID (optional, for multi-session mode) | ||||||
|  |      * @param rc routing context | ||||||
|      */ |      */ | ||||||
|     private void getOptimizationConfig(RoutingContext rc){ |     private void getOptimizationConfig(String sessionId, RoutingContext rc){ | ||||||
| 
 |         if (sessionId == null) { | ||||||
|         StatsStorage ss = knownSessionIDs.get(currentSessionID); |             sessionId = currentSessionID; | ||||||
|  |         } | ||||||
|  |         StatsStorage ss = knownSessionIDs.get(sessionId); | ||||||
|         if(ss == null){ |         if(ss == null){ | ||||||
|             log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); |             log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId); | ||||||
|             rc.response().end(); |             rc.response().end(); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); |         Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); | ||||||
| 
 | 
 | ||||||
|         if(p == null){ |         if(p == null){ | ||||||
|             log.debug("No static info"); |             log.debug("No static info"); | ||||||
| @ -593,15 +700,23 @@ public class ArbiterModule implements UIModule { | |||||||
|         rc.response().putHeader("content-type", "application/json").end(asJson(cd)); |         rc.response().putHeader("content-type", "application/json").end(asJson(cd)); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private void getSummaryResults(RoutingContext rc){ |     /** | ||||||
|         StatsStorage ss = knownSessionIDs.get(currentSessionID); |      * Get candidates summary results list - third section on the page: Results table | ||||||
|  |      * @param sessionId session ID (optional, for multi-session mode) | ||||||
|  |      * @param rc routing context | ||||||
|  |      */ | ||||||
|  |     private void getSummaryResults(String sessionId, RoutingContext rc){ | ||||||
|  |         if (sessionId == null) { | ||||||
|  |             sessionId = currentSessionID; | ||||||
|  |         } | ||||||
|  |         StatsStorage ss = knownSessionIDs.get(sessionId); | ||||||
|         if(ss == null){ |         if(ss == null){ | ||||||
|             log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID); |             log.debug("getSummaryResults(): Session ID is unknown: {}", sessionId); | ||||||
|             rc.response().end(); |             rc.response().end(); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); |         List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); | ||||||
|         List<String[]> table = new ArrayList<>(); |         List<String[]> table = new ArrayList<>(); | ||||||
|         for(Persistable per : allModelInfoTemp){ |         for(Persistable per : allModelInfoTemp){ | ||||||
|             ModelInfoPersistable mip = (ModelInfoPersistable)per; |             ModelInfoPersistable mip = (ModelInfoPersistable)per; | ||||||
| @ -614,16 +729,21 @@ public class ArbiterModule implements UIModule { | |||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Get summary status information: first section in the page |      * Get summary status information: first section in the page | ||||||
|  |      * @param sessionId session ID (optional, for multi-session mode) | ||||||
|  |      * @param rc routing context | ||||||
|      */ |      */ | ||||||
|     private void getSummaryStatus(RoutingContext rc){ |     private void getSummaryStatus(String sessionId, RoutingContext rc){ | ||||||
|         StatsStorage ss = knownSessionIDs.get(currentSessionID); |         if (sessionId == null) { | ||||||
|  |             sessionId = currentSessionID; | ||||||
|  |         } | ||||||
|  |         StatsStorage ss = knownSessionIDs.get(sessionId); | ||||||
|         if(ss == null){ |         if(ss == null){ | ||||||
|             log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); |             log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId); | ||||||
|             rc.response().end(); |             rc.response().end(); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); |         Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); | ||||||
| 
 | 
 | ||||||
|         if(p == null){ |         if(p == null){ | ||||||
|             log.info("No static info"); |             log.info("No static info"); | ||||||
| @ -643,7 +763,7 @@ public class ArbiterModule implements UIModule { | |||||||
| 
 | 
 | ||||||
|         //How to get this? query all model infos... |         //How to get this? query all model infos... | ||||||
| 
 | 
 | ||||||
|         List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); |         List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); | ||||||
|         List<ModelInfoPersistable> allModelInfo = new ArrayList<>(); |         List<ModelInfoPersistable> allModelInfo = new ArrayList<>(); | ||||||
|         for(Persistable per : allModelInfoTemp){ |         for(Persistable per : allModelInfoTemp){ | ||||||
|             ModelInfoPersistable mip = (ModelInfoPersistable)per; |             ModelInfoPersistable mip = (ModelInfoPersistable)per; | ||||||
| @ -668,7 +788,6 @@ public class ArbiterModule implements UIModule { | |||||||
| 
 | 
 | ||||||
|         //TODO: I18N |         //TODO: I18N | ||||||
| 
 | 
 | ||||||
|         //TODO don't use currentTimeMillis due to stored data?? |  | ||||||
|         long bestTime; |         long bestTime; | ||||||
|         Double bestScore = null; |         Double bestScore = null; | ||||||
|         String bestModelString = null; |         String bestModelString = null; | ||||||
| @ -685,7 +804,12 @@ public class ArbiterModule implements UIModule { | |||||||
|         String execTotalRuntimeStr = ""; |         String execTotalRuntimeStr = ""; | ||||||
|         if(execStartTime > 0){ |         if(execStartTime > 0){ | ||||||
|             execStartTimeStr = TIME_FORMATTER.print(execStartTime); |             execStartTimeStr = TIME_FORMATTER.print(execStartTime); | ||||||
|             execTotalRuntimeStr = UIUtils.formatDuration(System.currentTimeMillis() - execStartTime); |             // allModelInfo is sorted by Persistable::getTimeStamp | ||||||
|  |             long lastCompleteTime = execStartTime; | ||||||
|  |             if (!allModelInfo.isEmpty()) { | ||||||
|  |                 lastCompleteTime = allModelInfo.get(allModelInfo.size() - 1).getTimeStamp(); | ||||||
|  |             } | ||||||
|  |             execTotalRuntimeStr = UIUtils.formatDuration(lastCompleteTime - execStartTime); | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -198,27 +198,38 @@ | |||||||
| 
 | 
 | ||||||
|     var selectedCandidateIdx = null; |     var selectedCandidateIdx = null; | ||||||
|      |      | ||||||
|     //Set basic interval function to do updates |     //Multi-session mode | ||||||
|     setInterval(doUpdate,5000);    //Loop every 5 seconds |     var multiSession = null; | ||||||
|  |     //Session selection | ||||||
|  |     var currSession = ""; | ||||||
|      |      | ||||||
|  |     function getSessionIdFromUrl() { | ||||||
|  |         // path is like /arbiter/:sessionId/overview | ||||||
|  |         var sessionIdRegexp = /\/arbiter\/([^\/]+)/g; | ||||||
|  |         var match = sessionIdRegexp.exec(window.location.pathname) | ||||||
|  |         return match[1]; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     function doUpdate(){ |     function getCurrSession(callback) { | ||||||
|         //Get the update status, and do something with it: |         if (multiSession) { | ||||||
|         $.get("/arbiter/lastUpdate",function(data){ |             if (currSession == "") { | ||||||
|             //Encoding: matches names in UpdateStatus class |                 // get only once | ||||||
|             var jsonObj = JSON.parse(JSON.stringify(data)); |                 currSession = getSessionIdFromUrl(); | ||||||
|             var statusTime = jsonObj['statusUpdateTime']; |             } | ||||||
|             var settingsTime = jsonObj['settingsUpdateTime']; |             //we don't show session selector in multi-session mode (one can list sessions at /arbiter) | ||||||
|             var resultsTime = jsonObj['resultsUpdateTime']; |             callback(); | ||||||
|             //console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime); |         } else { | ||||||
| 
 |             $.ajax({ | ||||||
|             //Update available sessions: |                 url: "/arbiter/sessions/current", | ||||||
|             var currSession; |                 async: true, | ||||||
|             $.get("/arbiter/sessions/current", function(data){ |                 error: function (query, status, error) { | ||||||
|                 currSession = data; //JSON.stringify(data); |                     console.log("Error getting data: " + error); | ||||||
|                 console.log("Current: " + currSession); |                 }, | ||||||
|             }); |                 success: function (data) { | ||||||
|  |                     currSession = data; | ||||||
|  |                     console.log("Current session: " + currSession); | ||||||
|                      |                      | ||||||
|  |                     //Update available sessions in session selector | ||||||
|                     $.get("/arbiter/sessions/all", function(data){ |                     $.get("/arbiter/sessions/all", function(data){ | ||||||
|                         var keys = data;    // JSON.stringify(data); |                         var keys = data;    // JSON.stringify(data); | ||||||
| 
 | 
 | ||||||
| @ -240,14 +251,55 @@ | |||||||
|                             $("#sessionSelectDiv").show(); |                             $("#sessionSelectDiv").show(); | ||||||
|                         } |                         } | ||||||
|         //                console.log("Got sessions: " + keys + ", current: " + currSession); |         //                console.log("Got sessions: " + keys + ", current: " + currSession); | ||||||
|  |                         callback(); | ||||||
|                     }); |                     }); | ||||||
|                      |                      | ||||||
|  |                 } | ||||||
|  |             }); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     function getSessionSettings(callback) { | ||||||
|  |         // load only once | ||||||
|  |         if (multiSession != null) { | ||||||
|  |             getCurrSession(callback); | ||||||
|  |         } else { | ||||||
|  |             $.ajax({ | ||||||
|  |                 url: "/arbiter/multisession", | ||||||
|  |                 async: true, | ||||||
|  |                 error: function (query, status, error) { | ||||||
|  |                     console.log("Error getting data: " + error); | ||||||
|  |                 }, | ||||||
|  |                 success: function (data) { | ||||||
|  |                     multiSession = data == "true"; | ||||||
|  |                     getCurrSession(callback); | ||||||
|  |                 } | ||||||
|  |             }); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     //Initial update | ||||||
|  |     doUpdate(); | ||||||
|  |     //Set basic interval function to do updates | ||||||
|  |     setInterval(doUpdate,5000);    //Loop every 5 seconds | ||||||
|  |      | ||||||
|  |     function doUpdate(){ | ||||||
|  |         //Get the update status, and do something with it: | ||||||
|  |         getSessionSettings(function(){ | ||||||
|  |             var sessionUpdateUrl = multiSession ? "/arbiter/" + currSession + "/lastUpdate" : "/arbiter/lastUpdate"; | ||||||
|  |             $.get(sessionUpdateUrl,function(data){ | ||||||
|  |                 //Encoding: matches names in UpdateStatus class | ||||||
|  |                 var jsonObj = JSON.parse(JSON.stringify(data)); | ||||||
|  |                 var statusTime = jsonObj['statusUpdateTime']; | ||||||
|  |                 var settingsTime = jsonObj['settingsUpdateTime']; | ||||||
|  |                 var resultsTime = jsonObj['resultsUpdateTime']; | ||||||
|  |                 //console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime); | ||||||
| 
 | 
 | ||||||
|                 //Check last update times for each part of document, and update as necessary |                 //Check last update times for each part of document, and update as necessary | ||||||
|                 //First section: summary status |                 //First section: summary status | ||||||
|                 if(lastStatusUpdateTime != statusTime){ |                 if(lastStatusUpdateTime != statusTime){ | ||||||
|                 //Get JSON: address set by SummaryStatusResource |                     var summaryStatusUrl = multiSession ? "/arbiter/" + currSession + "/summary" : "/arbiter/summary"; | ||||||
|                 $.get("/arbiter/summary",function(data){ |                     $.get(summaryStatusUrl,function(data){ | ||||||
|                         var summaryStatusDiv = $('#statusdiv'); |                         var summaryStatusDiv = $('#statusdiv'); | ||||||
|                         summaryStatusDiv.html(''); |                         summaryStatusDiv.html(''); | ||||||
| 
 | 
 | ||||||
| @ -262,7 +314,8 @@ | |||||||
|                 //Second section: Optimization settings |                 //Second section: Optimization settings | ||||||
|                 if(lastSettingsUpdateTime != settingsTime){ |                 if(lastSettingsUpdateTime != settingsTime){ | ||||||
|                     //Get JSON for components |                     //Get JSON for components | ||||||
|                 $.get("/arbiter/config",function(data){ |                     var settingsUrl = multiSession ? "/arbiter/" + currSession + "/config" : "/arbiter/config"; | ||||||
|  |                     $.get(settingsUrl,function(data){ | ||||||
|                         var str = JSON.stringify(data); |                         var str = JSON.stringify(data); | ||||||
| 
 | 
 | ||||||
|                         var configDiv = $('#settingsdiv'); |                         var configDiv = $('#settingsdiv'); | ||||||
| @ -277,9 +330,9 @@ | |||||||
| 
 | 
 | ||||||
|                 //Third section: Summary results table (summary info for each candidate) |                 //Third section: Summary results table (summary info for each candidate) | ||||||
|                 if(lastResultsUpdateTime != resultsTime){ |                 if(lastResultsUpdateTime != resultsTime){ | ||||||
| 
 |                     //Get JSON for results table | ||||||
|                 //Get JSON; address set by SummaryResultsResource |                     var resultsUrl = multiSession ? "/arbiter/" + currSession + "/results" : "/arbiter/results"; | ||||||
|                 $.get("/arbiter/results",function(data){ |                     $.get(resultsUrl,function(data){ | ||||||
|                         //Expect an array of CandidateInfo type objects here |                         //Expect an array of CandidateInfo type objects here | ||||||
|                         resultsTableContent = data; |                         resultsTableContent = data; | ||||||
|                         drawResultTable(); |                         drawResultTable(); | ||||||
| @ -291,7 +344,10 @@ | |||||||
|                 //Finally: Currently selected result |                 //Finally: Currently selected result | ||||||
|                 if(selectedCandidateIdx != null){ |                 if(selectedCandidateIdx != null){ | ||||||
|                     //Get JSON for components |                     //Get JSON for components | ||||||
|                 $.get("/arbiter/candidateInfo/"+selectedCandidateIdx,function(data){ |                     var candidateInfoUrl = multiSession  | ||||||
|  |                         ? "/arbiter/" + currSession + "/candidateInfo/" + selectedCandidateIdx | ||||||
|  |                         : "/arbiter/candidateInfo/" + selectedCandidateIdx; | ||||||
|  |                     $.get(candidateInfoUrl,function(data){ | ||||||
|                         var str = JSON.stringify(data); |                         var str = JSON.stringify(data); | ||||||
| 
 | 
 | ||||||
|                         var resultsViewDiv = $('#resultsviewdiv'); |                         var resultsViewDiv = $('#resultsviewdiv'); | ||||||
| @ -302,6 +358,7 @@ | |||||||
|                     }); |                     }); | ||||||
|                 } |                 } | ||||||
|             }) |             }) | ||||||
|  |         }) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     function createTable(tableObj,tableId,appendTo){ |     function createTable(tableObj,tableId,appendTo){ | ||||||
|  | |||||||
| @ -16,6 +16,9 @@ | |||||||
| 
 | 
 | ||||||
| package org.deeplearning4j.arbiter.optimize; | package org.deeplearning4j.arbiter.optimize; | ||||||
| 
 | 
 | ||||||
|  | import io.netty.handler.codec.http.HttpResponseStatus; | ||||||
|  | import lombok.NonNull; | ||||||
|  | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.deeplearning4j.BaseDL4JTest; | import org.deeplearning4j.BaseDL4JTest; | ||||||
| import org.deeplearning4j.core.storage.StatsStorage; | import org.deeplearning4j.core.storage.StatsStorage; | ||||||
| import org.deeplearning4j.arbiter.ComputationGraphSpace; | import org.deeplearning4j.arbiter.ComputationGraphSpace; | ||||||
| @ -54,6 +57,7 @@ import org.deeplearning4j.ui.api.UIServer; | |||||||
| import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; | import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; | ||||||
| import org.junit.Ignore; | import org.junit.Ignore; | ||||||
| import org.junit.Test; | import org.junit.Test; | ||||||
|  | import org.nd4j.common.function.Function; | ||||||
| import org.nd4j.evaluation.classification.Evaluation; | import org.nd4j.evaluation.classification.Evaluation; | ||||||
| import org.nd4j.linalg.activations.Activation; | import org.nd4j.linalg.activations.Activation; | ||||||
| import org.nd4j.linalg.api.buffer.DataType; | import org.nd4j.linalg.api.buffer.DataType; | ||||||
| @ -62,52 +66,43 @@ import org.nd4j.linalg.factory.Nd4j; | |||||||
| import org.nd4j.linalg.lossfunctions.LossFunctions; | import org.nd4j.linalg.lossfunctions.LossFunctions; | ||||||
| 
 | 
 | ||||||
| import java.io.File; | import java.io.File; | ||||||
| import java.util.Collections; | import java.io.IOException; | ||||||
| import java.util.HashMap; | import java.io.UnsupportedEncodingException; | ||||||
| import java.util.Map; | import java.net.HttpURLConnection; | ||||||
| import java.util.Properties; | import java.net.URL; | ||||||
|  | import java.net.URLEncoder; | ||||||
|  | import java.util.*; | ||||||
| import java.util.concurrent.TimeUnit; | import java.util.concurrent.TimeUnit; | ||||||
| 
 | 
 | ||||||
|  | import static org.junit.Assert.assertEquals; | ||||||
|  | import static org.junit.Assert.assertTrue; | ||||||
|  | 
 | ||||||
| /** | /** | ||||||
|  * Created by Alex on 19/07/2017. |  * Created by Alex on 19/07/2017. | ||||||
|  */ |  */ | ||||||
|  | @Slf4j | ||||||
| public class TestBasic extends BaseDL4JTest { | public class TestBasic extends BaseDL4JTest { | ||||||
| 
 | 
 | ||||||
|  |     @Override | ||||||
|  |     public long getTimeoutMilliseconds() { | ||||||
|  |         return 3600_000L; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     @Test |     @Test | ||||||
|     @Ignore |     @Ignore | ||||||
|     public void testBasicUiOnly() throws Exception { |     public void testBasicUiOnly() throws Exception { | ||||||
| 
 | 
 | ||||||
|         UIServer.getInstance(); |         UIServer.getInstance(); | ||||||
| 
 | 
 | ||||||
|         Thread.sleep(1000000); |         Thread.sleep(1000_000); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
|     @Test |     @Test | ||||||
|     @Ignore |     @Ignore | ||||||
|     public void testBasicMnist() throws Exception { |     public void testBasicMnist() throws Exception { | ||||||
|         Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); |         Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); | ||||||
| 
 | 
 | ||||||
|         MultiLayerSpace mls = new MultiLayerSpace.Builder() |         MultiLayerSpace mls = getMultiLayerSpaceMnist(); | ||||||
|                 .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) |  | ||||||
|                 .l2(new ContinuousParameterSpace(0.0001, 0.05)) |  | ||||||
|                 .addLayer( |  | ||||||
|                         new ConvolutionLayerSpace.Builder().nIn(1) |  | ||||||
|                                 .nOut(new IntegerParameterSpace(5, 30)) |  | ||||||
|                                 .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, |  | ||||||
|                                         new int[]{4, 4}, new int[]{5, 5})) |  | ||||||
|                                 .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, |  | ||||||
|                                         new int[]{2, 2})) |  | ||||||
|                                 .activation(new DiscreteParameterSpace<>(Activation.RELU, |  | ||||||
|                                         Activation.SOFTPLUS, Activation.LEAKYRELU)) |  | ||||||
|                                 .build()) |  | ||||||
|                 .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) |  | ||||||
|                         .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) |  | ||||||
|                         .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers |  | ||||||
|                 .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) |  | ||||||
|                         .lossFunction(LossFunctions.LossFunction.MCXENT).build()) |  | ||||||
|                 .setInputType(InputType.convolutionalFlat(28, 28, 1)) |  | ||||||
|                 .build(); |  | ||||||
|         Map<String, Object> commands = new HashMap<>(); |         Map<String, Object> commands = new HashMap<>(); | ||||||
| //        commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); | //        commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); | ||||||
| 
 | 
 | ||||||
| @ -144,7 +139,30 @@ public class TestBasic extends BaseDL4JTest { | |||||||
|         UIServer.getInstance().attach(ss); |         UIServer.getInstance().attach(ss); | ||||||
| 
 | 
 | ||||||
|         runner.execute(); |         runner.execute(); | ||||||
|         Thread.sleep(100000); |         Thread.sleep(1000_000); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private static MultiLayerSpace getMultiLayerSpaceMnist() { | ||||||
|  |         return new MultiLayerSpace.Builder() | ||||||
|  |                 .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) | ||||||
|  |                 .l2(new ContinuousParameterSpace(0.0001, 0.05)) | ||||||
|  |                 .addLayer( | ||||||
|  |                         new ConvolutionLayerSpace.Builder().nIn(1) | ||||||
|  |                                 .nOut(new IntegerParameterSpace(5, 30)) | ||||||
|  |                                 .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, | ||||||
|  |                                         new int[]{4, 4}, new int[]{5, 5})) | ||||||
|  |                                 .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, | ||||||
|  |                                         new int[]{2, 2})) | ||||||
|  |                                 .activation(new DiscreteParameterSpace<>(Activation.RELU, | ||||||
|  |                                         Activation.SOFTPLUS, Activation.LEAKYRELU)) | ||||||
|  |                                 .build()) | ||||||
|  |                 .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) | ||||||
|  |                         .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) | ||||||
|  |                         .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers | ||||||
|  |                 .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) | ||||||
|  |                         .lossFunction(LossFunctions.LossFunction.MCXENT).build()) | ||||||
|  |                 .setInputType(InputType.convolutionalFlat(28, 28, 1)) | ||||||
|  |                 .build(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Test |     @Test | ||||||
| @ -233,7 +251,7 @@ public class TestBasic extends BaseDL4JTest { | |||||||
|                 .build(); |                 .build(); | ||||||
| 
 | 
 | ||||||
|         //Define configuration: |         //Define configuration: | ||||||
|         CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); |         CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs); | ||||||
|         DataProvider dataProvider = new MnistDataSetProvider(); |         DataProvider dataProvider = new MnistDataSetProvider(); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -331,7 +349,7 @@ public class TestBasic extends BaseDL4JTest { | |||||||
|         UIServer.getInstance().attach(ss); |         UIServer.getInstance().attach(ss); | ||||||
| 
 | 
 | ||||||
|         runner.execute(); |         runner.execute(); | ||||||
|         Thread.sleep(100000); |         Thread.sleep(1000_000); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -396,7 +414,7 @@ public class TestBasic extends BaseDL4JTest { | |||||||
|         UIServer.getInstance().attach(ss); |         UIServer.getInstance().attach(ss); | ||||||
| 
 | 
 | ||||||
|         runner.execute(); |         runner.execute(); | ||||||
|         Thread.sleep(100000); |         Thread.sleep(1000_000); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -433,7 +451,7 @@ public class TestBasic extends BaseDL4JTest { | |||||||
|                 .build(); |                 .build(); | ||||||
| 
 | 
 | ||||||
|         //Define configuration: |         //Define configuration: | ||||||
|         CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); |         CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs); | ||||||
|         DataProvider dataProvider = new MnistDataSetProvider(); |         DataProvider dataProvider = new MnistDataSetProvider(); | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @ -465,13 +483,17 @@ public class TestBasic extends BaseDL4JTest { | |||||||
|         UIServer.getInstance().attach(ss); |         UIServer.getInstance().attach(ss); | ||||||
| 
 | 
 | ||||||
|         runner.execute(); |         runner.execute(); | ||||||
|         Thread.sleep(100000); |         Thread.sleep(1000_000); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |     /** | ||||||
|  |      * Visualize multiple optimization sessions run one after another on single-session mode UI | ||||||
|  |      * @throws InterruptedException if current thread has been interrupted | ||||||
|  |      */ | ||||||
|     @Test |     @Test | ||||||
|     @Ignore |     @Ignore | ||||||
|     public void testBasicMnistMultipleSessions() throws Exception { |     public void testBasicMnistMultipleSessions() throws InterruptedException { | ||||||
| 
 | 
 | ||||||
|         MultiLayerSpace mls = new MultiLayerSpace.Builder() |         MultiLayerSpace mls = new MultiLayerSpace.Builder() | ||||||
|                 .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) |                 .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) | ||||||
| @ -499,8 +521,10 @@ public class TestBasic extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|         //Define configuration: |         //Define configuration: | ||||||
|         CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); |         CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); | ||||||
|         DataProvider dataProvider = new MnistDataSetProvider(); |  | ||||||
| 
 | 
 | ||||||
|  |         Class<? extends DataSource> ds = MnistDataSource.class; | ||||||
|  |         Properties dsp = new Properties(); | ||||||
|  |         dsp.setProperty("minibatch", "8"); | ||||||
| 
 | 
 | ||||||
|         String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); |         String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); | ||||||
| 
 | 
 | ||||||
| @ -513,7 +537,7 @@ public class TestBasic extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|         OptimizationConfiguration configuration = |         OptimizationConfiguration configuration = | ||||||
|                 new OptimizationConfiguration.Builder() |                 new OptimizationConfiguration.Builder() | ||||||
|                         .candidateGenerator(candidateGenerator).dataProvider(dataProvider) |                         .candidateGenerator(candidateGenerator).dataSource(ds, dsp) | ||||||
|                         .modelSaver(new FileModelSaver(modelSavePath)) |                         .modelSaver(new FileModelSaver(modelSavePath)) | ||||||
|                         .scoreFunction(new TestSetLossScoreFunction(true)) |                         .scoreFunction(new TestSetLossScoreFunction(true)) | ||||||
|                         .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), |                         .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), | ||||||
| @ -535,7 +559,7 @@ public class TestBasic extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|         candidateGenerator = new RandomSearchGenerator(mls, commands); |         candidateGenerator = new RandomSearchGenerator(mls, commands); | ||||||
|         configuration = new OptimizationConfiguration.Builder() |         configuration = new OptimizationConfiguration.Builder() | ||||||
|                 .candidateGenerator(candidateGenerator).dataProvider(dataProvider) |                 .candidateGenerator(candidateGenerator).dataSource(ds, dsp) | ||||||
|                 .modelSaver(new FileModelSaver(modelSavePath)) |                 .modelSaver(new FileModelSaver(modelSavePath)) | ||||||
|                 .scoreFunction(new TestSetLossScoreFunction(true)) |                 .scoreFunction(new TestSetLossScoreFunction(true)) | ||||||
|                 .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), |                 .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), | ||||||
| @ -550,7 +574,148 @@ public class TestBasic extends BaseDL4JTest { | |||||||
| 
 | 
 | ||||||
|         runner.execute(); |         runner.execute(); | ||||||
| 
 | 
 | ||||||
|         Thread.sleep(100000); |         Thread.sleep(1000_000); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Auto-attach multiple optimization sessions to multi-session mode UI | ||||||
|  |      * @throws IOException if could not connect to the server | ||||||
|  |      */ | ||||||
|  |     @Test | ||||||
|  |     public void testUiMultiSessionAutoAttach() throws IOException { | ||||||
|  | 
 | ||||||
|  |         //Define configuration: | ||||||
|  |         MultiLayerSpace mls = getMultiLayerSpaceMnist(); | ||||||
|  |         CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); | ||||||
|  | 
 | ||||||
|  |         Class<? extends DataSource> ds = MnistDataSource.class; | ||||||
|  |         Properties dsp = new Properties(); | ||||||
|  |         dsp.setProperty("minibatch", "8"); | ||||||
|  | 
 | ||||||
|  |         String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\") | ||||||
|  |                 .getAbsolutePath(); | ||||||
|  | 
 | ||||||
|  |         File f = new File(modelSavePath); | ||||||
|  |         if (f.exists()) | ||||||
|  |             f.delete(); | ||||||
|  |         f.mkdir(); | ||||||
|  |         if (!f.exists()) | ||||||
|  |             throw new RuntimeException(); | ||||||
|  | 
 | ||||||
|  |         OptimizationConfiguration configuration = | ||||||
|  |                 new OptimizationConfiguration.Builder() | ||||||
|  |                         .candidateGenerator(candidateGenerator).dataSource(ds, dsp) | ||||||
|  |                         .modelSaver(new FileModelSaver(modelSavePath)) | ||||||
|  |                         .scoreFunction(new TestSetLossScoreFunction(true)) | ||||||
|  |                         .terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS), | ||||||
|  |                                 new MaxCandidatesCondition(1)) | ||||||
|  |                         .build(); | ||||||
|  | 
 | ||||||
|  |         IOptimizationRunner runner = | ||||||
|  |                 new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); | ||||||
|  | 
 | ||||||
|  |         // add 3 different sessions to the same execution | ||||||
|  |         HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>(); | ||||||
|  |         for (int i = 0; i < 3; i++) { | ||||||
|  |             StatsStorage ss = new InMemoryStatsStorage(); | ||||||
|  |             @NonNull String sessionId = "sid" + i; | ||||||
|  |             statsStorageForSession.put(sessionId, ss); | ||||||
|  |             StatusListener sl = new ArbiterStatusListener(sessionId, ss); | ||||||
|  |             runner.addListeners(sl); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get; | ||||||
|  |         UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); | ||||||
|  |         String serverAddress = uIServer.getAddress(); | ||||||
|  | 
 | ||||||
|  |         runner.execute(); | ||||||
|  | 
 | ||||||
|  |         for (String sessionId : statsStorageForSession.keySet()) { | ||||||
|  |             /* | ||||||
|  |              * Visiting /arbiter/:sessionId to auto-attach StatsStorage | ||||||
|  |              */ | ||||||
|  |             String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId); | ||||||
|  |             HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); | ||||||
|  |             conn.connect(); | ||||||
|  | 
 | ||||||
|  |             log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId)); | ||||||
|  |             assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); | ||||||
|  |             assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId))); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL | ||||||
|  |      * @throws Exception if an error occurred | ||||||
|  |      */ | ||||||
|  |     @Test | ||||||
|  |     @Ignore | ||||||
|  |     public void testUiMultiSessionManualAttach() throws Exception { | ||||||
|  |         Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); | ||||||
|  | 
 | ||||||
|  |         //Define configuration: | ||||||
|  |         MultiLayerSpace mls = getMultiLayerSpaceMnist(); | ||||||
|  |         CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); | ||||||
|  | 
 | ||||||
|  |         Class<? extends DataSource> ds = MnistDataSource.class; | ||||||
|  |         Properties dsp = new Properties(); | ||||||
|  |         dsp.setProperty("minibatch", "8"); | ||||||
|  | 
 | ||||||
|  |         String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\") | ||||||
|  |                 .getAbsolutePath(); | ||||||
|  | 
 | ||||||
|  |         File f = new File(modelSavePath); | ||||||
|  |         if (f.exists()) | ||||||
|  |             f.delete(); | ||||||
|  |         f.mkdir(); | ||||||
|  |         if (!f.exists()) | ||||||
|  |             throw new RuntimeException(); | ||||||
|  | 
 | ||||||
|  |         OptimizationConfiguration configuration = | ||||||
|  |                 new OptimizationConfiguration.Builder() | ||||||
|  |                         .candidateGenerator(candidateGenerator).dataSource(ds, dsp) | ||||||
|  |                         .modelSaver(new FileModelSaver(modelSavePath)) | ||||||
|  |                         .scoreFunction(new TestSetLossScoreFunction(true)) | ||||||
|  |                         .terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES), | ||||||
|  |                                 new MaxCandidatesCondition(10)) | ||||||
|  |                         .build(); | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         // parallel execution of multiple optimization sessions | ||||||
|  |         HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>(); | ||||||
|  |         for (int i = 0; i < 3; i++) { | ||||||
|  |             String sessionId = "sid" + i; | ||||||
|  |             IOptimizationRunner runner = | ||||||
|  |                     new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); | ||||||
|  |             StatsStorage ss = new InMemoryStatsStorage(); | ||||||
|  |             statsStorageForSession.put(sessionId, ss); | ||||||
|  |             StatusListener sl = new ArbiterStatusListener(sessionId, ss); | ||||||
|  |             runner.addListeners(sl); | ||||||
|  |             // Asynchronous execution | ||||||
|  |             new Thread(runner::execute).start(); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get; | ||||||
|  |         UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); | ||||||
|  |         String serverAddress = uIServer.getAddress(); | ||||||
|  | 
 | ||||||
|  |         for (String sessionId : statsStorageForSession.keySet()) { | ||||||
|  |             log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Thread.sleep(1000_000); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Get URL for arbiter session on given server address | ||||||
|  |      * @param serverAddress server address, e.g.: http://localhost:9000 | ||||||
|  |      * @param sessionId session ID (will be URL-encoded) | ||||||
|  |      * @return URL | ||||||
|  |      * @throws UnsupportedEncodingException if the character encoding is not supported | ||||||
|  |      */ | ||||||
|  |     private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException { | ||||||
|  |         return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8")); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     private static class MnistDataSetProvider implements DataProvider { |     private static class MnistDataSetProvider implements DataProvider { | ||||||
|  | |||||||
| @ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { | |||||||
|     private static Integer instancePort; |     private static Integer instancePort; | ||||||
|     private static Thread autoStopThread; |     private static Thread autoStopThread; | ||||||
| 
 | 
 | ||||||
|     private TrainModule trainModule; |  | ||||||
| 
 |  | ||||||
|     /** |     /** | ||||||
|      * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started. |      * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started. | ||||||
|      * @param port TCP socket port for {@link HttpServer} to listen |      * @param port TCP socket port for {@link HttpServer} to listen | ||||||
| @ -194,8 +192,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { | |||||||
|         VertxUIServer.autoStopThread = new Thread(() -> { |         VertxUIServer.autoStopThread = new Thread(() -> { | ||||||
|             try { |             try { | ||||||
|                 currentThread.join(); |                 currentThread.join(); | ||||||
|                 log.info("Deeplearning4j UI server is auto-stopping."); |  | ||||||
|                 if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) { |                 if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) { | ||||||
|  |                     log.info("Deeplearning4j UI server is auto-stopping after thread (name: {}) died.", | ||||||
|  |                             currentThread.getName()); | ||||||
|                     instance.stop(); |                     instance.stop(); | ||||||
|                 } |                 } | ||||||
|             } catch (InterruptedException e) { |             } catch (InterruptedException e) { | ||||||
| @ -207,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { | |||||||
| 
 | 
 | ||||||
|     private List<UIModule> uiModules = new CopyOnWriteArrayList<>(); |     private List<UIModule> uiModules = new CopyOnWriteArrayList<>(); | ||||||
|     private RemoteReceiverModule remoteReceiverModule; |     private RemoteReceiverModule remoteReceiverModule; | ||||||
|     private StatsStorageLoader statsStorageLoader; |     /** | ||||||
|  |      * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID | ||||||
|  |      */ | ||||||
|  |     @Getter | ||||||
|  |     private Function<String, Boolean> statsStorageLoader; | ||||||
| 
 | 
 | ||||||
|     //typeIDModuleMap: Records which modules are registered for which type IDs |     //typeIDModuleMap: Records which modules are registered for which type IDs | ||||||
|     private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>(); |     private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>(); | ||||||
| @ -247,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { | |||||||
|      */ |      */ | ||||||
|     public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) { |     public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) { | ||||||
|         if (statsStorageProvider != null) { |         if (statsStorageProvider != null) { | ||||||
|             this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider); |             this.statsStorageLoader = (sessionId) -> { | ||||||
|             if (trainModule != null) { |                 log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ")."); | ||||||
|                 this.trainModule.setSessionLoader(this.statsStorageLoader); |                 StatsStorage statsStorage = statsStorageProvider.apply(sessionId); | ||||||
|  |                 if (statsStorage != null) { | ||||||
|  |                     if (statsStorage.sessionExists(sessionId)) { | ||||||
|  |                         attach(statsStorage); | ||||||
|  |                         return true; | ||||||
|                     } |                     } | ||||||
|  |                     log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " + | ||||||
|  |                             "Session ID (" + sessionId + ") does not exist in StatsStorage."); | ||||||
|  |                     return false; | ||||||
|  |                 } else { | ||||||
|  |                     log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " + | ||||||
|  |                             "StatsStorageProvider returned null."); | ||||||
|  |                     return false; | ||||||
|  |                 } | ||||||
|  |             }; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| @ -302,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { | |||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|         uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" |         uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" | ||||||
|         trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress); |         uiModules.add(new TrainModule()); | ||||||
|         uiModules.add(trainModule); |  | ||||||
|         uiModules.add(new ConvolutionalListenerModule()); |         uiModules.add(new ConvolutionalListenerModule()); | ||||||
|         uiModules.add(new TsneModule()); |         uiModules.add(new TsneModule()); | ||||||
|         uiModules.add(new SameDiffModule()); |         uiModules.add(new SameDiffModule()); | ||||||
| @ -596,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |  | ||||||
|      * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID |  | ||||||
|      */ |  | ||||||
|     private class StatsStorageLoader implements Function<String, Boolean> { |  | ||||||
| 
 |  | ||||||
|         Function<String, StatsStorage> statsStorageProvider; |  | ||||||
| 
 |  | ||||||
|         StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) { |  | ||||||
|             this.statsStorageProvider = statsStorageProvider; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public Boolean apply(String sessionId) { |  | ||||||
|             log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ")."); |  | ||||||
|             StatsStorage statsStorage = statsStorageProvider.apply(sessionId); |  | ||||||
|             if (statsStorage != null) { |  | ||||||
|                 if (statsStorage.sessionExists(sessionId)) { |  | ||||||
|                     attach(statsStorage); |  | ||||||
|                     return true; |  | ||||||
|                 } |  | ||||||
|                 log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " + |  | ||||||
|                         "Session ID (" + sessionId + ") does not exist in StatsStorage."); |  | ||||||
|                 return false; |  | ||||||
|             } else { |  | ||||||
|                 log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " + |  | ||||||
|                         "StatsStorageProvider returned null."); |  | ||||||
|                 return false; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     //================================================================================================================== |     //================================================================================================================== | ||||||
|     // CLI Launcher |     // CLI Launcher | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -157,8 +157,9 @@ public interface UIServer { | |||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Stop/shut down the UI server. This synchronous function should wait until the server is stopped. |      * Stop/shut down the UI server. This synchronous function should wait until the server is stopped. | ||||||
|  |      * @throws InterruptedException if the current thread is interrupted while waiting | ||||||
|      */ |      */ | ||||||
|     void stop() throws Exception; |     void stop() throws InterruptedException; | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Stop/shut down the UI server. |      * Stop/shut down the UI server. | ||||||
|  | |||||||
| @ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext; | |||||||
| import it.unimi.dsi.fastutil.longs.LongArrayList; | import it.unimi.dsi.fastutil.longs.LongArrayList; | ||||||
| import lombok.AllArgsConstructor; | import lombok.AllArgsConstructor; | ||||||
| import lombok.Data; | import lombok.Data; | ||||||
| import lombok.Getter; |  | ||||||
| import lombok.Setter; |  | ||||||
| import lombok.extern.slf4j.Slf4j; | import lombok.extern.slf4j.Slf4j; | ||||||
| import org.apache.commons.io.FileUtils; | import org.apache.commons.io.FileUtils; | ||||||
| import org.apache.commons.io.FilenameUtils; | import org.apache.commons.io.FilenameUtils; | ||||||
| @ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex; | |||||||
| import org.deeplearning4j.nn.conf.graph.LayerVertex; | import org.deeplearning4j.nn.conf.graph.LayerVertex; | ||||||
| import org.deeplearning4j.nn.conf.layers.*; | import org.deeplearning4j.nn.conf.layers.*; | ||||||
| import org.deeplearning4j.nn.conf.serde.JsonMappers; | import org.deeplearning4j.nn.conf.serde.JsonMappers; | ||||||
|  | import org.deeplearning4j.ui.VertxUIServer; | ||||||
| import org.deeplearning4j.ui.api.HttpMethod; | import org.deeplearning4j.ui.api.HttpMethod; | ||||||
| import org.deeplearning4j.ui.api.I18N; | import org.deeplearning4j.ui.api.I18N; | ||||||
| import org.deeplearning4j.ui.api.Route; | import org.deeplearning4j.ui.api.Route; | ||||||
| @ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport; | |||||||
| import org.deeplearning4j.ui.model.stats.api.StatsReport; | import org.deeplearning4j.ui.model.stats.api.StatsReport; | ||||||
| import org.deeplearning4j.ui.model.stats.api.StatsType; | import org.deeplearning4j.ui.model.stats.api.StatsType; | ||||||
| import org.nd4j.common.function.Function; | import org.nd4j.common.function.Function; | ||||||
| import org.nd4j.common.function.Supplier; |  | ||||||
| import org.nd4j.linalg.learning.config.IUpdater; | import org.nd4j.linalg.learning.config.IUpdater; | ||||||
| import org.nd4j.common.primitives.Pair; | import org.nd4j.common.primitives.Pair; | ||||||
| import org.nd4j.common.primitives.Triple; | import org.nd4j.common.primitives.Triple; | ||||||
| @ -86,8 +84,6 @@ public class TrainModule implements UIModule { | |||||||
|     private static final DecimalFormat df2 = new DecimalFormat("#.00"); |     private static final DecimalFormat df2 = new DecimalFormat("#.00"); | ||||||
|     private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); |     private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); | ||||||
| 
 | 
 | ||||||
|     private final Supplier<String> addressSupplier; |  | ||||||
| 
 |  | ||||||
|     private enum ModelType { |     private enum ModelType { | ||||||
|         MLN, CG, Layer |         MLN, CG, Layer | ||||||
|     } |     } | ||||||
| @ -99,29 +95,14 @@ public class TrainModule implements UIModule { | |||||||
|     private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID |     private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID | ||||||
|     private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID |     private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID | ||||||
|     private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>(); |     private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>(); | ||||||
|     private final boolean multiSession; |  | ||||||
|     @Getter @Setter |  | ||||||
|     private Function<String, Boolean> sessionLoader; |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|     private final Configuration configuration; |     private final Configuration configuration; | ||||||
| 
 | 
 | ||||||
|     public TrainModule() { |  | ||||||
|         this(false, null, null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |     /** | ||||||
|      * TrainModule |      * TrainModule | ||||||
|      * |  | ||||||
|      * @param multiSession    multi-session mode |  | ||||||
|      * @param sessionLoader   StatsStorage loader to call if an unknown session ID is passed as URL path parameter |  | ||||||
|      *                        in multi-session mode |  | ||||||
|      * @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules) |  | ||||||
|      */ |      */ | ||||||
|     public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) { |     public TrainModule() { | ||||||
|         this.multiSession = multiSession; |  | ||||||
|         this.sessionLoader = sessionLoader; |  | ||||||
|         this.addressSupplier = addressSupplier; |  | ||||||
|         String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY); |         String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY); | ||||||
|         int value = DEFAULT_MAX_CHART_POINTS; |         int value = DEFAULT_MAX_CHART_POINTS; | ||||||
|         if (maxChartPointsProp != null) { |         if (maxChartPointsProp != null) { | ||||||
| @ -159,8 +140,9 @@ public class TrainModule implements UIModule { | |||||||
|     @Override |     @Override | ||||||
|     public List<Route> getRoutes() { |     public List<Route> getRoutes() { | ||||||
|         List<Route> r = new ArrayList<>(); |         List<Route> r = new ArrayList<>(); | ||||||
|         r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); |         r.add(new Route("/train/multisession", HttpMethod.GET, | ||||||
|         if (multiSession) { |                 (path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false"))); | ||||||
|  |         if (VertxUIServer.getInstance().isMultiSession()) { | ||||||
|             r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); |             r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); | ||||||
|             r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> { |             r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> { | ||||||
|                 rc.response() |                 rc.response() | ||||||
| @ -264,7 +246,9 @@ public class TrainModule implements UIModule { | |||||||
|         if (!knownSessionIDs.isEmpty()) { |         if (!knownSessionIDs.isEmpty()) { | ||||||
|             sb.append("        <ul>"); |             sb.append("        <ul>"); | ||||||
|             for (String sessionId : knownSessionIDs.keySet()) { |             for (String sessionId : knownSessionIDs.keySet()) { | ||||||
|                 sb.append("            <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n"); |                 sb.append("            <li><a href=\"/train/") | ||||||
|  |                         .append(sessionId).append("\">") | ||||||
|  |                         .append(sessionId).append("</a></li>\n"); | ||||||
|             } |             } | ||||||
|             sb.append("        </ul>"); |             sb.append("        </ul>"); | ||||||
|         } else { |         } else { | ||||||
| @ -284,9 +268,11 @@ public class TrainModule implements UIModule { | |||||||
|      * |      * | ||||||
|      * @param sessionId  session ID to look fo with provider |      * @param sessionId  session ID to look fo with provider | ||||||
|      * @param targetPath one of overview / model / system, or null |      * @param targetPath one of overview / model / system, or null | ||||||
|  |      * @param rc routing context | ||||||
|      */ |      */ | ||||||
|     private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { |     private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { | ||||||
|         if (sessionLoader != null && sessionLoader.apply(sessionId)) { |         Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader(); | ||||||
|  |         if (loader != null && loader.apply(sessionId)) { | ||||||
|             if (targetPath != null) { |             if (targetPath != null) { | ||||||
|                 rc.reroute(targetPath); |                 rc.reroute(targetPath); | ||||||
|             } else { |             } else { | ||||||
| @ -306,9 +292,9 @@ public class TrainModule implements UIModule { | |||||||
|                         && StatsListener.TYPE_ID.equals(sse.getTypeID()) |                         && StatsListener.TYPE_ID.equals(sse.getTypeID()) | ||||||
|                         && !knownSessionIDs.containsKey(sse.getSessionID())) { |                         && !knownSessionIDs.containsKey(sse.getSessionID())) { | ||||||
|                     knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage()); |                     knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage()); | ||||||
|                     if (multiSession) { |                     if (VertxUIServer.getInstance().isMultiSession()) { | ||||||
|                         log.info("Adding training session {}/train/{} of StatsStorage instance {}", |                         log.info("Adding training session {}/train/{} of StatsStorage instance {}", | ||||||
|                                 addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage()); |                                 VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage()); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
| @ -332,9 +318,9 @@ public class TrainModule implements UIModule { | |||||||
|                 if (!StatsListener.TYPE_ID.equals(typeID)) |                 if (!StatsListener.TYPE_ID.equals(typeID)) | ||||||
|                     continue; |                     continue; | ||||||
|                 knownSessionIDs.put(sessionID, statsStorage); |                 knownSessionIDs.put(sessionID, statsStorage); | ||||||
|                 if (multiSession) { |                 if (VertxUIServer.getInstance().isMultiSession()) { | ||||||
|                     log.info("Adding training session {}/train/{} of StatsStorage instance {}", |                     log.info("Adding training session {}/train/{} of StatsStorage instance {}", | ||||||
|                             addressSupplier.get(), sessionID, statsStorage); |                             VertxUIServer.getInstance().getAddress(), sessionID, statsStorage); | ||||||
|                 } |                 } | ||||||
| 
 | 
 | ||||||
|                 List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID); |                 List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID); | ||||||
| @ -364,9 +350,9 @@ public class TrainModule implements UIModule { | |||||||
|         } |         } | ||||||
|         for (String s : toRemove) { |         for (String s : toRemove) { | ||||||
|             knownSessionIDs.remove(s); |             knownSessionIDs.remove(s); | ||||||
|             if (multiSession) { |             if (VertxUIServer.getInstance().isMultiSession()) { | ||||||
|                 log.info("Removing training session {}/train/{} of StatsStorage instance {}.", |                 log.info("Removing training session {}/train/{} of StatsStorage instance {}.", | ||||||
|                         addressSupplier.get(), s, statsStorage); |                         VertxUIServer.getInstance().getAddress(), s, statsStorage); | ||||||
|             } |             } | ||||||
|             lastUpdateForSession.remove(s); |             lastUpdateForSession.remove(s); | ||||||
|         } |         } | ||||||
| @ -602,13 +588,13 @@ public class TrainModule implements UIModule { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session |      * Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session | ||||||
|      * |      * | ||||||
|      * @param sessionId session ID |      * @param sessionId session ID | ||||||
|      * @return {@link I18N} instance |      * @return {@link I18N} instance | ||||||
|      */ |      */ | ||||||
|     private I18N getI18N(String sessionId) { |     private I18N getI18N(String sessionId) { | ||||||
|         return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance(); |         return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -259,7 +259,7 @@ public class TestVertxUIManual extends BaseDL4JTest { | |||||||
|                         log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.", |                         log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.", | ||||||
|                                 sessionId, autoDetachTimeoutMillis); |                                 sessionId, autoDetachTimeoutMillis); | ||||||
|                         uIServer.detach(statsStorage); |                         uIServer.detach(statsStorage); | ||||||
|                         log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}", |                         log.info(" To re-attach StatsStorage of training session, visit {}/train/{}", | ||||||
|                                 uIServer.getAddress(), sessionId); |                                 uIServer.getAddress(), sessionId); | ||||||
|                     } |                     } | ||||||
|                 }).start(); |                 }).start(); | ||||||
|  | |||||||
| @ -197,9 +197,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     /** |     /** | ||||||
|      * Get URL-encoded URL for training session on given server address |      * Get URL for training session on given server address | ||||||
|      * @param serverAddress server address |      * @param serverAddress server address | ||||||
|      * @param sessionId session ID |      * @param sessionId session ID (will be URL-encoded) | ||||||
|      * @return URL |      * @return URL | ||||||
|      * @throws UnsupportedEncodingException if the used encoding is not supported |      * @throws UnsupportedEncodingException if the used encoding is not supported | ||||||
|      */ |      */ | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user