Merge pull request #8872 from printomi/ui_multisession_arbiter
DL4J UI: Add multi-session support for Arbitermaster
commit
a76f957b72
|
@ -36,6 +36,7 @@ import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
|
||||||
import org.deeplearning4j.arbiter.ui.misc.UIUtils;
|
import org.deeplearning4j.arbiter.ui.misc.UIUtils;
|
||||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
|
import org.deeplearning4j.ui.VertxUIServer;
|
||||||
import org.deeplearning4j.ui.api.Component;
|
import org.deeplearning4j.ui.api.Component;
|
||||||
import org.deeplearning4j.ui.api.*;
|
import org.deeplearning4j.ui.api.*;
|
||||||
import org.deeplearning4j.ui.components.chart.ChartLine;
|
import org.deeplearning4j.ui.components.chart.ChartLine;
|
||||||
|
@ -50,6 +51,7 @@ import org.deeplearning4j.ui.components.text.style.StyleText;
|
||||||
import org.deeplearning4j.ui.i18n.I18NResource;
|
import org.deeplearning4j.ui.i18n.I18NResource;
|
||||||
import org.joda.time.format.DateTimeFormat;
|
import org.joda.time.format.DateTimeFormat;
|
||||||
import org.joda.time.format.DateTimeFormatter;
|
import org.joda.time.format.DateTimeFormatter;
|
||||||
|
import org.nd4j.common.function.Function;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
|
||||||
|
@ -77,7 +79,6 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>());
|
private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>());
|
||||||
|
|
||||||
|
|
||||||
//Styles for UI:
|
//Styles for UI:
|
||||||
private static final StyleTable STYLE_TABLE = new StyleTable.Builder()
|
private static final StyleTable STYLE_TABLE = new StyleTable.Builder()
|
||||||
.width(100, LengthUnit.Percent)
|
.width(100, LengthUnit.Percent)
|
||||||
|
@ -134,20 +135,135 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Route> getRoutes() {
|
public List<Route> getRoutes() {
|
||||||
Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
|
boolean multiSession = VertxUIServer.getMultiSession().get();
|
||||||
.putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html"));
|
List<Route> r = new ArrayList<>();
|
||||||
Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc));
|
r.add(new Route("/arbiter/multisession", HttpMethod.GET,
|
||||||
Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc));
|
(path, rc) -> rc.response().end(multiSession ? "true" : "false")));
|
||||||
Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc));
|
if (multiSession) {
|
||||||
Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc));
|
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
||||||
Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc));
|
r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> {
|
||||||
Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc));
|
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||||
|
rc.response()
|
||||||
|
.putHeader("content-type", "text/html; charset=utf-8")
|
||||||
|
.sendFile("templates/ArbiterUI.html");
|
||||||
|
} else {
|
||||||
|
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc));
|
r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> {
|
||||||
Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc));
|
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||||
Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc));
|
this.getLastUpdateTime(path.get(0), rc);
|
||||||
|
} else {
|
||||||
|
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> {
|
||||||
|
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||||
|
this.getCandidateInfo(path.get(0), path.get(1), rc);
|
||||||
|
} else {
|
||||||
|
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> {
|
||||||
|
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||||
|
this.getOptimizationConfig(path.get(0), rc);
|
||||||
|
} else {
|
||||||
|
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> {
|
||||||
|
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||||
|
this.getSummaryResults(path.get(0), rc);
|
||||||
|
} else {
|
||||||
|
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> {
|
||||||
|
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||||
|
this.getSummaryStatus(path.get(0), rc);
|
||||||
|
} else {
|
||||||
|
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
|
||||||
|
.putHeader("content-type", "text/html; charset=utf-8")
|
||||||
|
.sendFile("templates/ArbiterUI.html")));
|
||||||
|
r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc)));
|
||||||
|
r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET,
|
||||||
|
(path, rc) -> this.getCandidateInfo(null, path.get(0), rc)));
|
||||||
|
r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc)));
|
||||||
|
r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc)));
|
||||||
|
r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc)));
|
||||||
|
|
||||||
return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c);
|
r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)));
|
||||||
|
r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET,
|
||||||
|
(path, rc) -> this.setSession(path.get(0), rc)));
|
||||||
|
}
|
||||||
|
// common for single- and multi-session mode
|
||||||
|
r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load StatsStorage via provider, or return "not found"
|
||||||
|
*
|
||||||
|
* @param sessionId session ID to look fo with provider
|
||||||
|
* @param targetPath one of overview / model / system, or null
|
||||||
|
* @param rc routing context
|
||||||
|
*/
|
||||||
|
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
||||||
|
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
|
||||||
|
if (loader != null && loader.apply(sessionId)) {
|
||||||
|
if (targetPath != null) {
|
||||||
|
rc.reroute(targetPath);
|
||||||
|
} else {
|
||||||
|
rc.response().end();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
|
||||||
|
.end("Unknown session ID: " + sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List optimization sessions. Returns a HTML list of arbiter sessions
|
||||||
|
*/
|
||||||
|
private synchronized void listSessions(RoutingContext rc) {
|
||||||
|
StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" +
|
||||||
|
"<html lang=\"en\">\n" +
|
||||||
|
"<head>\n" +
|
||||||
|
" <meta charset=\"utf-8\">\n" +
|
||||||
|
" <title>Optimization sessions - DL4J Arbiter UI</title>\n" +
|
||||||
|
" </head>\n" +
|
||||||
|
"\n" +
|
||||||
|
" <body>\n" +
|
||||||
|
" <h1>DL4J Arbiter UI</h1>\n" +
|
||||||
|
" <p>UI server is in multi-session mode." +
|
||||||
|
" To visualize an optimization session, please select one from the following list.</p>\n" +
|
||||||
|
" <h2>List of attached optimization sessions</h2>\n");
|
||||||
|
if (!knownSessionIDs.isEmpty()) {
|
||||||
|
sb.append(" <ul>");
|
||||||
|
for (String sessionId : knownSessionIDs.keySet()) {
|
||||||
|
sb.append(" <li><a href=\"/arbiter/")
|
||||||
|
.append(sessionId).append("\">")
|
||||||
|
.append(sessionId).append("</a></li>\n");
|
||||||
|
}
|
||||||
|
sb.append(" </ul>");
|
||||||
|
} else {
|
||||||
|
sb.append("No optimization session attached.");
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.append(" </body>\n" +
|
||||||
|
"</html>\n");
|
||||||
|
|
||||||
|
rc.response()
|
||||||
|
.putHeader("content-type", "text/html; charset=utf-8")
|
||||||
|
.end(sb.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -201,7 +317,7 @@ public class ArbiterModule implements UIModule {
|
||||||
.end(asJson(sid));
|
.end(asJson(sid));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void listSessions(RoutingContext rc) {
|
private void sessionInfo(RoutingContext rc) {
|
||||||
rc.response()
|
rc.response()
|
||||||
.putHeader("content-type", "application/json")
|
.putHeader("content-type", "application/json")
|
||||||
.end(asJson(knownSessionIDs.keySet()));
|
.end(asJson(knownSessionIDs.keySet()));
|
||||||
|
@ -257,10 +373,25 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the last update time for the page
|
* Return the last update time for the page
|
||||||
|
* @param sessionId session ID (optional, for multi-session mode)
|
||||||
|
* @param rc routing context
|
||||||
*/
|
*/
|
||||||
private void getLastUpdateTime(RoutingContext rc){
|
private void getLastUpdateTime(String sessionId, RoutingContext rc){
|
||||||
//TODO - this forces updates on every request... which is fine, just inefficient
|
if (sessionId == null) {
|
||||||
long t = System.currentTimeMillis();
|
sessionId = currentSessionID;
|
||||||
|
}
|
||||||
|
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||||
|
List<Persistable> latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID);
|
||||||
|
long t = 0;
|
||||||
|
if (latestUpdates.isEmpty()) {
|
||||||
|
t = System.currentTimeMillis();
|
||||||
|
} else {
|
||||||
|
for (Persistable update : latestUpdates) {
|
||||||
|
if (update.getTimeStamp() > t) {
|
||||||
|
t = update.getTimeStamp();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
UpdateStatus us = new UpdateStatus(t, t, t);
|
UpdateStatus us = new UpdateStatus(t, t, t);
|
||||||
|
|
||||||
rc.response().putHeader("content-type", "application/json").end(asJson(us));
|
rc.response().putHeader("content-type", "application/json").end(asJson(us));
|
||||||
|
@ -274,56 +405,28 @@ public class ArbiterModule implements UIModule {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the last update time for the specified model IDs
|
|
||||||
* @param modelIDs Model IDs to get the update time for
|
|
||||||
*/
|
|
||||||
private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){
|
|
||||||
if(currentSessionID == null){
|
|
||||||
rc.response().end();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
|
||||||
if(ss == null){
|
|
||||||
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
|
|
||||||
rc.response().end("-1");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
String[] split = modelIDs.split(",");
|
|
||||||
|
|
||||||
long[] lastUpdateTimes = new long[split.length];
|
|
||||||
for( int i=0; i<split.length; i++ ){
|
|
||||||
String s = split[i];
|
|
||||||
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, s);
|
|
||||||
if(p != null){
|
|
||||||
lastUpdateTimes[i] = p.getTimeStamp();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rc.response().putHeader("content-type", "application/json").end(asJson(lastUpdateTimes));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the info for a specific candidate - last section in the UI
|
* Get the info for a specific candidate - last section in the UI
|
||||||
*
|
* @param sessionId session ID (optional, for multi-session mode)
|
||||||
* @param candidateId ID for the candidate
|
* @param candidateId ID for the candidate
|
||||||
* @return Content/info for the candidate
|
* @param rc routing context
|
||||||
*/
|
*/
|
||||||
private void getCandidateInfo(String candidateId, RoutingContext rc){
|
private void getCandidateInfo(String sessionId, String candidateId, RoutingContext rc){
|
||||||
|
if (sessionId == null) {
|
||||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
sessionId = currentSessionID;
|
||||||
|
}
|
||||||
|
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||||
if(ss == null){
|
if(ss == null){
|
||||||
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID);
|
log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", sessionId);
|
||||||
rc.response().end();
|
rc.response().end();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);;
|
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss
|
||||||
|
.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||||
OptimizationConfiguration oc = gcp.getOptimizationConfiguration();
|
OptimizationConfiguration oc = gcp.getOptimizationConfiguration();
|
||||||
|
|
||||||
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId);
|
Persistable p = ss.getLatestUpdate(sessionId, ARBITER_UI_TYPE_ID, candidateId);
|
||||||
if(p == null){
|
if(p == null){
|
||||||
String title = "No results found for model " + candidateId + ".";
|
String title = "No results found for model " + candidateId + ".";
|
||||||
ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build();
|
ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build();
|
||||||
|
@ -493,17 +596,21 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the optimization configuration - second section in the page
|
* Get the optimization configuration - second section in the page
|
||||||
|
* @param sessionId session ID (optional, for multi-session mode)
|
||||||
|
* @param rc routing context
|
||||||
*/
|
*/
|
||||||
private void getOptimizationConfig(RoutingContext rc){
|
private void getOptimizationConfig(String sessionId, RoutingContext rc){
|
||||||
|
if (sessionId == null) {
|
||||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
sessionId = currentSessionID;
|
||||||
|
}
|
||||||
|
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||||
if(ss == null){
|
if(ss == null){
|
||||||
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID);
|
log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
|
||||||
rc.response().end();
|
rc.response().end();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||||
|
|
||||||
if(p == null){
|
if(p == null){
|
||||||
log.debug("No static info");
|
log.debug("No static info");
|
||||||
|
@ -593,15 +700,23 @@ public class ArbiterModule implements UIModule {
|
||||||
rc.response().putHeader("content-type", "application/json").end(asJson(cd));
|
rc.response().putHeader("content-type", "application/json").end(asJson(cd));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void getSummaryResults(RoutingContext rc){
|
/**
|
||||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
* Get candidates summary results list - third section on the page: Results table
|
||||||
|
* @param sessionId session ID (optional, for multi-session mode)
|
||||||
|
* @param rc routing context
|
||||||
|
*/
|
||||||
|
private void getSummaryResults(String sessionId, RoutingContext rc){
|
||||||
|
if (sessionId == null) {
|
||||||
|
sessionId = currentSessionID;
|
||||||
|
}
|
||||||
|
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||||
if(ss == null){
|
if(ss == null){
|
||||||
log.debug("getSummaryResults(): Session ID is unknown: {}", currentSessionID);
|
log.debug("getSummaryResults(): Session ID is unknown: {}", sessionId);
|
||||||
rc.response().end();
|
rc.response().end();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID));
|
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
|
||||||
List<String[]> table = new ArrayList<>();
|
List<String[]> table = new ArrayList<>();
|
||||||
for(Persistable per : allModelInfoTemp){
|
for(Persistable per : allModelInfoTemp){
|
||||||
ModelInfoPersistable mip = (ModelInfoPersistable)per;
|
ModelInfoPersistable mip = (ModelInfoPersistable)per;
|
||||||
|
@ -614,16 +729,21 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get summary status information: first section in the page
|
* Get summary status information: first section in the page
|
||||||
|
* @param sessionId session ID (optional, for multi-session mode)
|
||||||
|
* @param rc routing context
|
||||||
*/
|
*/
|
||||||
private void getSummaryStatus(RoutingContext rc){
|
private void getSummaryStatus(String sessionId, RoutingContext rc){
|
||||||
StatsStorage ss = knownSessionIDs.get(currentSessionID);
|
if (sessionId == null) {
|
||||||
|
sessionId = currentSessionID;
|
||||||
|
}
|
||||||
|
StatsStorage ss = knownSessionIDs.get(sessionId);
|
||||||
if(ss == null){
|
if(ss == null){
|
||||||
log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID);
|
log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId);
|
||||||
rc.response().end();
|
rc.response().end();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||||
|
|
||||||
if(p == null){
|
if(p == null){
|
||||||
log.info("No static info");
|
log.info("No static info");
|
||||||
|
@ -643,7 +763,7 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
//How to get this? query all model infos...
|
//How to get this? query all model infos...
|
||||||
|
|
||||||
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID));
|
List<Persistable> allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID));
|
||||||
List<ModelInfoPersistable> allModelInfo = new ArrayList<>();
|
List<ModelInfoPersistable> allModelInfo = new ArrayList<>();
|
||||||
for(Persistable per : allModelInfoTemp){
|
for(Persistable per : allModelInfoTemp){
|
||||||
ModelInfoPersistable mip = (ModelInfoPersistable)per;
|
ModelInfoPersistable mip = (ModelInfoPersistable)per;
|
||||||
|
@ -668,7 +788,6 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
//TODO: I18N
|
//TODO: I18N
|
||||||
|
|
||||||
//TODO don't use currentTimeMillis due to stored data??
|
|
||||||
long bestTime;
|
long bestTime;
|
||||||
Double bestScore = null;
|
Double bestScore = null;
|
||||||
String bestModelString = null;
|
String bestModelString = null;
|
||||||
|
@ -685,7 +804,12 @@ public class ArbiterModule implements UIModule {
|
||||||
String execTotalRuntimeStr = "";
|
String execTotalRuntimeStr = "";
|
||||||
if(execStartTime > 0){
|
if(execStartTime > 0){
|
||||||
execStartTimeStr = TIME_FORMATTER.print(execStartTime);
|
execStartTimeStr = TIME_FORMATTER.print(execStartTime);
|
||||||
execTotalRuntimeStr = UIUtils.formatDuration(System.currentTimeMillis() - execStartTime);
|
// allModelInfo is sorted by Persistable::getTimeStamp
|
||||||
|
long lastCompleteTime = execStartTime;
|
||||||
|
if (!allModelInfo.isEmpty()) {
|
||||||
|
lastCompleteTime = allModelInfo.get(allModelInfo.size() - 1).getTimeStamp();
|
||||||
|
}
|
||||||
|
execTotalRuntimeStr = UIUtils.formatDuration(lastCompleteTime - execStartTime);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -198,109 +198,166 @@
|
||||||
|
|
||||||
var selectedCandidateIdx = null;
|
var selectedCandidateIdx = null;
|
||||||
|
|
||||||
|
//Multi-session mode
|
||||||
|
var multiSession = null;
|
||||||
|
//Session selection
|
||||||
|
var currSession = "";
|
||||||
|
|
||||||
|
function getSessionIdFromUrl() {
|
||||||
|
// path is like /arbiter/:sessionId/overview
|
||||||
|
var sessionIdRegexp = /\/arbiter\/([^\/]+)/g;
|
||||||
|
var match = sessionIdRegexp.exec(window.location.pathname)
|
||||||
|
return match[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
function getCurrSession(callback) {
|
||||||
|
if (multiSession) {
|
||||||
|
if (currSession == "") {
|
||||||
|
// get only once
|
||||||
|
currSession = getSessionIdFromUrl();
|
||||||
|
}
|
||||||
|
//we don't show session selector in multi-session mode (one can list sessions at /arbiter)
|
||||||
|
callback();
|
||||||
|
} else {
|
||||||
|
$.ajax({
|
||||||
|
url: "/arbiter/sessions/current",
|
||||||
|
async: true,
|
||||||
|
error: function (query, status, error) {
|
||||||
|
console.log("Error getting data: " + error);
|
||||||
|
},
|
||||||
|
success: function (data) {
|
||||||
|
currSession = data;
|
||||||
|
console.log("Current session: " + currSession);
|
||||||
|
|
||||||
|
//Update available sessions in session selector
|
||||||
|
$.get("/arbiter/sessions/all", function(data){
|
||||||
|
var keys = data; // JSON.stringify(data);
|
||||||
|
|
||||||
|
if(keys.length > 1){
|
||||||
|
$("#sessionSelectDiv").show();
|
||||||
|
|
||||||
|
var elem = $("#sessionSelect");
|
||||||
|
elem.empty();
|
||||||
|
|
||||||
|
var currSelectedIdx = 0;
|
||||||
|
for (var i = 0; i < keys.length; i++) {
|
||||||
|
if(keys[i] == currSession){
|
||||||
|
currSelectedIdx = i;
|
||||||
|
}
|
||||||
|
elem.append("<option value='" + keys[i] + "'>" + keys[i] + "</option>");
|
||||||
|
}
|
||||||
|
|
||||||
|
$("#sessionSelect option[value='" + keys[currSelectedIdx] +"']").attr("selected", "selected");
|
||||||
|
$("#sessionSelectDiv").show();
|
||||||
|
}
|
||||||
|
// console.log("Got sessions: " + keys + ", current: " + currSession);
|
||||||
|
callback();
|
||||||
|
});
|
||||||
|
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getSessionSettings(callback) {
|
||||||
|
// load only once
|
||||||
|
if (multiSession != null) {
|
||||||
|
getCurrSession(callback);
|
||||||
|
} else {
|
||||||
|
$.ajax({
|
||||||
|
url: "/arbiter/multisession",
|
||||||
|
async: true,
|
||||||
|
error: function (query, status, error) {
|
||||||
|
console.log("Error getting data: " + error);
|
||||||
|
},
|
||||||
|
success: function (data) {
|
||||||
|
multiSession = data == "true";
|
||||||
|
getCurrSession(callback);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//Initial update
|
||||||
|
doUpdate();
|
||||||
//Set basic interval function to do updates
|
//Set basic interval function to do updates
|
||||||
setInterval(doUpdate,5000); //Loop every 5 seconds
|
setInterval(doUpdate,5000); //Loop every 5 seconds
|
||||||
|
|
||||||
|
|
||||||
function doUpdate(){
|
function doUpdate(){
|
||||||
//Get the update status, and do something with it:
|
//Get the update status, and do something with it:
|
||||||
$.get("/arbiter/lastUpdate",function(data){
|
getSessionSettings(function(){
|
||||||
//Encoding: matches names in UpdateStatus class
|
var sessionUpdateUrl = multiSession ? "/arbiter/" + currSession + "/lastUpdate" : "/arbiter/lastUpdate";
|
||||||
var jsonObj = JSON.parse(JSON.stringify(data));
|
$.get(sessionUpdateUrl,function(data){
|
||||||
var statusTime = jsonObj['statusUpdateTime'];
|
//Encoding: matches names in UpdateStatus class
|
||||||
var settingsTime = jsonObj['settingsUpdateTime'];
|
var jsonObj = JSON.parse(JSON.stringify(data));
|
||||||
var resultsTime = jsonObj['resultsUpdateTime'];
|
var statusTime = jsonObj['statusUpdateTime'];
|
||||||
//console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime);
|
var settingsTime = jsonObj['settingsUpdateTime'];
|
||||||
|
var resultsTime = jsonObj['resultsUpdateTime'];
|
||||||
|
//console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime);
|
||||||
|
|
||||||
//Update available sessions:
|
//Check last update times for each part of document, and update as necessary
|
||||||
var currSession;
|
//First section: summary status
|
||||||
$.get("/arbiter/sessions/current", function(data){
|
if(lastStatusUpdateTime != statusTime){
|
||||||
currSession = data; //JSON.stringify(data);
|
var summaryStatusUrl = multiSession ? "/arbiter/" + currSession + "/summary" : "/arbiter/summary";
|
||||||
console.log("Current: " + currSession);
|
$.get(summaryStatusUrl,function(data){
|
||||||
});
|
var summaryStatusDiv = $('#statusdiv');
|
||||||
|
summaryStatusDiv.html('');
|
||||||
|
|
||||||
$.get("/arbiter/sessions/all", function(data){
|
var str = JSON.stringify(data);
|
||||||
var keys = data; // JSON.stringify(data);
|
var component = Component.getComponent(str);
|
||||||
|
component.render(summaryStatusDiv);
|
||||||
|
});
|
||||||
|
|
||||||
if(keys.length > 1){
|
lastStatusUpdateTime = statusTime;
|
||||||
$("#sessionSelectDiv").show();
|
|
||||||
|
|
||||||
var elem = $("#sessionSelect");
|
|
||||||
elem.empty();
|
|
||||||
|
|
||||||
var currSelectedIdx = 0;
|
|
||||||
for (var i = 0; i < keys.length; i++) {
|
|
||||||
if(keys[i] == currSession){
|
|
||||||
currSelectedIdx = i;
|
|
||||||
}
|
|
||||||
elem.append("<option value='" + keys[i] + "'>" + keys[i] + "</option>");
|
|
||||||
}
|
|
||||||
|
|
||||||
$("#sessionSelect option[value='" + keys[currSelectedIdx] +"']").attr("selected", "selected");
|
|
||||||
$("#sessionSelectDiv").show();
|
|
||||||
}
|
}
|
||||||
// console.log("Got sessions: " + keys + ", current: " + currSession);
|
|
||||||
});
|
|
||||||
|
|
||||||
|
//Second section: Optimization settings
|
||||||
|
if(lastSettingsUpdateTime != settingsTime){
|
||||||
|
//Get JSON for components
|
||||||
|
var settingsUrl = multiSession ? "/arbiter/" + currSession + "/config" : "/arbiter/config";
|
||||||
|
$.get(settingsUrl,function(data){
|
||||||
|
var str = JSON.stringify(data);
|
||||||
|
|
||||||
//Check last update times for each part of document, and update as necessary
|
var configDiv = $('#settingsdiv');
|
||||||
//First section: summary status
|
configDiv.html('');
|
||||||
if(lastStatusUpdateTime != statusTime){
|
|
||||||
//Get JSON: address set by SummaryStatusResource
|
|
||||||
$.get("/arbiter/summary",function(data){
|
|
||||||
var summaryStatusDiv = $('#statusdiv');
|
|
||||||
summaryStatusDiv.html('');
|
|
||||||
|
|
||||||
var str = JSON.stringify(data);
|
var component = Component.getComponent(str);
|
||||||
var component = Component.getComponent(str);
|
component.render(configDiv);
|
||||||
component.render(summaryStatusDiv);
|
});
|
||||||
});
|
|
||||||
|
|
||||||
lastStatusUpdateTime = statusTime;
|
lastSettingsUpdateTime = settingsTime;
|
||||||
}
|
}
|
||||||
|
|
||||||
//Second section: Optimization settings
|
//Third section: Summary results table (summary info for each candidate)
|
||||||
if(lastSettingsUpdateTime != settingsTime){
|
if(lastResultsUpdateTime != resultsTime){
|
||||||
//Get JSON for components
|
//Get JSON for results table
|
||||||
$.get("/arbiter/config",function(data){
|
var resultsUrl = multiSession ? "/arbiter/" + currSession + "/results" : "/arbiter/results";
|
||||||
var str = JSON.stringify(data);
|
$.get(resultsUrl,function(data){
|
||||||
|
//Expect an array of CandidateInfo type objects here
|
||||||
|
resultsTableContent = data;
|
||||||
|
drawResultTable();
|
||||||
|
});
|
||||||
|
|
||||||
var configDiv = $('#settingsdiv');
|
lastResultsUpdateTime = resultsTime;
|
||||||
configDiv.html('');
|
}
|
||||||
|
|
||||||
var component = Component.getComponent(str);
|
//Finally: Currently selected result
|
||||||
component.render(configDiv);
|
if(selectedCandidateIdx != null){
|
||||||
});
|
//Get JSON for components
|
||||||
|
var candidateInfoUrl = multiSession
|
||||||
|
? "/arbiter/" + currSession + "/candidateInfo/" + selectedCandidateIdx
|
||||||
|
: "/arbiter/candidateInfo/" + selectedCandidateIdx;
|
||||||
|
$.get(candidateInfoUrl,function(data){
|
||||||
|
var str = JSON.stringify(data);
|
||||||
|
|
||||||
lastSettingsUpdateTime = settingsTime;
|
var resultsViewDiv = $('#resultsviewdiv');
|
||||||
}
|
resultsViewDiv.html('');
|
||||||
|
|
||||||
//Third section: Summary results table (summary info for each candidate)
|
var component = Component.getComponent(str);
|
||||||
if(lastResultsUpdateTime != resultsTime){
|
component.render(resultsViewDiv);
|
||||||
|
});
|
||||||
//Get JSON; address set by SummaryResultsResource
|
}
|
||||||
$.get("/arbiter/results",function(data){
|
})
|
||||||
//Expect an array of CandidateInfo type objects here
|
|
||||||
resultsTableContent = data;
|
|
||||||
drawResultTable();
|
|
||||||
});
|
|
||||||
|
|
||||||
lastResultsUpdateTime = resultsTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
//Finally: Currently selected result
|
|
||||||
if(selectedCandidateIdx != null){
|
|
||||||
//Get JSON for components
|
|
||||||
$.get("/arbiter/candidateInfo/"+selectedCandidateIdx,function(data){
|
|
||||||
var str = JSON.stringify(data);
|
|
||||||
|
|
||||||
var resultsViewDiv = $('#resultsviewdiv');
|
|
||||||
resultsViewDiv.html('');
|
|
||||||
|
|
||||||
var component = Component.getComponent(str);
|
|
||||||
component.render(resultsViewDiv);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,9 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
package org.deeplearning4j.arbiter.optimize;
|
||||||
|
|
||||||
|
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.core.storage.StatsStorage;
|
import org.deeplearning4j.core.storage.StatsStorage;
|
||||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||||
|
@ -54,6 +57,7 @@ import org.deeplearning4j.ui.api.UIServer;
|
||||||
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.common.function.Function;
|
||||||
import org.nd4j.evaluation.classification.Evaluation;
|
import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -62,52 +66,43 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Collections;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.io.UnsupportedEncodingException;
|
||||||
import java.util.Map;
|
import java.net.HttpURLConnection;
|
||||||
import java.util.Properties;
|
import java.net.URL;
|
||||||
|
import java.net.URLEncoder;
|
||||||
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 19/07/2017.
|
* Created by Alex on 19/07/2017.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class TestBasic extends BaseDL4JTest {
|
public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 3600_000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testBasicUiOnly() throws Exception {
|
public void testBasicUiOnly() throws Exception {
|
||||||
|
|
||||||
UIServer.getInstance();
|
UIServer.getInstance();
|
||||||
|
|
||||||
Thread.sleep(1000000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testBasicMnist() throws Exception {
|
public void testBasicMnist() throws Exception {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
|
|
||||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
|
||||||
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
|
||||||
.addLayer(
|
|
||||||
new ConvolutionLayerSpace.Builder().nIn(1)
|
|
||||||
.nOut(new IntegerParameterSpace(5, 30))
|
|
||||||
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
|
||||||
new int[]{4, 4}, new int[]{5, 5}))
|
|
||||||
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
|
||||||
new int[]{2, 2}))
|
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
|
||||||
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
|
||||||
.build())
|
|
||||||
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
|
||||||
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
|
||||||
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
|
||||||
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
|
||||||
.build();
|
|
||||||
Map<String, Object> commands = new HashMap<>();
|
Map<String, Object> commands = new HashMap<>();
|
||||||
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
||||||
|
|
||||||
|
@ -144,7 +139,30 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static MultiLayerSpace getMultiLayerSpaceMnist() {
|
||||||
|
return new MultiLayerSpace.Builder()
|
||||||
|
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||||
|
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
||||||
|
.addLayer(
|
||||||
|
new ConvolutionLayerSpace.Builder().nIn(1)
|
||||||
|
.nOut(new IntegerParameterSpace(5, 30))
|
||||||
|
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
||||||
|
new int[]{4, 4}, new int[]{5, 5}))
|
||||||
|
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
||||||
|
new int[]{2, 2}))
|
||||||
|
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
||||||
|
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
||||||
|
.build())
|
||||||
|
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
||||||
|
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||||
|
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
||||||
|
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -233,7 +251,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Define configuration:
|
//Define configuration:
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
|
||||||
DataProvider dataProvider = new MnistDataSetProvider();
|
DataProvider dataProvider = new MnistDataSetProvider();
|
||||||
|
|
||||||
|
|
||||||
|
@ -331,7 +349,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -396,7 +414,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -433,7 +451,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Define configuration:
|
//Define configuration:
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
|
||||||
DataProvider dataProvider = new MnistDataSetProvider();
|
DataProvider dataProvider = new MnistDataSetProvider();
|
||||||
|
|
||||||
|
|
||||||
|
@ -465,13 +483,17 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Visualize multiple optimization sessions run one after another on single-session mode UI
|
||||||
|
* @throws InterruptedException if current thread has been interrupted
|
||||||
|
*/
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testBasicMnistMultipleSessions() throws Exception {
|
public void testBasicMnistMultipleSessions() throws InterruptedException {
|
||||||
|
|
||||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
||||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||||
|
@ -499,8 +521,10 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
//Define configuration:
|
//Define configuration:
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
||||||
DataProvider dataProvider = new MnistDataSetProvider();
|
|
||||||
|
|
||||||
|
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||||
|
Properties dsp = new Properties();
|
||||||
|
dsp.setProperty("minibatch", "8");
|
||||||
|
|
||||||
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
|
||||||
|
|
||||||
|
@ -513,7 +537,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
OptimizationConfiguration configuration =
|
OptimizationConfiguration configuration =
|
||||||
new OptimizationConfiguration.Builder()
|
new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
.modelSaver(new FileModelSaver(modelSavePath))
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
||||||
|
@ -535,7 +559,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
candidateGenerator = new RandomSearchGenerator(mls, commands);
|
candidateGenerator = new RandomSearchGenerator(mls, commands);
|
||||||
configuration = new OptimizationConfiguration.Builder()
|
configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
.modelSaver(new FileModelSaver(modelSavePath))
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
||||||
|
@ -550,7 +574,148 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
|
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Auto-attach multiple optimization sessions to multi-session mode UI
|
||||||
|
* @throws IOException if could not connect to the server
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testUiMultiSessionAutoAttach() throws IOException {
|
||||||
|
|
||||||
|
//Define configuration:
|
||||||
|
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||||
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
|
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||||
|
Properties dsp = new Properties();
|
||||||
|
dsp.setProperty("minibatch", "8");
|
||||||
|
|
||||||
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\")
|
||||||
|
.getAbsolutePath();
|
||||||
|
|
||||||
|
File f = new File(modelSavePath);
|
||||||
|
if (f.exists())
|
||||||
|
f.delete();
|
||||||
|
f.mkdir();
|
||||||
|
if (!f.exists())
|
||||||
|
throw new RuntimeException();
|
||||||
|
|
||||||
|
OptimizationConfiguration configuration =
|
||||||
|
new OptimizationConfiguration.Builder()
|
||||||
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
|
.terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS),
|
||||||
|
new MaxCandidatesCondition(1))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
IOptimizationRunner runner =
|
||||||
|
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
||||||
|
|
||||||
|
// add 3 different sessions to the same execution
|
||||||
|
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
StatsStorage ss = new InMemoryStatsStorage();
|
||||||
|
@NonNull String sessionId = "sid" + i;
|
||||||
|
statsStorageForSession.put(sessionId, ss);
|
||||||
|
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
||||||
|
runner.addListeners(sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
||||||
|
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
||||||
|
String serverAddress = uIServer.getAddress();
|
||||||
|
|
||||||
|
runner.execute();
|
||||||
|
|
||||||
|
for (String sessionId : statsStorageForSession.keySet()) {
|
||||||
|
/*
|
||||||
|
* Visiting /arbiter/:sessionId to auto-attach StatsStorage
|
||||||
|
*/
|
||||||
|
String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId);
|
||||||
|
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
|
||||||
|
conn.connect();
|
||||||
|
|
||||||
|
log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId));
|
||||||
|
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||||
|
assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL
|
||||||
|
* @throws Exception if an error occurred
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testUiMultiSessionManualAttach() throws Exception {
|
||||||
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
|
|
||||||
|
//Define configuration:
|
||||||
|
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||||
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
|
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||||
|
Properties dsp = new Properties();
|
||||||
|
dsp.setProperty("minibatch", "8");
|
||||||
|
|
||||||
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\")
|
||||||
|
.getAbsolutePath();
|
||||||
|
|
||||||
|
File f = new File(modelSavePath);
|
||||||
|
if (f.exists())
|
||||||
|
f.delete();
|
||||||
|
f.mkdir();
|
||||||
|
if (!f.exists())
|
||||||
|
throw new RuntimeException();
|
||||||
|
|
||||||
|
OptimizationConfiguration configuration =
|
||||||
|
new OptimizationConfiguration.Builder()
|
||||||
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
|
.terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES),
|
||||||
|
new MaxCandidatesCondition(10))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
// parallel execution of multiple optimization sessions
|
||||||
|
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
String sessionId = "sid" + i;
|
||||||
|
IOptimizationRunner runner =
|
||||||
|
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
||||||
|
StatsStorage ss = new InMemoryStatsStorage();
|
||||||
|
statsStorageForSession.put(sessionId, ss);
|
||||||
|
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
||||||
|
runner.addListeners(sl);
|
||||||
|
// Asynchronous execution
|
||||||
|
new Thread(runner::execute).start();
|
||||||
|
}
|
||||||
|
|
||||||
|
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
||||||
|
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
||||||
|
String serverAddress = uIServer.getAddress();
|
||||||
|
|
||||||
|
for (String sessionId : statsStorageForSession.keySet()) {
|
||||||
|
log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId));
|
||||||
|
}
|
||||||
|
|
||||||
|
Thread.sleep(1000_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get URL for arbiter session on given server address
|
||||||
|
* @param serverAddress server address, e.g.: http://localhost:9000
|
||||||
|
* @param sessionId session ID (will be URL-encoded)
|
||||||
|
* @return URL
|
||||||
|
* @throws UnsupportedEncodingException if the character encoding is not supported
|
||||||
|
*/
|
||||||
|
private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
|
||||||
|
return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class MnistDataSetProvider implements DataProvider {
|
private static class MnistDataSetProvider implements DataProvider {
|
||||||
|
|
|
@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
private static Integer instancePort;
|
private static Integer instancePort;
|
||||||
private static Thread autoStopThread;
|
private static Thread autoStopThread;
|
||||||
|
|
||||||
private TrainModule trainModule;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||||
* @param port TCP socket port for {@link HttpServer} to listen
|
* @param port TCP socket port for {@link HttpServer} to listen
|
||||||
|
@ -194,8 +192,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
VertxUIServer.autoStopThread = new Thread(() -> {
|
VertxUIServer.autoStopThread = new Thread(() -> {
|
||||||
try {
|
try {
|
||||||
currentThread.join();
|
currentThread.join();
|
||||||
log.info("Deeplearning4j UI server is auto-stopping.");
|
|
||||||
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
|
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
|
||||||
|
log.info("Deeplearning4j UI server is auto-stopping after thread (name: {}) died.",
|
||||||
|
currentThread.getName());
|
||||||
instance.stop();
|
instance.stop();
|
||||||
}
|
}
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
|
@ -207,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
|
|
||||||
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
||||||
private RemoteReceiverModule remoteReceiverModule;
|
private RemoteReceiverModule remoteReceiverModule;
|
||||||
private StatsStorageLoader statsStorageLoader;
|
/**
|
||||||
|
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
||||||
|
*/
|
||||||
|
@Getter
|
||||||
|
private Function<String, Boolean> statsStorageLoader;
|
||||||
|
|
||||||
//typeIDModuleMap: Records which modules are registered for which type IDs
|
//typeIDModuleMap: Records which modules are registered for which type IDs
|
||||||
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
|
private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap<>();
|
||||||
|
@ -247,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
*/
|
*/
|
||||||
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
||||||
if (statsStorageProvider != null) {
|
if (statsStorageProvider != null) {
|
||||||
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
this.statsStorageLoader = (sessionId) -> {
|
||||||
if (trainModule != null) {
|
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
||||||
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
||||||
}
|
if (statsStorage != null) {
|
||||||
|
if (statsStorage.sessionExists(sessionId)) {
|
||||||
|
attach(statsStorage);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
||||||
|
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
||||||
|
"StatsStorageProvider returned null.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
||||||
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
uiModules.add(new TrainModule());
|
||||||
uiModules.add(trainModule);
|
|
||||||
uiModules.add(new ConvolutionalListenerModule());
|
uiModules.add(new ConvolutionalListenerModule());
|
||||||
uiModules.add(new TsneModule());
|
uiModules.add(new TsneModule());
|
||||||
uiModules.add(new SameDiffModule());
|
uiModules.add(new SameDiffModule());
|
||||||
|
@ -596,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID
|
|
||||||
*/
|
|
||||||
private class StatsStorageLoader implements Function<String, Boolean> {
|
|
||||||
|
|
||||||
Function<String, StatsStorage> statsStorageProvider;
|
|
||||||
|
|
||||||
StatsStorageLoader(Function<String, StatsStorage> statsStorageProvider) {
|
|
||||||
this.statsStorageProvider = statsStorageProvider;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Boolean apply(String sessionId) {
|
|
||||||
log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ").");
|
|
||||||
StatsStorage statsStorage = statsStorageProvider.apply(sessionId);
|
|
||||||
if (statsStorage != null) {
|
|
||||||
if (statsStorage.sessionExists(sessionId)) {
|
|
||||||
attach(statsStorage);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " +
|
|
||||||
"Session ID (" + sessionId + ") does not exist in StatsStorage.");
|
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " +
|
|
||||||
"StatsStorageProvider returned null.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//==================================================================================================================
|
//==================================================================================================================
|
||||||
// CLI Launcher
|
// CLI Launcher
|
||||||
|
|
||||||
|
|
|
@ -157,8 +157,9 @@ public interface UIServer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
|
* Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
|
||||||
|
* @throws InterruptedException if the current thread is interrupted while waiting
|
||||||
*/
|
*/
|
||||||
void stop() throws Exception;
|
void stop() throws InterruptedException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stop/shut down the UI server.
|
* Stop/shut down the UI server.
|
||||||
|
|
|
@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext;
|
||||||
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.FilenameUtils;
|
import org.apache.commons.io.FilenameUtils;
|
||||||
|
@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
import org.deeplearning4j.nn.conf.graph.LayerVertex;
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
|
import org.deeplearning4j.ui.VertxUIServer;
|
||||||
import org.deeplearning4j.ui.api.HttpMethod;
|
import org.deeplearning4j.ui.api.HttpMethod;
|
||||||
import org.deeplearning4j.ui.api.I18N;
|
import org.deeplearning4j.ui.api.I18N;
|
||||||
import org.deeplearning4j.ui.api.Route;
|
import org.deeplearning4j.ui.api.Route;
|
||||||
|
@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport;
|
||||||
import org.deeplearning4j.ui.model.stats.api.StatsReport;
|
import org.deeplearning4j.ui.model.stats.api.StatsReport;
|
||||||
import org.deeplearning4j.ui.model.stats.api.StatsType;
|
import org.deeplearning4j.ui.model.stats.api.StatsType;
|
||||||
import org.nd4j.common.function.Function;
|
import org.nd4j.common.function.Function;
|
||||||
import org.nd4j.common.function.Supplier;
|
|
||||||
import org.nd4j.linalg.learning.config.IUpdater;
|
import org.nd4j.linalg.learning.config.IUpdater;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
import org.nd4j.common.primitives.Triple;
|
import org.nd4j.common.primitives.Triple;
|
||||||
|
@ -86,8 +84,6 @@ public class TrainModule implements UIModule {
|
||||||
private static final DecimalFormat df2 = new DecimalFormat("#.00");
|
private static final DecimalFormat df2 = new DecimalFormat("#.00");
|
||||||
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
|
||||||
|
|
||||||
private final Supplier<String> addressSupplier;
|
|
||||||
|
|
||||||
private enum ModelType {
|
private enum ModelType {
|
||||||
MLN, CG, Layer
|
MLN, CG, Layer
|
||||||
}
|
}
|
||||||
|
@ -99,29 +95,14 @@ public class TrainModule implements UIModule {
|
||||||
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
|
private Map<String, AtomicInteger> workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID
|
||||||
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
|
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
|
||||||
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
|
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
|
||||||
private final boolean multiSession;
|
|
||||||
@Getter @Setter
|
|
||||||
private Function<String, Boolean> sessionLoader;
|
|
||||||
|
|
||||||
|
|
||||||
private final Configuration configuration;
|
private final Configuration configuration;
|
||||||
|
|
||||||
public TrainModule() {
|
|
||||||
this(false, null, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TrainModule
|
* TrainModule
|
||||||
*
|
|
||||||
* @param multiSession multi-session mode
|
|
||||||
* @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter
|
|
||||||
* in multi-session mode
|
|
||||||
* @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules)
|
|
||||||
*/
|
*/
|
||||||
public TrainModule(boolean multiSession, Function<String, Boolean> sessionLoader, Supplier<String> addressSupplier) {
|
public TrainModule() {
|
||||||
this.multiSession = multiSession;
|
|
||||||
this.sessionLoader = sessionLoader;
|
|
||||||
this.addressSupplier = addressSupplier;
|
|
||||||
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
|
String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY);
|
||||||
int value = DEFAULT_MAX_CHART_POINTS;
|
int value = DEFAULT_MAX_CHART_POINTS;
|
||||||
if (maxChartPointsProp != null) {
|
if (maxChartPointsProp != null) {
|
||||||
|
@ -159,8 +140,9 @@ public class TrainModule implements UIModule {
|
||||||
@Override
|
@Override
|
||||||
public List<Route> getRoutes() {
|
public List<Route> getRoutes() {
|
||||||
List<Route> r = new ArrayList<>();
|
List<Route> r = new ArrayList<>();
|
||||||
r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false")));
|
r.add(new Route("/train/multisession", HttpMethod.GET,
|
||||||
if (multiSession) {
|
(path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false")));
|
||||||
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
||||||
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
|
r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> {
|
||||||
rc.response()
|
rc.response()
|
||||||
|
@ -264,7 +246,9 @@ public class TrainModule implements UIModule {
|
||||||
if (!knownSessionIDs.isEmpty()) {
|
if (!knownSessionIDs.isEmpty()) {
|
||||||
sb.append(" <ul>");
|
sb.append(" <ul>");
|
||||||
for (String sessionId : knownSessionIDs.keySet()) {
|
for (String sessionId : knownSessionIDs.keySet()) {
|
||||||
sb.append(" <li><a href=\"train/").append(sessionId).append("\">").append(sessionId).append("</a></li>\n");
|
sb.append(" <li><a href=\"/train/")
|
||||||
|
.append(sessionId).append("\">")
|
||||||
|
.append(sessionId).append("</a></li>\n");
|
||||||
}
|
}
|
||||||
sb.append(" </ul>");
|
sb.append(" </ul>");
|
||||||
} else {
|
} else {
|
||||||
|
@ -284,9 +268,11 @@ public class TrainModule implements UIModule {
|
||||||
*
|
*
|
||||||
* @param sessionId session ID to look fo with provider
|
* @param sessionId session ID to look fo with provider
|
||||||
* @param targetPath one of overview / model / system, or null
|
* @param targetPath one of overview / model / system, or null
|
||||||
|
* @param rc routing context
|
||||||
*/
|
*/
|
||||||
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
||||||
if (sessionLoader != null && sessionLoader.apply(sessionId)) {
|
Function<String, Boolean> loader = VertxUIServer.getInstance().getStatsStorageLoader();
|
||||||
|
if (loader != null && loader.apply(sessionId)) {
|
||||||
if (targetPath != null) {
|
if (targetPath != null) {
|
||||||
rc.reroute(targetPath);
|
rc.reroute(targetPath);
|
||||||
} else {
|
} else {
|
||||||
|
@ -306,9 +292,9 @@ public class TrainModule implements UIModule {
|
||||||
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
|
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
|
||||||
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
|
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
|
||||||
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
|
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
|
||||||
if (multiSession) {
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||||
addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage());
|
VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -332,9 +318,9 @@ public class TrainModule implements UIModule {
|
||||||
if (!StatsListener.TYPE_ID.equals(typeID))
|
if (!StatsListener.TYPE_ID.equals(typeID))
|
||||||
continue;
|
continue;
|
||||||
knownSessionIDs.put(sessionID, statsStorage);
|
knownSessionIDs.put(sessionID, statsStorage);
|
||||||
if (multiSession) {
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
|
||||||
addressSupplier.get(), sessionID, statsStorage);
|
VertxUIServer.getInstance().getAddress(), sessionID, statsStorage);
|
||||||
}
|
}
|
||||||
|
|
||||||
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
|
List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
|
||||||
|
@ -364,9 +350,9 @@ public class TrainModule implements UIModule {
|
||||||
}
|
}
|
||||||
for (String s : toRemove) {
|
for (String s : toRemove) {
|
||||||
knownSessionIDs.remove(s);
|
knownSessionIDs.remove(s);
|
||||||
if (multiSession) {
|
if (VertxUIServer.getInstance().isMultiSession()) {
|
||||||
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
|
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
|
||||||
addressSupplier.get(), s, statsStorage);
|
VertxUIServer.getInstance().getAddress(), s, statsStorage);
|
||||||
}
|
}
|
||||||
lastUpdateForSession.remove(s);
|
lastUpdateForSession.remove(s);
|
||||||
}
|
}
|
||||||
|
@ -602,13 +588,13 @@ public class TrainModule implements UIModule {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session
|
* Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session
|
||||||
*
|
*
|
||||||
* @param sessionId session ID
|
* @param sessionId session ID
|
||||||
* @return {@link I18N} instance
|
* @return {@link I18N} instance
|
||||||
*/
|
*/
|
||||||
private I18N getI18N(String sessionId) {
|
private I18N getI18N(String sessionId) {
|
||||||
return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ import static org.junit.Assert.*;
|
||||||
public class TestVertxUIManual extends BaseDL4JTest {
|
public class TestVertxUIManual extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 3600_000L;
|
return 3600_000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ public class TestVertxUIManual extends BaseDL4JTest {
|
||||||
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
|
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
|
||||||
sessionId, autoDetachTimeoutMillis);
|
sessionId, autoDetachTimeoutMillis);
|
||||||
uIServer.detach(statsStorage);
|
uIServer.detach(statsStorage);
|
||||||
log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}",
|
log.info(" To re-attach StatsStorage of training session, visit {}/train/{}",
|
||||||
uIServer.getAddress(), sessionId);
|
uIServer.getAddress(), sessionId);
|
||||||
}
|
}
|
||||||
}).start();
|
}).start();
|
||||||
|
|
|
@ -197,9 +197,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get URL-encoded URL for training session on given server address
|
* Get URL for training session on given server address
|
||||||
* @param serverAddress server address
|
* @param serverAddress server address
|
||||||
* @param sessionId session ID
|
* @param sessionId session ID (will be URL-encoded)
|
||||||
* @return URL
|
* @return URL
|
||||||
* @throws UnsupportedEncodingException if the used encoding is not supported
|
* @throws UnsupportedEncodingException if the used encoding is not supported
|
||||||
*/
|
*/
|
||||||
|
|
Loading…
Reference in New Issue