Merge pull request #8872 from printomi/ui_multisession_arbiter

DL4J UI: Add multi-session support for Arbiter
master
Alex Black 2020-05-11 19:24:17 +10:00 committed by GitHub
commit a76f957b72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 591 additions and 274 deletions

View File

@ -36,6 +36,7 @@ import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
import org.deeplearning4j.arbiter.ui.misc.UIUtils; import org.deeplearning4j.arbiter.ui.misc.UIUtils;
import org.deeplearning4j.arbiter.util.ObjectUtils; import org.deeplearning4j.arbiter.util.ObjectUtils;
import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.ui.VertxUIServer;
import org.deeplearning4j.ui.api.Component; import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.api.*; import org.deeplearning4j.ui.api.*;
import org.deeplearning4j.ui.components.chart.ChartLine; import org.deeplearning4j.ui.components.chart.ChartLine;
@ -50,6 +51,7 @@ import org.deeplearning4j.ui.components.text.style.StyleText;
import org.deeplearning4j.ui.i18n.I18NResource; import org.deeplearning4j.ui.i18n.I18NResource;
import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter; import org.joda.time.format.DateTimeFormatter;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.shade.jackson.core.JsonProcessingException; import org.nd4j.shade.jackson.core.JsonProcessingException;
@ -77,7 +79,6 @@ public class ArbiterModule implements UIModule {
private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>()); private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>());
//Styles for UI: //Styles for UI:
private static final StyleTable STYLE_TABLE = new StyleTable.Builder() private static final StyleTable STYLE_TABLE = new StyleTable.Builder()
.width(100, LengthUnit.Percent) .width(100, LengthUnit.Percent)
@ -134,20 +135,135 @@ public class ArbiterModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() boolean multiSession = VertxUIServer.getMultiSession().get();
.putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html")); List<Route> r = new ArrayList<>();
Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc)); r.add(new Route("/arbiter/multisession", HttpMethod.GET,
Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc)); (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc)); if (multiSession) {
Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc)); r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc)); r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> {
Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc)); if (knownSessionIDs.containsKey(path.get(0))) {
rc.response()
.putHeader("content-type", "text/html; charset=utf-8")
.sendFile("templates/ArbiterUI.html");
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc)); r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> {
Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)); if (knownSessionIDs.containsKey(path.get(0))) {
Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)); this.getLastUpdateTime(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getCandidateInfo(path.get(0), path.get(1), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getOptimizationConfig(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getSummaryResults(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
this.getSummaryStatus(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
} else {
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
.putHeader("content-type", "text/html; charset=utf-8")
.sendFile("templates/ArbiterUI.html")));
r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc)));
r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET,
(path, rc) -> this.getCandidateInfo(null, path.get(0), rc)));
r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc)));
r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc)));
r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc)));
return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c); r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)));
r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET,
(path, rc) -> this.setSession(path.get(0), rc)));
}
// common for single- and multi-session mode
r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
return r;
}
/**
* Load StatsStorage via provider, or return "not found"
*
* @param sessionId session ID to look fo with provider
* @param targetPath one of overview / model / system, or null
* @param rc routing context
*/
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
if (loader != null && loader.apply(sessionId)) {
if (targetPath != null) {
rc.reroute(targetPath);
} else {
rc.response().end();
}
} else {
rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
.end("Unknown session ID: " + sessionId);
}
}
/**
* List optimization sessions. Returns a HTML list of arbiter sessions
*/
private synchronized void listSessions(RoutingContext rc) {
StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" +
"<html lang=\"en\">\n" +
"<head>\n" +
" <meta charset=\"utf-8\">\n" +
" <title>Optimization sessions - DL4J Arbiter UI</title>\n" +
" </head>\n" +
"\n" +
" <body>\n" +
" <h1>DL4J Arbiter UI</h1>\n" +
" <p>UI server is in multi-session mode." +
" To visualize an optimization session, please select one from the following list.</p>\n" +
" <h2>List of attached optimization sessions</h2>\n");
if (!knownSessionIDs.isEmpty()) {
sb.append(" <ul>");
for (String sessionId : knownSessionIDs.keySet()) {
sb.append(" <li><a href=\"/arbiter/")
.append(sessionId).append("\">")
.append(sessionId).append("</a></li>\n");
}
sb.append(" </ul>");
} else {
sb.append("No optimization session attached.");
}
sb.append(" </body>\n" +
"</html>\n");
rc.response()
.putHeader("content-type", "text/html; charset=utf-8")
.end(sb.toString());
} }
@Override @Override
@ -201,7 +317,7 @@ public class ArbiterModule implements UIModule {
.end(asJson(sid)); .end(asJson(sid));
} }
private void listSessions(RoutingContext rc) { private void sessionInfo(RoutingContext rc) {
rc.response() rc.response()
.putHeader("content-type", "application/json") .putHeader("content-type", "application/json")
.end(asJson(knownSessionIDs.keySet())); .end(asJson(knownSessionIDs.keySet()));
@ -257,10 +373,25 @@ public class ArbiterModule implements UIModule {
/** /**
* Return the last update time for the page * Return the last update time for the page
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/ */
private void getLastUpdateTime(RoutingContext rc){ private void getLastUpdateTime(String sessionId, RoutingContext rc){
//TODO - this forces updates on every request... which is fine, just inefficient if (sessionId == null) {
long t = System.currentTimeMillis(); sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
List<Persistable> latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID);
long t = 0;
if (latestUpdates.isEmpty()) {
t = System.currentTimeMillis();
} else {
for (Persistable update : latestUpdates) {
if (update.getTimeStamp() > t) {
t = update.getTimeStamp();
}
}
}
UpdateStatus us = new UpdateStatus(t, t, t); UpdateStatus us = new UpdateStatus(t, t, t);
rc.response().putHeader("content-type", "application/json").end(asJson(us)); rc.response().putHeader("content-type", "application/json").end(asJson(us));
@ -274,56 +405,28 @@ public class ArbiterModule implements UIModule {
} }
} }
/**
* Get the last update time for the specified model IDs
* @param modelIDs Model IDs to get the update time for
*/
private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){
if(currentSessionID == null){
rc.response().end();
return;
}
StatsStorage ss = knownSessionIDs.get(currentSessionID);
if(ss == null){
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
rc.response().end("-1");
return;
}
String[] split = modelIDs.split(",");
long[] lastUpdateTimes = new long[split.length];
for( int i=0; i<split.length; i++ ){
String s = split[i];
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, s);
if(p != null){
lastUpdateTimes[i] = p.getTimeStamp();
}
}
rc.response().putHeader("content-type", "application/json").end(asJson(lastUpdateTimes));
}
/** /**
* Get the info for a specific candidate - last section in the UI * Get the info for a specific candidate - last section in the UI
* * @param sessionId session ID (optional, for multi-session mode)
* @param candidateId ID for the candidate * @param candidateId ID for the candidate
* @return Content/info for the candidate * @param rc routing context
*/ */
private void getCandidateInfo(String candidateId, RoutingContext rc){ private void getCandidateInfo(String sessionId, String candidateId, RoutingContext rc){
if (sessionId == null) {
StatsStorage ss = knownSessionIDs.get(currentSessionID); sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);; GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss
.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); OptimizationConfiguration oc = gcp.getOptimizationConfiguration();
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId); Persistable p = ss.getLatestUpdate(sessionId, ARBITER_UI_TYPE_ID, candidateId);
if(p == null){ if(p == null){
String title = "No results found for model " + candidateId + "."; String title = "No results found for model " + candidateId + ".";
ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build(); ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build();
@ -493,17 +596,21 @@ public class ArbiterModule implements UIModule {
/** /**
* Get the optimization configuration - second section in the page * Get the optimization configuration - second section in the page
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/ */
private void getOptimizationConfig(RoutingContext rc){ private void getOptimizationConfig(String sessionId, RoutingContext rc){
if (sessionId == null) {
StatsStorage ss = knownSessionIDs.get(currentSessionID); sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
if(p == null){ if(p == null){
log.debug("No static info"); log.debug("No static info");
@ -593,15 +700,23 @@ public class ArbiterModule implements UIModule {
rc.response().putHeader("content-type", "application/json").end(asJson(cd)); rc.response().putHeader("content-type", "application/json").end(asJson(cd));
} }
private void getSummaryResults(RoutingContext rc){ /**
StatsStorage ss = knownSessionIDs.get(currentSessionID); * Get candidates summary results list - third section on the page: Results table
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/
private void getSummaryResults(String sessionId, RoutingContext rc){
if (sessionId == null) {
sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID); log.debug("getSummaryResults(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
List<String[]> table = new ArrayList<>(); List<String[]> table = new ArrayList<>();
for(Persistable per : allModelInfoTemp){ for(Persistable per : allModelInfoTemp){
ModelInfoPersistable mip = (ModelInfoPersistable)per; ModelInfoPersistable mip = (ModelInfoPersistable)per;
@ -614,16 +729,21 @@ public class ArbiterModule implements UIModule {
/** /**
* Get summary status information: first section in the page * Get summary status information: first section in the page
* @param sessionId session ID (optional, for multi-session mode)
* @param rc routing context
*/ */
private void getSummaryStatus(RoutingContext rc){ private void getSummaryStatus(String sessionId, RoutingContext rc){
StatsStorage ss = knownSessionIDs.get(currentSessionID); if (sessionId == null) {
sessionId = currentSessionID;
}
StatsStorage ss = knownSessionIDs.get(sessionId);
if(ss == null){ if(ss == null){
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
rc.response().end(); rc.response().end();
return; return;
} }
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
if(p == null){ if(p == null){
log.info("No static info"); log.info("No static info");
@ -643,7 +763,7 @@ public class ArbiterModule implements UIModule {
//How to get this? query all model infos... //How to get this? query all model infos...
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
List<ModelInfoPersistable> allModelInfo = new ArrayList<>(); List<ModelInfoPersistable> allModelInfo = new ArrayList<>();
for(Persistable per : allModelInfoTemp){ for(Persistable per : allModelInfoTemp){
ModelInfoPersistable mip = (ModelInfoPersistable)per; ModelInfoPersistable mip = (ModelInfoPersistable)per;
@ -668,7 +788,6 @@ public class ArbiterModule implements UIModule {
//TODO: I18N //TODO: I18N
//TODO don't use currentTimeMillis due to stored data??
long bestTime; long bestTime;
Double bestScore = null; Double bestScore = null;
String bestModelString = null; String bestModelString = null;
@ -685,7 +804,12 @@ public class ArbiterModule implements UIModule {
String execTotalRuntimeStr = ""; String execTotalRuntimeStr = "";
if(execStartTime > 0){ if(execStartTime > 0){
execStartTimeStr = TIME_FORMATTER.print(execStartTime); execStartTimeStr = TIME_FORMATTER.print(execStartTime);
execTotalRuntimeStr = UIUtils.formatDuration(System.currentTimeMillis() - execStartTime); // allModelInfo is sorted by Persistable::getTimeStamp
long lastCompleteTime = execStartTime;
if (!allModelInfo.isEmpty()) {
lastCompleteTime = allModelInfo.get(allModelInfo.size() - 1).getTimeStamp();
}
execTotalRuntimeStr = UIUtils.formatDuration(lastCompleteTime - execStartTime);
} }

View File

@ -198,27 +198,38 @@
var selectedCandidateIdx = null; var selectedCandidateIdx = null;
//Set basic interval function to do updates //Multi-session mode
setInterval(doUpdate,5000); //Loop every 5 seconds var multiSession = null;
//Session selection
var currSession = "";
function getSessionIdFromUrl() {
// path is like /arbiter/:sessionId/overview
var sessionIdRegexp = /\/arbiter\/([^\/]+)/g;
var match = sessionIdRegexp.exec(window.location.pathname)
return match[1];
}
function doUpdate(){ function getCurrSession(callback) {
//Get the update status, and do something with it: if (multiSession) {
$.get("/arbiter/lastUpdate",function(data){ if (currSession == "") {
//Encoding: matches names in UpdateStatus class // get only once
var jsonObj = JSON.parse(JSON.stringify(data)); currSession = getSessionIdFromUrl();
var statusTime = jsonObj['statusUpdateTime']; }
var settingsTime = jsonObj['settingsUpdateTime']; //we don't show session selector in multi-session mode (one can list sessions at /arbiter)
var resultsTime = jsonObj['resultsUpdateTime']; callback();
//console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime); } else {
$.ajax({
//Update available sessions: url: "/arbiter/sessions/current",
var currSession; async: true,
$.get("/arbiter/sessions/current", function(data){ error: function (query, status, error) {
currSession = data; //JSON.stringify(data); console.log("Error getting data: " + error);
console.log("Current: " + currSession); },
}); success: function (data) {
currSession = data;
console.log("Current session: " + currSession);
//Update available sessions in session selector
$.get("/arbiter/sessions/all", function(data){ $.get("/arbiter/sessions/all", function(data){
var keys = data; // JSON.stringify(data); var keys = data; // JSON.stringify(data);
@ -239,15 +250,56 @@
$("#sessionSelect option[value='" + keys[currSelectedIdx] +"']").attr("selected", "selected"); $("#sessionSelect option[value='" + keys[currSelectedIdx] +"']").attr("selected", "selected");
$("#sessionSelectDiv").show(); $("#sessionSelectDiv").show();
} }
// console.log("Got sessions: " + keys + ", current: " + currSession); // console.log("Got sessions: " + keys + ", current: " + currSession);
callback();
}); });
}
});
}
}
function getSessionSettings(callback) {
// load only once
if (multiSession != null) {
getCurrSession(callback);
} else {
$.ajax({
url: "/arbiter/multisession",
async: true,
error: function (query, status, error) {
console.log("Error getting data: " + error);
},
success: function (data) {
multiSession = data == "true";
getCurrSession(callback);
}
});
}
}
//Initial update
doUpdate();
//Set basic interval function to do updates
setInterval(doUpdate,5000); //Loop every 5 seconds
function doUpdate(){
//Get the update status, and do something with it:
getSessionSettings(function(){
var sessionUpdateUrl = multiSession ? "/arbiter/" + currSession + "/lastUpdate" : "/arbiter/lastUpdate";
$.get(sessionUpdateUrl,function(data){
//Encoding: matches names in UpdateStatus class
var jsonObj = JSON.parse(JSON.stringify(data));
var statusTime = jsonObj['statusUpdateTime'];
var settingsTime = jsonObj['settingsUpdateTime'];
var resultsTime = jsonObj['resultsUpdateTime'];
//console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime);
//Check last update times for each part of document, and update as necessary //Check last update times for each part of document, and update as necessary
//First section: summary status //First section: summary status
if(lastStatusUpdateTime != statusTime){ if(lastStatusUpdateTime != statusTime){
//Get JSON: address set by SummaryStatusResource var summaryStatusUrl = multiSession ? "/arbiter/" + currSession + "/summary" : "/arbiter/summary";
$.get("/arbiter/summary",function(data){ $.get(summaryStatusUrl,function(data){
var summaryStatusDiv = $('#statusdiv'); var summaryStatusDiv = $('#statusdiv');
summaryStatusDiv.html(''); summaryStatusDiv.html('');
@ -262,7 +314,8 @@
//Second section: Optimization settings //Second section: Optimization settings
if(lastSettingsUpdateTime != settingsTime){ if(lastSettingsUpdateTime != settingsTime){
//Get JSON for components //Get JSON for components
$.get("/arbiter/config",function(data){ var settingsUrl = multiSession ? "/arbiter/" + currSession + "/config" : "/arbiter/config";
$.get(settingsUrl,function(data){
var str = JSON.stringify(data); var str = JSON.stringify(data);
var configDiv = $('#settingsdiv'); var configDiv = $('#settingsdiv');
@ -277,9 +330,9 @@
//Third section: Summary results table (summary info for each candidate) //Third section: Summary results table (summary info for each candidate)
if(lastResultsUpdateTime != resultsTime){ if(lastResultsUpdateTime != resultsTime){
//Get JSON for results table
//Get JSON; address set by SummaryResultsResource var resultsUrl = multiSession ? "/arbiter/" + currSession + "/results" : "/arbiter/results";
$.get("/arbiter/results",function(data){ $.get(resultsUrl,function(data){
//Expect an array of CandidateInfo type objects here //Expect an array of CandidateInfo type objects here
resultsTableContent = data; resultsTableContent = data;
drawResultTable(); drawResultTable();
@ -291,7 +344,10 @@
//Finally: Currently selected result //Finally: Currently selected result
if(selectedCandidateIdx != null){ if(selectedCandidateIdx != null){
//Get JSON for components //Get JSON for components
$.get("/arbiter/candidateInfo/"+selectedCandidateIdx,function(data){ var candidateInfoUrl = multiSession
? "/arbiter/" + currSession + "/candidateInfo/" + selectedCandidateIdx
: "/arbiter/candidateInfo/" + selectedCandidateIdx;
$.get(candidateInfoUrl,function(data){
var str = JSON.stringify(data); var str = JSON.stringify(data);
var resultsViewDiv = $('#resultsviewdiv'); var resultsViewDiv = $('#resultsviewdiv');
@ -302,6 +358,7 @@
}); });
} }
}) })
})
} }
function createTable(tableObj,tableId,appendTo){ function createTable(tableObj,tableId,appendTo){

View File

@ -16,6 +16,9 @@
package org.deeplearning4j.arbiter.optimize; package org.deeplearning4j.arbiter.optimize;
import io.netty.handler.codec.http.HttpResponseStatus;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.ComputationGraphSpace;
@ -54,6 +57,7 @@ import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.junit.Ignore; import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.common.function.Function;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -62,52 +66,43 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File; import java.io.File;
import java.util.Collections; import java.io.IOException;
import java.util.HashMap; import java.io.UnsupportedEncodingException;
import java.util.Map; import java.net.HttpURLConnection;
import java.util.Properties; import java.net.URL;
import java.net.URLEncoder;
import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/** /**
* Created by Alex on 19/07/2017. * Created by Alex on 19/07/2017.
*/ */
@Slf4j
public class TestBasic extends BaseDL4JTest { public class TestBasic extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 3600_000L;
}
@Test @Test
@Ignore @Ignore
public void testBasicUiOnly() throws Exception { public void testBasicUiOnly() throws Exception {
UIServer.getInstance(); UIServer.getInstance();
Thread.sleep(1000000); Thread.sleep(1000_000);
} }
@Test @Test
@Ignore @Ignore
public void testBasicMnist() throws Exception { public void testBasicMnist() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
MultiLayerSpace mls = new MultiLayerSpace.Builder() MultiLayerSpace mls = getMultiLayerSpaceMnist();
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
.l2(new ContinuousParameterSpace(0.0001, 0.05))
.addLayer(
new ConvolutionLayerSpace.Builder().nIn(1)
.nOut(new IntegerParameterSpace(5, 30))
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
new int[]{4, 4}, new int[]{5, 5}))
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
new int[]{2, 2}))
.activation(new DiscreteParameterSpace<>(Activation.RELU,
Activation.SOFTPLUS, Activation.LEAKYRELU))
.build())
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
Map<String, Object> commands = new HashMap<>(); Map<String, Object> commands = new HashMap<>();
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); // commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
@ -144,7 +139,30 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
}
private static MultiLayerSpace getMultiLayerSpaceMnist() {
return new MultiLayerSpace.Builder()
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
.l2(new ContinuousParameterSpace(0.0001, 0.05))
.addLayer(
new ConvolutionLayerSpace.Builder().nIn(1)
.nOut(new IntegerParameterSpace(5, 30))
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
new int[]{4, 4}, new int[]{5, 5}))
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
new int[]{2, 2}))
.activation(new DiscreteParameterSpace<>(Activation.RELU,
Activation.SOFTPLUS, Activation.LEAKYRELU))
.build())
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
} }
@Test @Test
@ -233,7 +251,7 @@ public class TestBasic extends BaseDL4JTest {
.build(); .build();
//Define configuration: //Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
DataProvider dataProvider = new MnistDataSetProvider(); DataProvider dataProvider = new MnistDataSetProvider();
@ -331,7 +349,7 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
} }
@ -396,7 +414,7 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
} }
@ -433,7 +451,7 @@ public class TestBasic extends BaseDL4JTest {
.build(); .build();
//Define configuration: //Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
DataProvider dataProvider = new MnistDataSetProvider(); DataProvider dataProvider = new MnistDataSetProvider();
@ -465,13 +483,17 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
} }
/**
* Visualize multiple optimization sessions run one after another on single-session mode UI
* @throws InterruptedException if current thread has been interrupted
*/
@Test @Test
@Ignore @Ignore
public void testBasicMnistMultipleSessions() throws Exception { public void testBasicMnistMultipleSessions() throws InterruptedException {
MultiLayerSpace mls = new MultiLayerSpace.Builder() MultiLayerSpace mls = new MultiLayerSpace.Builder()
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
@ -499,8 +521,10 @@ public class TestBasic extends BaseDL4JTest {
//Define configuration: //Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
DataProvider dataProvider = new MnistDataSetProvider();
Class<? extends DataSource> ds = MnistDataSource.class;
Properties dsp = new Properties();
dsp.setProperty("minibatch", "8");
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
@ -513,7 +537,7 @@ public class TestBasic extends BaseDL4JTest {
OptimizationConfiguration configuration = OptimizationConfiguration configuration =
new OptimizationConfiguration.Builder() new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider) .candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath)) .modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true)) .scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
@ -535,7 +559,7 @@ public class TestBasic extends BaseDL4JTest {
candidateGenerator = new RandomSearchGenerator(mls, commands); candidateGenerator = new RandomSearchGenerator(mls, commands);
configuration = new OptimizationConfiguration.Builder() configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider) .candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath)) .modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true)) .scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
@ -550,7 +574,148 @@ public class TestBasic extends BaseDL4JTest {
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
}
/**
* Auto-attach multiple optimization sessions to multi-session mode UI
* @throws IOException if could not connect to the server
*/
@Test
public void testUiMultiSessionAutoAttach() throws IOException {
//Define configuration:
MultiLayerSpace mls = getMultiLayerSpaceMnist();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
Class<? extends DataSource> ds = MnistDataSource.class;
Properties dsp = new Properties();
dsp.setProperty("minibatch", "8");
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\")
.getAbsolutePath();
File f = new File(modelSavePath);
if (f.exists())
f.delete();
f.mkdir();
if (!f.exists())
throw new RuntimeException();
OptimizationConfiguration configuration =
new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS),
new MaxCandidatesCondition(1))
.build();
IOptimizationRunner runner =
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
// add 3 different sessions to the same execution
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
for (int i = 0; i < 3; i++) {
StatsStorage ss = new InMemoryStatsStorage();
@NonNull String sessionId = "sid" + i;
statsStorageForSession.put(sessionId, ss);
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
runner.addListeners(sl);
}
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
String serverAddress = uIServer.getAddress();
runner.execute();
for (String sessionId : statsStorageForSession.keySet()) {
/*
* Visiting /arbiter/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();
log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId));
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId)));
}
}
/**
* Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL
* @throws Exception if an error occurred
*/
@Test
@Ignore
public void testUiMultiSessionManualAttach() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
//Define configuration:
MultiLayerSpace mls = getMultiLayerSpaceMnist();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
Class<? extends DataSource> ds = MnistDataSource.class;
Properties dsp = new Properties();
dsp.setProperty("minibatch", "8");
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\")
.getAbsolutePath();
File f = new File(modelSavePath);
if (f.exists())
f.delete();
f.mkdir();
if (!f.exists())
throw new RuntimeException();
OptimizationConfiguration configuration =
new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES),
new MaxCandidatesCondition(10))
.build();
// parallel execution of multiple optimization sessions
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
for (int i = 0; i < 3; i++) {
String sessionId = "sid" + i;
IOptimizationRunner runner =
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
StatsStorage ss = new InMemoryStatsStorage();
statsStorageForSession.put(sessionId, ss);
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
runner.addListeners(sl);
// Asynchronous execution
new Thread(runner::execute).start();
}
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
String serverAddress = uIServer.getAddress();
for (String sessionId : statsStorageForSession.keySet()) {
log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId));
}
Thread.sleep(1000_000);
}
/**
* Get URL for arbiter session on given server address
* @param serverAddress server address, e.g.: http://localhost:9000
* @param sessionId session ID (will be URL-encoded)
* @return URL
* @throws UnsupportedEncodingException if the character encoding is not supported
*/
private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
} }
private static class MnistDataSetProvider implements DataProvider { private static class MnistDataSetProvider implements DataProvider {

View File

@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
private static Integer instancePort; private static Integer instancePort;
private static Thread autoStopThread; private static Thread autoStopThread;
private TrainModule trainModule;
/** /**
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started. * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* @param port TCP socket port for {@link HttpServer} to listen * @param port TCP socket port for {@link HttpServer} to listen
@ -194,8 +192,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
VertxUIServer.autoStopThread = new Thread(() -> { VertxUIServer.autoStopThread = new Thread(() -> {
try { try {
currentThread.join(); currentThread.join();
log.info("Deeplearning4j UI server is auto-stopping.");
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) { if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
log.info("Deeplearning4j UI server is auto-stopping after thread (name: {}) died.",
currentThread.getName());
instance.stop(); instance.stop();
} }
} catch (InterruptedException e) { } catch (InterruptedException e) {
@ -207,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
private List<UIModule> uiModules = new CopyOnWriteArrayList<>(); private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
private RemoteReceiverModule remoteReceiverModule; private RemoteReceiverModule remoteReceiverModule;
private StatsStorageLoader statsStorageLoader; /**
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
*/
@Getter
private Function<String, Boolean> statsStorageLoader;
//typeIDModuleMap: Records which modules are registered for which type IDs //typeIDModuleMap: Records which modules are registered for which type IDs
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>(); private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
@ -247,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
*/ */
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) { public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
if (statsStorageProvider != null) { if (statsStorageProvider != null) {
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider); this.statsStorageLoader = (sessionId) -> {
if (trainModule != null) { log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
this.trainModule.setSessionLoader(this.statsStorageLoader); StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
if (statsStorage != null) {
if (statsStorage.sessionExists(sessionId)) {
attach(statsStorage);
return true;
} }
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
return false;
} else {
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
"StatsStorageProvider returned null.");
return false;
}
};
} }
} }
@ -302,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
} }
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress); uiModules.add(new TrainModule());
uiModules.add(trainModule);
uiModules.add(new ConvolutionalListenerModule()); uiModules.add(new ConvolutionalListenerModule());
uiModules.add(new TsneModule()); uiModules.add(new TsneModule());
uiModules.add(new SameDiffModule()); uiModules.add(new SameDiffModule());
@ -596,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
} }
} }
/**
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
*/
private class StatsStorageLoader implements Function<String, Boolean> {
Function<String, StatsStorage> statsStorageProvider;
StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) {
this.statsStorageProvider = statsStorageProvider;
}
@Override
public Boolean apply(String sessionId) {
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
if (statsStorage != null) {
if (statsStorage.sessionExists(sessionId)) {
attach(statsStorage);
return true;
}
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
return false;
} else {
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
"StatsStorageProvider returned null.");
return false;
}
}
}
//================================================================================================================== //==================================================================================================================
// CLI Launcher // CLI Launcher

View File

@ -157,8 +157,9 @@ public interface UIServer {
/** /**
* Stop/shut down the UI server. This synchronous function should wait until the server is stopped. * Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
* @throws InterruptedException if the current thread is interrupted while waiting
*/ */
void stop() throws Exception; void stop() throws InterruptedException;
/** /**
* Stop/shut down the UI server. * Stop/shut down the UI server.

View File

@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext;
import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongArrayList;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.ui.VertxUIServer;
import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.I18N; import org.deeplearning4j.ui.api.I18N;
import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.Route;
@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.model.stats.api.StatsReport; import org.deeplearning4j.ui.model.stats.api.StatsReport;
import org.deeplearning4j.ui.model.stats.api.StatsType; import org.deeplearning4j.ui.model.stats.api.StatsType;
import org.nd4j.common.function.Function; import org.nd4j.common.function.Function;
import org.nd4j.common.function.Supplier;
import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple; import org.nd4j.common.primitives.Triple;
@ -86,8 +84,6 @@ public class TrainModule implements UIModule {
private static final DecimalFormat df2 = new DecimalFormat("#.00"); private static final DecimalFormat df2 = new DecimalFormat("#.00");
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
private final Supplier<String> addressSupplier;
private enum ModelType { private enum ModelType {
MLN, CG, Layer MLN, CG, Layer
} }
@ -99,29 +95,14 @@ public class TrainModule implements UIModule {
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>(); private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
private final boolean multiSession;
@Getter @Setter
private Function<String, Boolean> sessionLoader;
private final Configuration configuration; private final Configuration configuration;
public TrainModule() {
this(false, null, null);
}
/** /**
* TrainModule * TrainModule
*
* @param multiSession multi-session mode
* @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter
* in multi-session mode
* @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
*/ */
public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) { public TrainModule() {
this.multiSession = multiSession;
this.sessionLoader = sessionLoader;
this.addressSupplier = addressSupplier;
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY); String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
int value = DEFAULT_MAX_CHART_POINTS; int value = DEFAULT_MAX_CHART_POINTS;
if (maxChartPointsProp != null) { if (maxChartPointsProp != null) {
@ -159,8 +140,9 @@ public class TrainModule implements UIModule {
@Override @Override
public List<Route> getRoutes() { public List<Route> getRoutes() {
List<Route> r = new ArrayList<>(); List<Route> r = new ArrayList<>();
r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); r.add(new Route("/train/multisession", HttpMethod.GET,
if (multiSession) { (path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
if (VertxUIServer.getInstance().isMultiSession()) {
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> { r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
rc.response() rc.response()
@ -264,7 +246,9 @@ public class TrainModule implements UIModule {
if (!knownSessionIDs.isEmpty()) { if (!knownSessionIDs.isEmpty()) {
sb.append(" <ul>"); sb.append(" <ul>");
for (String sessionId : knownSessionIDs.keySet()) { for (String sessionId : knownSessionIDs.keySet()) {
sb.append(" <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n"); sb.append(" <li><a href=\"/train/")
.append(sessionId).append("\">")
.append(sessionId).append("</a></li>\n");
} }
sb.append(" </ul>"); sb.append(" </ul>");
} else { } else {
@ -284,9 +268,11 @@ public class TrainModule implements UIModule {
* *
* @param sessionId session ID to look fo with provider * @param sessionId session ID to look fo with provider
* @param targetPath one of overview / model / system, or null * @param targetPath one of overview / model / system, or null
* @param rc routing context
*/ */
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
if (sessionLoader != null && sessionLoader.apply(sessionId)) { Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
if (loader != null && loader.apply(sessionId)) {
if (targetPath != null) { if (targetPath != null) {
rc.reroute(targetPath); rc.reroute(targetPath);
} else { } else {
@ -306,9 +292,9 @@ public class TrainModule implements UIModule {
&& StatsListener.TYPE_ID.equals(sse.getTypeID()) && StatsListener.TYPE_ID.equals(sse.getTypeID())
&& !knownSessionIDs.containsKey(sse.getSessionID())) { && !knownSessionIDs.containsKey(sse.getSessionID())) {
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage()); knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
if (multiSession) { if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}", log.info("Adding training session {}/train/{} of StatsStorage instance {}",
addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage()); VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
} }
} }
@ -332,9 +318,9 @@ public class TrainModule implements UIModule {
if (!StatsListener.TYPE_ID.equals(typeID)) if (!StatsListener.TYPE_ID.equals(typeID))
continue; continue;
knownSessionIDs.put(sessionID, statsStorage); knownSessionIDs.put(sessionID, statsStorage);
if (multiSession) { if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}", log.info("Adding training session {}/train/{} of StatsStorage instance {}",
addressSupplier.get(), sessionID, statsStorage); VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
} }
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID); List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
@ -364,9 +350,9 @@ public class TrainModule implements UIModule {
} }
for (String s : toRemove) { for (String s : toRemove) {
knownSessionIDs.remove(s); knownSessionIDs.remove(s);
if (multiSession) { if (VertxUIServer.getInstance().isMultiSession()) {
log.info("Removing training session {}/train/{} of StatsStorage instance {}.", log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
addressSupplier.get(), s, statsStorage); VertxUIServer.getInstance().getAddress(), s, statsStorage);
} }
lastUpdateForSession.remove(s); lastUpdateForSession.remove(s);
} }
@ -602,13 +588,13 @@ public class TrainModule implements UIModule {
} }
/** /**
* Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session * Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
* *
* @param sessionId session ID * @param sessionId session ID
* @return {@link I18N} instance * @return {@link I18N} instance
*/ */
private I18N getI18N(String sessionId) { private I18N getI18N(String sessionId) {
return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance(); return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
} }

View File

@ -259,7 +259,7 @@ public class TestVertxUIManual extends BaseDL4JTest {
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.", log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
sessionId, autoDetachTimeoutMillis); sessionId, autoDetachTimeoutMillis);
uIServer.detach(statsStorage); uIServer.detach(statsStorage);
log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}", log.info(" To re-attach StatsStorage of training session, visit {}/train/{}",
uIServer.getAddress(), sessionId); uIServer.getAddress(), sessionId);
} }
}).start(); }).start();

View File

@ -197,9 +197,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
} }
/** /**
* Get URL-encoded URL for training session on given server address * Get URL for training session on given server address
* @param serverAddress server address * @param serverAddress server address
* @param sessionId session ID * @param sessionId session ID (will be URL-encoded)
* @return URL * @return URL
* @throws UnsupportedEncodingException if the used encoding is not supported * @throws UnsupportedEncodingException if the used encoding is not supported
*/ */