Merge pull request #8872 from printomi/ui_multisession_arbiter
DL4J UI: Add multi-session support for Arbitermaster
commit
a76f957b72
|
@ -36,6 +36,7 @@ import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
|
|||
import org.deeplearning4j.arbiter.ui.misc.UIUtils;
|
||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||
import org.deeplearning4j.ui.VertxUIServer;
|
||||
import org.deeplearning4j.ui.api.Component;
|
||||
import org.deeplearning4j.ui.api.*;
|
||||
import org.deeplearning4j.ui.components.chart.ChartLine;
|
||||
|
@ -50,6 +51,7 @@ import org.deeplearning4j.ui.components.text.style.StyleText;
|
|||
import org.deeplearning4j.ui.i18n.I18NResource;
|
||||
import org.joda.time.format.DateTimeFormat;
|
||||
import org.joda.time.format.DateTimeFormatter;
|
||||
import org.nd4j.common.function.Function;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
|
||||
|
@ -77,7 +79,6 @@ public class ArbiterModule implements UIModule {
|
|||
|
||||
private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>());
|
||||
|
||||
|
||||
//Styles for UI:
|
||||
private static final StyleTable STYLE_TABLE = new StyleTable.Builder()
|
||||
.width(100, LengthUnit.Percent)
|
||||
|
@ -134,20 +135,135 @@ public class ArbiterModule implements UIModule {
|
|||
|
||||
@Override
|
||||
public List<Route> getRoutes() {
|
||||
Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
|
||||
.putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html"));
|
||||
Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc));
|
||||
Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc));
|
||||
Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc));
|
||||
Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc));
|
||||
Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc));
|
||||
Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc));
|
||||
boolean multiSession = VertxUIServer.getMultiSession().get();
|
||||
List<Route> r = new ArrayList<>();
|
||||
r.add(new Route("/arbiter/multisession", HttpMethod.GET,
|
||||
(path, rc) -> rc.response().end(multiSession ? "true" : "false")));
|
||||
if (multiSession) {
|
||||
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
||||
r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
rc.response()
|
||||
.putHeader("content-type", "text/html; charset=utf-8")
|
||||
.sendFile("templates/ArbiterUI.html");
|
||||
} else {
|
||||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
|
||||
Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc));
|
||||
Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc));
|
||||
Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc));
|
||||
r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
this.getLastUpdateTime(path.get(0), rc);
|
||||
} else {
|
||||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
this.getCandidateInfo(path.get(0), path.get(1), rc);
|
||||
} else {
|
||||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
this.getOptimizationConfig(path.get(0), rc);
|
||||
} else {
|
||||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
this.getSummaryResults(path.get(0), rc);
|
||||
} else {
|
||||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
this.getSummaryStatus(path.get(0), rc);
|
||||
} else {
|
||||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
} else {
|
||||
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
|
||||
.putHeader("content-type", "text/html; charset=utf-8")
|
||||
.sendFile("templates/ArbiterUI.html")));
|
||||
r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc)));
|
||||
r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET,
|
||||
(path, rc) -> this.getCandidateInfo(null, path.get(0), rc)));
|
||||
r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc)));
|
||||
r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc)));
|
||||
r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc)));
|
||||
|
||||
return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c);
|
||||
r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)));
|
||||
r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET,
|
||||
(path, rc) -> this.setSession(path.get(0), rc)));
|
||||
}
|
||||
// common for single- and multi-session mode
|
||||
r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Load StatsStorage via provider, or return "not found"
|
||||
*
|
||||
* @param sessionId session ID to look fo with provider
|
||||
* @param targetPath one of overview / model / system, or null
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
||||
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
|
||||
if (loader != null && loader.apply(sessionId)) {
|
||||
if (targetPath != null) {
|
||||
rc.reroute(targetPath);
|
||||
} else {
|
||||
rc.response().end();
|
||||
}
|
||||
} else {
|
||||
rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
|
||||
.end("Unknown session ID: " + sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* List optimization sessions. Returns a HTML list of arbiter sessions
|
||||
*/
|
||||
private synchronized void listSessions(RoutingContext rc) {
|
||||
StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" +
|
||||
"<html lang=\"en\">\n" +
|
||||
"<head>\n" +
|
||||
" <meta charset=\"utf-8\">\n" +
|
||||
" <title>Optimization sessions - DL4J Arbiter UI</title>\n" +
|
||||
" </head>\n" +
|
||||
"\n" +
|
||||
" <body>\n" +
|
||||
" <h1>DL4J Arbiter UI</h1>\n" +
|
||||
" <p>UI server is in multi-session mode." +
|
||||
" To visualize an optimization session, please select one from the following list.</p>\n" +
|
||||
" <h2>List of attached optimization sessions</h2>\n");
|
||||
if (!knownSessionIDs.isEmpty()) {
|
||||
sb.append(" <ul>");
|
||||
for (String sessionId : knownSessionIDs.keySet()) {
|
||||
sb.append(" <li><a href=\"/arbiter/")
|
||||
.append(sessionId).append("\">")
|
||||
.append(sessionId).append("</a></li>\n");
|
||||
}
|
||||
sb.append(" </ul>");
|
||||
} else {
|
||||
sb.append("No optimization session attached.");
|
||||
}
|
||||
|
||||
sb.append(" </body>\n" +
|
||||
"</html>\n");
|
||||
|
||||
rc.response()
|
||||
.putHeader("content-type", "text/html; charset=utf-8")
|
||||
.end(sb.toString());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -201,7 +317,7 @@ public class ArbiterModule implements UIModule {
|
|||
.end(asJson(sid));
|
||||
}
|
||||
|
||||
private void listSessions(RoutingContext rc) {
|
||||
private void sessionInfo(RoutingContext rc) {
|
||||
rc.response()
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(asJson(knownSessionIDs.keySet()));
|
||||
|
@ -257,10 +373,25 @@ public class ArbiterModule implements UIModule {
|
|||
|
||||
/**
|
||||
* Return the last update time for the page
|
||||
* @param sessionId session ID (optional, for multi-session mode)
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void getLastUpdateTime(RoutingContext rc){
|
||||
//TODO - this forces updates on every request... which is fine, just inefficient
|
||||
long t = System.currentTimeMillis();
|
||||
private void getLastUpdateTime(String sessionId, RoutingContext rc){
|
||||
if (sessionId == null) {
|
||||
sessionId = currentSessionID;
|
||||
}
|
||||
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||
List<Persistable> latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID);
|
||||
long t = 0;
|
||||
if (latestUpdates.isEmpty()) {
|
||||
t = System.currentTimeMillis();
|
||||
} else {
|
||||
for (Persistable update : latestUpdates) {
|
||||
if (update.getTimeStamp() > t) {
|
||||
t = update.getTimeStamp();
|
||||
}
|
||||
}
|
||||
}
|
||||
UpdateStatus us = new UpdateStatus(t, t, t);
|
||||
|
||||
rc.response().putHeader("content-type", "application/json").end(asJson(us));
|
||||
|
@ -274,56 +405,28 @@ public class ArbiterModule implements UIModule {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the last update time for the specified model IDs
|
||||
* @param modelIDs Model IDs to get the update time for
|
||||
*/
|
||||
private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){
|
||||
if(currentSessionID == null){
|
||||
rc.response().end();
|
||||
return;
|
||||
}
|
||||
|
||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
||||
if(ss == null){
|
||||
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
|
||||
rc.response().end("-1");
|
||||
return;
|
||||
}
|
||||
|
||||
String[] split = modelIDs.split(",");
|
||||
|
||||
long[] lastUpdateTimes = new long[split.length];
|
||||
for( int i=0; i<split.length; i++ ){
|
||||
String s = split[i];
|
||||
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, s);
|
||||
if(p != null){
|
||||
lastUpdateTimes[i] = p.getTimeStamp();
|
||||
}
|
||||
}
|
||||
|
||||
rc.response().putHeader("content-type", "application/json").end(asJson(lastUpdateTimes));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the info for a specific candidate - last section in the UI
|
||||
*
|
||||
* @param sessionId session ID (optional, for multi-session mode)
|
||||
* @param candidateId ID for the candidate
|
||||
* @return Content/info for the candidate
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void getCandidateInfo(String candidateId, RoutingContext rc){
|
||||
|
||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
||||
private void getCandidateInfo(String sessionId, String candidateId, RoutingContext rc){
|
||||
if (sessionId == null) {
|
||||
sessionId = currentSessionID;
|
||||
}
|
||||
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||
if(ss == null){
|
||||
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
|
||||
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", sessionId);
|
||||
rc.response().end();
|
||||
return;
|
||||
}
|
||||
|
||||
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);;
|
||||
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss
|
||||
.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||
OptimizationConfiguration oc = gcp.getOptimizationConfiguration();
|
||||
|
||||
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId);
|
||||
Persistable p = ss.getLatestUpdate(sessionId, ARBITER_UI_TYPE_ID, candidateId);
|
||||
if(p == null){
|
||||
String title = "No results found for model " + candidateId + ".";
|
||||
ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build();
|
||||
|
@ -493,17 +596,21 @@ public class ArbiterModule implements UIModule {
|
|||
|
||||
/**
|
||||
* Get the optimization configuration - second section in the page
|
||||
* @param sessionId session ID (optional, for multi-session mode)
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void getOptimizationConfig(RoutingContext rc){
|
||||
|
||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
||||
private void getOptimizationConfig(String sessionId, RoutingContext rc){
|
||||
if (sessionId == null) {
|
||||
sessionId = currentSessionID;
|
||||
}
|
||||
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||
if(ss == null){
|
||||
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID);
|
||||
log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
|
||||
rc.response().end();
|
||||
return;
|
||||
}
|
||||
|
||||
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||
Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||
|
||||
if(p == null){
|
||||
log.debug("No static info");
|
||||
|
@ -593,15 +700,23 @@ public class ArbiterModule implements UIModule {
|
|||
rc.response().putHeader("content-type", "application/json").end(asJson(cd));
|
||||
}
|
||||
|
||||
private void getSummaryResults(RoutingContext rc){
|
||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
||||
/**
|
||||
* Get candidates summary results list - third section on the page: Results table
|
||||
* @param sessionId session ID (optional, for multi-session mode)
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void getSummaryResults(String sessionId, RoutingContext rc){
|
||||
if (sessionId == null) {
|
||||
sessionId = currentSessionID;
|
||||
}
|
||||
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||
if(ss == null){
|
||||
log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID);
|
||||
log.debug("getSummaryResults(): Session ID is unknown: {}", sessionId);
|
||||
rc.response().end();
|
||||
return;
|
||||
}
|
||||
|
||||
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID));
|
||||
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
|
||||
List<String[]> table = new ArrayList<>();
|
||||
for(Persistable per : allModelInfoTemp){
|
||||
ModelInfoPersistable mip = (ModelInfoPersistable)per;
|
||||
|
@ -614,16 +729,21 @@ public class ArbiterModule implements UIModule {
|
|||
|
||||
/**
|
||||
* Get summary status information: first section in the page
|
||||
* @param sessionId session ID (optional, for multi-session mode)
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void getSummaryStatus(RoutingContext rc){
|
||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
||||
private void getSummaryStatus(String sessionId, RoutingContext rc){
|
||||
if (sessionId == null) {
|
||||
sessionId = currentSessionID;
|
||||
}
|
||||
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||
if(ss == null){
|
||||
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID);
|
||||
log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
|
||||
rc.response().end();
|
||||
return;
|
||||
}
|
||||
|
||||
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||
Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||
|
||||
if(p == null){
|
||||
log.info("No static info");
|
||||
|
@ -643,7 +763,7 @@ public class ArbiterModule implements UIModule {
|
|||
|
||||
//How to get this? query all model infos...
|
||||
|
||||
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID));
|
||||
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
|
||||
List<ModelInfoPersistable> allModelInfo = new ArrayList<>();
|
||||
for(Persistable per : allModelInfoTemp){
|
||||
ModelInfoPersistable mip = (ModelInfoPersistable)per;
|
||||
|
@ -668,7 +788,6 @@ public class ArbiterModule implements UIModule {
|
|||
|
||||
//TODO: I18N
|
||||
|
||||
//TODO don't use currentTimeMillis due to stored data??
|
||||
long bestTime;
|
||||
Double bestScore = null;
|
||||
String bestModelString = null;
|
||||
|
@ -685,7 +804,12 @@ public class ArbiterModule implements UIModule {
|
|||
String execTotalRuntimeStr = "";
|
||||
if(execStartTime > 0){
|
||||
execStartTimeStr = TIME_FORMATTER.print(execStartTime);
|
||||
execTotalRuntimeStr = UIUtils.formatDuration(System.currentTimeMillis() - execStartTime);
|
||||
// allModelInfo is sorted by Persistable::getTimeStamp
|
||||
long lastCompleteTime = execStartTime;
|
||||
if (!allModelInfo.isEmpty()) {
|
||||
lastCompleteTime = allModelInfo.get(allModelInfo.size() - 1).getTimeStamp();
|
||||
}
|
||||
execTotalRuntimeStr = UIUtils.formatDuration(lastCompleteTime - execStartTime);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -198,27 +198,38 @@
|
|||
|
||||
var selectedCandidateIdx = null;
|
||||
|
||||
//Set basic interval function to do updates
|
||||
setInterval(doUpdate,5000); //Loop every 5 seconds
|
||||
//Multi-session mode
|
||||
var multiSession = null;
|
||||
//Session selection
|
||||
var currSession = "";
|
||||
|
||||
function getSessionIdFromUrl() {
|
||||
// path is like /arbiter/:sessionId/overview
|
||||
var sessionIdRegexp = /\/arbiter\/([^\/]+)/g;
|
||||
var match = sessionIdRegexp.exec(window.location.pathname)
|
||||
return match[1];
|
||||
}
|
||||
|
||||
function doUpdate(){
|
||||
//Get the update status, and do something with it:
|
||||
$.get("/arbiter/lastUpdate",function(data){
|
||||
//Encoding: matches names in UpdateStatus class
|
||||
var jsonObj = JSON.parse(JSON.stringify(data));
|
||||
var statusTime = jsonObj['statusUpdateTime'];
|
||||
var settingsTime = jsonObj['settingsUpdateTime'];
|
||||
var resultsTime = jsonObj['resultsUpdateTime'];
|
||||
//console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime);
|
||||
|
||||
//Update available sessions:
|
||||
var currSession;
|
||||
$.get("/arbiter/sessions/current", function(data){
|
||||
currSession = data; //JSON.stringify(data);
|
||||
console.log("Current: " + currSession);
|
||||
});
|
||||
function getCurrSession(callback) {
|
||||
if (multiSession) {
|
||||
if (currSession == "") {
|
||||
// get only once
|
||||
currSession = getSessionIdFromUrl();
|
||||
}
|
||||
//we don't show session selector in multi-session mode (one can list sessions at /arbiter)
|
||||
callback();
|
||||
} else {
|
||||
$.ajax({
|
||||
url: "/arbiter/sessions/current",
|
||||
async: true,
|
||||
error: function (query, status, error) {
|
||||
console.log("Error getting data: " + error);
|
||||
},
|
||||
success: function (data) {
|
||||
currSession = data;
|
||||
console.log("Current session: " + currSession);
|
||||
|
||||
//Update available sessions in session selector
|
||||
$.get("/arbiter/sessions/all", function(data){
|
||||
var keys = data; // JSON.stringify(data);
|
||||
|
||||
|
@ -239,15 +250,56 @@
|
|||
$("#sessionSelect option[value='" + keys[currSelectedIdx] +"']").attr("selected", "selected");
|
||||
$("#sessionSelectDiv").show();
|
||||
}
|
||||
// console.log("Got sessions: " + keys + ", current: " + currSession);
|
||||
// console.log("Got sessions: " + keys + ", current: " + currSession);
|
||||
callback();
|
||||
});
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function getSessionSettings(callback) {
|
||||
// load only once
|
||||
if (multiSession != null) {
|
||||
getCurrSession(callback);
|
||||
} else {
|
||||
$.ajax({
|
||||
url: "/arbiter/multisession",
|
||||
async: true,
|
||||
error: function (query, status, error) {
|
||||
console.log("Error getting data: " + error);
|
||||
},
|
||||
success: function (data) {
|
||||
multiSession = data == "true";
|
||||
getCurrSession(callback);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
//Initial update
|
||||
doUpdate();
|
||||
//Set basic interval function to do updates
|
||||
setInterval(doUpdate,5000); //Loop every 5 seconds
|
||||
|
||||
function doUpdate(){
|
||||
//Get the update status, and do something with it:
|
||||
getSessionSettings(function(){
|
||||
var sessionUpdateUrl = multiSession ? "/arbiter/" + currSession + "/lastUpdate" : "/arbiter/lastUpdate";
|
||||
$.get(sessionUpdateUrl,function(data){
|
||||
//Encoding: matches names in UpdateStatus class
|
||||
var jsonObj = JSON.parse(JSON.stringify(data));
|
||||
var statusTime = jsonObj['statusUpdateTime'];
|
||||
var settingsTime = jsonObj['settingsUpdateTime'];
|
||||
var resultsTime = jsonObj['resultsUpdateTime'];
|
||||
//console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime);
|
||||
|
||||
//Check last update times for each part of document, and update as necessary
|
||||
//First section: summary status
|
||||
if(lastStatusUpdateTime != statusTime){
|
||||
//Get JSON: address set by SummaryStatusResource
|
||||
$.get("/arbiter/summary",function(data){
|
||||
var summaryStatusUrl = multiSession ? "/arbiter/" + currSession + "/summary" : "/arbiter/summary";
|
||||
$.get(summaryStatusUrl,function(data){
|
||||
var summaryStatusDiv = $('#statusdiv');
|
||||
summaryStatusDiv.html('');
|
||||
|
||||
|
@ -262,7 +314,8 @@
|
|||
//Second section: Optimization settings
|
||||
if(lastSettingsUpdateTime != settingsTime){
|
||||
//Get JSON for components
|
||||
$.get("/arbiter/config",function(data){
|
||||
var settingsUrl = multiSession ? "/arbiter/" + currSession + "/config" : "/arbiter/config";
|
||||
$.get(settingsUrl,function(data){
|
||||
var str = JSON.stringify(data);
|
||||
|
||||
var configDiv = $('#settingsdiv');
|
||||
|
@ -277,9 +330,9 @@
|
|||
|
||||
//Third section: Summary results table (summary info for each candidate)
|
||||
if(lastResultsUpdateTime != resultsTime){
|
||||
|
||||
//Get JSON; address set by SummaryResultsResource
|
||||
$.get("/arbiter/results",function(data){
|
||||
//Get JSON for results table
|
||||
var resultsUrl = multiSession ? "/arbiter/" + currSession + "/results" : "/arbiter/results";
|
||||
$.get(resultsUrl,function(data){
|
||||
//Expect an array of CandidateInfo type objects here
|
||||
resultsTableContent = data;
|
||||
drawResultTable();
|
||||
|
@ -291,7 +344,10 @@
|
|||
//Finally: Currently selected result
|
||||
if(selectedCandidateIdx != null){
|
||||
//Get JSON for components
|
||||
$.get("/arbiter/candidateInfo/"+selectedCandidateIdx,function(data){
|
||||
var candidateInfoUrl = multiSession
|
||||
? "/arbiter/" + currSession + "/candidateInfo/" + selectedCandidateIdx
|
||||
: "/arbiter/candidateInfo/" + selectedCandidateIdx;
|
||||
$.get(candidateInfoUrl,function(data){
|
||||
var str = JSON.stringify(data);
|
||||
|
||||
var resultsViewDiv = $('#resultsviewdiv');
|
||||
|
@ -302,6 +358,7 @@
|
|||
});
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
function createTable(tableObj,tableId,appendTo){
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
|
||||
package org.deeplearning4j.arbiter.optimize;
|
||||
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.core.storage.StatsStorage;
|
||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||
|
@ -54,6 +57,7 @@ import org.deeplearning4j.ui.api.UIServer;
|
|||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.common.function.Function;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -62,52 +66,43 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
import java.io.IOException;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
import java.net.URLEncoder;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
/**
|
||||
* Created by Alex on 19/07/2017.
|
||||
*/
|
||||
@Slf4j
|
||||
public class TestBasic extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 3600_000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testBasicUiOnly() throws Exception {
|
||||
|
||||
UIServer.getInstance();
|
||||
|
||||
Thread.sleep(1000000);
|
||||
Thread.sleep(1000_000);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testBasicMnist() throws Exception {
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
|
||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
||||
.addLayer(
|
||||
new ConvolutionLayerSpace.Builder().nIn(1)
|
||||
.nOut(new IntegerParameterSpace(5, 30))
|
||||
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
||||
new int[]{4, 4}, new int[]{5, 5}))
|
||||
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
||||
new int[]{2, 2}))
|
||||
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
||||
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
||||
.build())
|
||||
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
||||
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
||||
.build();
|
||||
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||
Map<String, Object> commands = new HashMap<>();
|
||||
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
||||
|
||||
|
@ -144,7 +139,30 @@ public class TestBasic extends BaseDL4JTest {
|
|||
UIServer.getInstance().attach(ss);
|
||||
|
||||
runner.execute();
|
||||
Thread.sleep(100000);
|
||||
Thread.sleep(1000_000);
|
||||
}
|
||||
|
||||
private static MultiLayerSpace getMultiLayerSpaceMnist() {
|
||||
return new MultiLayerSpace.Builder()
|
||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
||||
.addLayer(
|
||||
new ConvolutionLayerSpace.Builder().nIn(1)
|
||||
.nOut(new IntegerParameterSpace(5, 30))
|
||||
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
||||
new int[]{4, 4}, new int[]{5, 5}))
|
||||
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
||||
new int[]{2, 2}))
|
||||
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
||||
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
||||
.build())
|
||||
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
||||
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -233,7 +251,7 @@ public class TestBasic extends BaseDL4JTest {
|
|||
.build();
|
||||
|
||||
//Define configuration:
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP);
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
|
||||
DataProvider dataProvider = new MnistDataSetProvider();
|
||||
|
||||
|
||||
|
@ -331,7 +349,7 @@ public class TestBasic extends BaseDL4JTest {
|
|||
UIServer.getInstance().attach(ss);
|
||||
|
||||
runner.execute();
|
||||
Thread.sleep(100000);
|
||||
Thread.sleep(1000_000);
|
||||
}
|
||||
|
||||
|
||||
|
@ -396,7 +414,7 @@ public class TestBasic extends BaseDL4JTest {
|
|||
UIServer.getInstance().attach(ss);
|
||||
|
||||
runner.execute();
|
||||
Thread.sleep(100000);
|
||||
Thread.sleep(1000_000);
|
||||
}
|
||||
|
||||
|
||||
|
@ -433,7 +451,7 @@ public class TestBasic extends BaseDL4JTest {
|
|||
.build();
|
||||
|
||||
//Define configuration:
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP);
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
|
||||
DataProvider dataProvider = new MnistDataSetProvider();
|
||||
|
||||
|
||||
|
@ -465,13 +483,17 @@ public class TestBasic extends BaseDL4JTest {
|
|||
UIServer.getInstance().attach(ss);
|
||||
|
||||
runner.execute();
|
||||
Thread.sleep(100000);
|
||||
Thread.sleep(1000_000);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Visualize multiple optimization sessions run one after another on single-session mode UI
|
||||
* @throws InterruptedException if current thread has been interrupted
|
||||
*/
|
||||
@Test
|
||||
@Ignore
|
||||
public void testBasicMnistMultipleSessions() throws Exception {
|
||||
public void testBasicMnistMultipleSessions() throws InterruptedException {
|
||||
|
||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||
|
@ -499,8 +521,10 @@ public class TestBasic extends BaseDL4JTest {
|
|||
|
||||
//Define configuration:
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
||||
DataProvider dataProvider = new MnistDataSetProvider();
|
||||
|
||||
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||
Properties dsp = new Properties();
|
||||
dsp.setProperty("minibatch", "8");
|
||||
|
||||
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
|
||||
|
||||
|
@ -513,7 +537,7 @@ public class TestBasic extends BaseDL4JTest {
|
|||
|
||||
OptimizationConfiguration configuration =
|
||||
new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||
.modelSaver(new FileModelSaver(modelSavePath))
|
||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
||||
|
@ -535,7 +559,7 @@ public class TestBasic extends BaseDL4JTest {
|
|||
|
||||
candidateGenerator = new RandomSearchGenerator(mls, commands);
|
||||
configuration = new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||
.modelSaver(new FileModelSaver(modelSavePath))
|
||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
||||
|
@ -550,7 +574,148 @@ public class TestBasic extends BaseDL4JTest {
|
|||
|
||||
runner.execute();
|
||||
|
||||
Thread.sleep(100000);
|
||||
Thread.sleep(1000_000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Auto-attach multiple optimization sessions to multi-session mode UI
|
||||
* @throws IOException if could not connect to the server
|
||||
*/
|
||||
@Test
|
||||
public void testUiMultiSessionAutoAttach() throws IOException {
|
||||
|
||||
//Define configuration:
|
||||
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||
Properties dsp = new Properties();
|
||||
dsp.setProperty("minibatch", "8");
|
||||
|
||||
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\")
|
||||
.getAbsolutePath();
|
||||
|
||||
File f = new File(modelSavePath);
|
||||
if (f.exists())
|
||||
f.delete();
|
||||
f.mkdir();
|
||||
if (!f.exists())
|
||||
throw new RuntimeException();
|
||||
|
||||
OptimizationConfiguration configuration =
|
||||
new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||
.modelSaver(new FileModelSaver(modelSavePath))
|
||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS),
|
||||
new MaxCandidatesCondition(1))
|
||||
.build();
|
||||
|
||||
IOptimizationRunner runner =
|
||||
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
||||
|
||||
// add 3 different sessions to the same execution
|
||||
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
StatsStorage ss = new InMemoryStatsStorage();
|
||||
@NonNull String sessionId = "sid" + i;
|
||||
statsStorageForSession.put(sessionId, ss);
|
||||
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
||||
runner.addListeners(sl);
|
||||
}
|
||||
|
||||
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
||||
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
||||
String serverAddress = uIServer.getAddress();
|
||||
|
||||
runner.execute();
|
||||
|
||||
for (String sessionId : statsStorageForSession.keySet()) {
|
||||
/*
|
||||
* Visiting /arbiter/:sessionId to auto-attach StatsStorage
|
||||
*/
|
||||
String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId);
|
||||
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
|
||||
conn.connect();
|
||||
|
||||
log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId));
|
||||
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||
assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId)));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL
|
||||
* @throws Exception if an error occurred
|
||||
*/
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUiMultiSessionManualAttach() throws Exception {
|
||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||
|
||||
//Define configuration:
|
||||
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||
|
||||
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||
Properties dsp = new Properties();
|
||||
dsp.setProperty("minibatch", "8");
|
||||
|
||||
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\")
|
||||
.getAbsolutePath();
|
||||
|
||||
File f = new File(modelSavePath);
|
||||
if (f.exists())
|
||||
f.delete();
|
||||
f.mkdir();
|
||||
if (!f.exists())
|
||||
throw new RuntimeException();
|
||||
|
||||
OptimizationConfiguration configuration =
|
||||
new OptimizationConfiguration.Builder()
|
||||
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||
.modelSaver(new FileModelSaver(modelSavePath))
|
||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||
.terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES),
|
||||
new MaxCandidatesCondition(10))
|
||||
.build();
|
||||
|
||||
|
||||
// parallel execution of multiple optimization sessions
|
||||
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
||||
for (int i = 0; i < 3; i++) {
|
||||
String sessionId = "sid" + i;
|
||||
IOptimizationRunner runner =
|
||||
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
||||
StatsStorage ss = new InMemoryStatsStorage();
|
||||
statsStorageForSession.put(sessionId, ss);
|
||||
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
||||
runner.addListeners(sl);
|
||||
// Asynchronous execution
|
||||
new Thread(runner::execute).start();
|
||||
}
|
||||
|
||||
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
||||
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
||||
String serverAddress = uIServer.getAddress();
|
||||
|
||||
for (String sessionId : statsStorageForSession.keySet()) {
|
||||
log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId));
|
||||
}
|
||||
|
||||
Thread.sleep(1000_000);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get URL for arbiter session on given server address
|
||||
* @param serverAddress server address, e.g.: http://localhost:9000
|
||||
* @param sessionId session ID (will be URL-encoded)
|
||||
* @return URL
|
||||
* @throws UnsupportedEncodingException if the character encoding is not supported
|
||||
*/
|
||||
private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
|
||||
return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
||||
}
|
||||
|
||||
private static class MnistDataSetProvider implements DataProvider {
|
||||
|
|
|
@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
private static Integer instancePort;
|
||||
private static Thread autoStopThread;
|
||||
|
||||
private TrainModule trainModule;
|
||||
|
||||
/**
|
||||
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||
* @param port TCP socket port for {@link HttpServer} to listen
|
||||
|
@ -194,8 +192,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
VertxUIServer.autoStopThread = new Thread(() -> {
|
||||
try {
|
||||
currentThread.join();
|
||||
log.info("Deeplearning4j UI server is auto-stopping.");
|
||||
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
|
||||
log.info("Deeplearning4j UI server is auto-stopping after thread (name: {}) died.",
|
||||
currentThread.getName());
|
||||
instance.stop();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
|
@ -207,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
|
||||
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
||||
private RemoteReceiverModule remoteReceiverModule;
|
||||
private StatsStorageLoader statsStorageLoader;
|
||||
/**
|
||||
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
||||
*/
|
||||
@Getter
|
||||
private Function<String, Boolean> statsStorageLoader;
|
||||
|
||||
//typeIDModuleMap: Records which modules are registered for which type IDs
|
||||
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
|
||||
|
@ -247,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
*/
|
||||
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
||||
if (statsStorageProvider != null) {
|
||||
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
||||
if (trainModule != null) {
|
||||
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
||||
this.statsStorageLoader = (sessionId) -> {
|
||||
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
||||
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
||||
if (statsStorage != null) {
|
||||
if (statsStorage.sessionExists(sessionId)) {
|
||||
attach(statsStorage);
|
||||
return true;
|
||||
}
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
||||
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
||||
return false;
|
||||
} else {
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
||||
"StatsStorageProvider returned null.");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -302,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
|
||||
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
||||
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
||||
uiModules.add(trainModule);
|
||||
uiModules.add(new TrainModule());
|
||||
uiModules.add(new ConvolutionalListenerModule());
|
||||
uiModules.add(new TsneModule());
|
||||
uiModules.add(new SameDiffModule());
|
||||
|
@ -596,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
||||
*/
|
||||
private class StatsStorageLoader implements Function<String, Boolean> {
|
||||
|
||||
Function<String, StatsStorage> statsStorageProvider;
|
||||
|
||||
StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) {
|
||||
this.statsStorageProvider = statsStorageProvider;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean apply(String sessionId) {
|
||||
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
||||
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
||||
if (statsStorage != null) {
|
||||
if (statsStorage.sessionExists(sessionId)) {
|
||||
attach(statsStorage);
|
||||
return true;
|
||||
}
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
||||
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
||||
return false;
|
||||
} else {
|
||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
||||
"StatsStorageProvider returned null.");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//==================================================================================================================
|
||||
// CLI Launcher
|
||||
|
||||
|
|
|
@ -157,8 +157,9 @@ public interface UIServer {
|
|||
|
||||
/**
|
||||
* Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
|
||||
* @throws InterruptedException if the current thread is interrupted while waiting
|
||||
*/
|
||||
void stop() throws Exception;
|
||||
void stop() throws InterruptedException;
|
||||
|
||||
/**
|
||||
* Stop/shut down the UI server.
|
||||
|
|
|
@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext;
|
|||
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
|
@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
|||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||
import org.deeplearning4j.ui.VertxUIServer;
|
||||
import org.deeplearning4j.ui.api.HttpMethod;
|
||||
import org.deeplearning4j.ui.api.I18N;
|
||||
import org.deeplearning4j.ui.api.Route;
|
||||
|
@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
|
|||
import org.deeplearning4j.ui.model.stats.api.StatsReport;
|
||||
import org.deeplearning4j.ui.model.stats.api.StatsType;
|
||||
import org.nd4j.common.function.Function;
|
||||
import org.nd4j.common.function.Supplier;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
@ -86,8 +84,6 @@ public class TrainModule implements UIModule {
|
|||
private static final DecimalFormat df2 = new DecimalFormat("#.00");
|
||||
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
||||
|
||||
private final Supplier<String> addressSupplier;
|
||||
|
||||
private enum ModelType {
|
||||
MLN, CG, Layer
|
||||
}
|
||||
|
@ -99,29 +95,14 @@ public class TrainModule implements UIModule {
|
|||
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
|
||||
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
|
||||
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
|
||||
private final boolean multiSession;
|
||||
@Getter @Setter
|
||||
private Function<String, Boolean> sessionLoader;
|
||||
|
||||
|
||||
private final Configuration configuration;
|
||||
|
||||
public TrainModule() {
|
||||
this(false, null, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* TrainModule
|
||||
*
|
||||
* @param multiSession multi-session mode
|
||||
* @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter
|
||||
* in multi-session mode
|
||||
* @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
|
||||
*/
|
||||
public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) {
|
||||
this.multiSession = multiSession;
|
||||
this.sessionLoader = sessionLoader;
|
||||
this.addressSupplier = addressSupplier;
|
||||
public TrainModule() {
|
||||
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
|
||||
int value = DEFAULT_MAX_CHART_POINTS;
|
||||
if (maxChartPointsProp != null) {
|
||||
|
@ -159,8 +140,9 @@ public class TrainModule implements UIModule {
|
|||
@Override
|
||||
public List<Route> getRoutes() {
|
||||
List<Route> r = new ArrayList<>();
|
||||
r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
|
||||
if (multiSession) {
|
||||
r.add(new Route("/train/multisession", HttpMethod.GET,
|
||||
(path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
||||
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
|
||||
rc.response()
|
||||
|
@ -264,7 +246,9 @@ public class TrainModule implements UIModule {
|
|||
if (!knownSessionIDs.isEmpty()) {
|
||||
sb.append(" <ul>");
|
||||
for (String sessionId : knownSessionIDs.keySet()) {
|
||||
sb.append(" <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n");
|
||||
sb.append(" <li><a href=\"/train/")
|
||||
.append(sessionId).append("\">")
|
||||
.append(sessionId).append("</a></li>\n");
|
||||
}
|
||||
sb.append(" </ul>");
|
||||
} else {
|
||||
|
@ -284,9 +268,11 @@ public class TrainModule implements UIModule {
|
|||
*
|
||||
* @param sessionId session ID to look fo with provider
|
||||
* @param targetPath one of overview / model / system, or null
|
||||
* @param rc routing context
|
||||
*/
|
||||
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
||||
if (sessionLoader != null && sessionLoader.apply(sessionId)) {
|
||||
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
|
||||
if (loader != null && loader.apply(sessionId)) {
|
||||
if (targetPath != null) {
|
||||
rc.reroute(targetPath);
|
||||
} else {
|
||||
|
@ -306,9 +292,9 @@ public class TrainModule implements UIModule {
|
|||
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
|
||||
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
|
||||
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
|
||||
if (multiSession) {
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||
addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage());
|
||||
VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -332,9 +318,9 @@ public class TrainModule implements UIModule {
|
|||
if (!StatsListener.TYPE_ID.equals(typeID))
|
||||
continue;
|
||||
knownSessionIDs.put(sessionID, statsStorage);
|
||||
if (multiSession) {
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||
addressSupplier.get(), sessionID, statsStorage);
|
||||
VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
|
||||
}
|
||||
|
||||
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
|
||||
|
@ -364,9 +350,9 @@ public class TrainModule implements UIModule {
|
|||
}
|
||||
for (String s : toRemove) {
|
||||
knownSessionIDs.remove(s);
|
||||
if (multiSession) {
|
||||
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
|
||||
addressSupplier.get(), s, statsStorage);
|
||||
VertxUIServer.getInstance().getAddress(), s, statsStorage);
|
||||
}
|
||||
lastUpdateForSession.remove(s);
|
||||
}
|
||||
|
@ -602,13 +588,13 @@ public class TrainModule implements UIModule {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session
|
||||
* Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
|
||||
*
|
||||
* @param sessionId session ID
|
||||
* @return {@link I18N} instance
|
||||
*/
|
||||
private I18N getI18N(String sessionId) {
|
||||
return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
||||
return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -259,7 +259,7 @@ public class TestVertxUIManual extends BaseDL4JTest {
|
|||
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
|
||||
sessionId, autoDetachTimeoutMillis);
|
||||
uIServer.detach(statsStorage);
|
||||
log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}",
|
||||
log.info(" To re-attach StatsStorage of training session, visit {}/train/{}",
|
||||
uIServer.getAddress(), sessionId);
|
||||
}
|
||||
}).start();
|
||||
|
|
|
@ -197,9 +197,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get URL-encoded URL for training session on given server address
|
||||
* Get URL for training session on given server address
|
||||
* @param serverAddress server address
|
||||
* @param sessionId session ID
|
||||
* @param sessionId session ID (will be URL-encoded)
|
||||
* @return URL
|
||||
* @throws UnsupportedEncodingException if the used encoding is not supported
|
||||
*/
|
||||
|
|
Loading…
Reference in New Issue