[WIP] basic handling of multi-session in ArbiterModule (routes: arbiter, arbiter/multisession)
Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>master
parent
f1ebced7a1
commit
d3c759e03f
|
@ -36,6 +36,7 @@ import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable;
|
||||||
import org.deeplearning4j.arbiter.ui.misc.UIUtils;
|
import org.deeplearning4j.arbiter.ui.misc.UIUtils;
|
||||||
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
import org.deeplearning4j.arbiter.util.ObjectUtils;
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||||
|
import org.deeplearning4j.ui.VertxUIServer;
|
||||||
import org.deeplearning4j.ui.api.Component;
|
import org.deeplearning4j.ui.api.Component;
|
||||||
import org.deeplearning4j.ui.api.*;
|
import org.deeplearning4j.ui.api.*;
|
||||||
import org.deeplearning4j.ui.components.chart.ChartLine;
|
import org.deeplearning4j.ui.components.chart.ChartLine;
|
||||||
|
@ -77,7 +78,6 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>());
|
private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap<>());
|
||||||
|
|
||||||
|
|
||||||
//Styles for UI:
|
//Styles for UI:
|
||||||
private static final StyleTable STYLE_TABLE = new StyleTable.Builder()
|
private static final StyleTable STYLE_TABLE = new StyleTable.Builder()
|
||||||
.width(100, LengthUnit.Percent)
|
.width(100, LengthUnit.Percent)
|
||||||
|
@ -134,20 +134,69 @@ public class ArbiterModule implements UIModule {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Route> getRoutes() {
|
public List<Route> getRoutes() {
|
||||||
Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
|
boolean multiSession = VertxUIServer.getMultiSession().get();
|
||||||
.putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html"));
|
List<Route> r = new ArrayList<>();
|
||||||
Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc));
|
r.add(new Route("/arbiter/multisession", HttpMethod.GET,
|
||||||
Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc));
|
(path, rc) -> rc.response().end(multiSession ? "true" : "false")));
|
||||||
Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc));
|
if (multiSession) {
|
||||||
Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc));
|
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc)));
|
||||||
Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc));
|
} else {
|
||||||
Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc));
|
r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response()
|
||||||
|
.putHeader("content-type", "text/html; charset=utf-8")
|
||||||
|
.sendFile("templates/ArbiterUI.html")));
|
||||||
|
r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc)));
|
||||||
|
r.add(new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET,
|
||||||
|
(path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc)));
|
||||||
|
r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET,
|
||||||
|
(path, rc) -> this.getCandidateInfo(path.get(0), rc)));
|
||||||
|
r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc)));
|
||||||
|
r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc)));
|
||||||
|
r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc)));
|
||||||
|
|
||||||
Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc));
|
r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc)));
|
||||||
Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc));
|
r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)));
|
||||||
Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc));
|
r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET,
|
||||||
|
(path, rc) -> this.setSession(path.get(0), rc)));
|
||||||
|
}
|
||||||
|
|
||||||
return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c);
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List optimization sessions. Returns a HTML list of arbiter sessions
|
||||||
|
*/
|
||||||
|
private synchronized void listSessions(RoutingContext rc) {
|
||||||
|
StringBuilder sb = new StringBuilder("<!DOCTYPE html>\n" +
|
||||||
|
"<html lang=\"en\">\n" +
|
||||||
|
"<head>\n" +
|
||||||
|
" <meta charset=\"utf-8\">\n" +
|
||||||
|
" <title>Optimization sessions - DL4J Arbiter UI</title>\n" +
|
||||||
|
" </head>\n" +
|
||||||
|
"\n" +
|
||||||
|
" <body>\n" +
|
||||||
|
" <h1>DL4J Arbiter UI</h1>\n" +
|
||||||
|
" <p>UI server is in multi-session mode." +
|
||||||
|
" To visualize an optimization session, please select one from the following list.</p>\n" +
|
||||||
|
" <h2>List of attached optimization sessions</h2>\n");
|
||||||
|
if (!knownSessionIDs.isEmpty()) {
|
||||||
|
sb.append(" <ul>");
|
||||||
|
for (String sessionId : knownSessionIDs.keySet()) {
|
||||||
|
sb.append(" <li><a href=\"train/")
|
||||||
|
.append(sessionId).append("\">")
|
||||||
|
.append(sessionId).append("</a></li>\n");
|
||||||
|
}
|
||||||
|
sb.append(" </ul>");
|
||||||
|
} else {
|
||||||
|
sb.append("No optimization session attached.");
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.append(" </body>\n" +
|
||||||
|
"</html>\n");
|
||||||
|
|
||||||
|
rc.response()
|
||||||
|
.putHeader("content-type", "text/html; charset=utf-8")
|
||||||
|
.end(sb.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -201,7 +250,7 @@ public class ArbiterModule implements UIModule {
|
||||||
.end(asJson(sid));
|
.end(asJson(sid));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void listSessions(RoutingContext rc) {
|
private void sessionInfo(RoutingContext rc) {
|
||||||
rc.response()
|
rc.response()
|
||||||
.putHeader("content-type", "application/json")
|
.putHeader("content-type", "application/json")
|
||||||
.end(asJson(knownSessionIDs.keySet()));
|
.end(asJson(knownSessionIDs.keySet()));
|
||||||
|
@ -309,7 +358,6 @@ public class ArbiterModule implements UIModule {
|
||||||
* Get the info for a specific candidate - last section in the UI
|
* Get the info for a specific candidate - last section in the UI
|
||||||
*
|
*
|
||||||
* @param candidateId ID for the candidate
|
* @param candidateId ID for the candidate
|
||||||
* @return Content/info for the candidate
|
|
||||||
*/
|
*/
|
||||||
private void getCandidateInfo(String candidateId, RoutingContext rc){
|
private void getCandidateInfo(String candidateId, RoutingContext rc){
|
||||||
|
|
||||||
|
@ -320,7 +368,8 @@ public class ArbiterModule implements UIModule {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);;
|
GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss
|
||||||
|
.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);
|
||||||
OptimizationConfiguration oc = gcp.getOptimizationConfiguration();
|
OptimizationConfiguration oc = gcp.getOptimizationConfiguration();
|
||||||
|
|
||||||
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId);
|
Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId);
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
package org.deeplearning4j.arbiter.optimize;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.core.storage.StatsStorage;
|
import org.deeplearning4j.core.storage.StatsStorage;
|
||||||
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
import org.deeplearning4j.arbiter.ComputationGraphSpace;
|
||||||
|
@ -59,6 +61,7 @@ import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.function.Function;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
|
@ -71,8 +74,14 @@ import java.util.concurrent.TimeUnit;
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 19/07/2017.
|
* Created by Alex on 19/07/2017.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
public class TestBasic extends BaseDL4JTest {
|
public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 3600_000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testBasicUiOnly() throws Exception {
|
public void testBasicUiOnly() throws Exception {
|
||||||
|
@ -82,6 +91,83 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
Thread.sleep(1000000);
|
Thread.sleep(1000000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testBasicUiMultiSession() throws Exception {
|
||||||
|
|
||||||
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
|
|
||||||
|
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
||||||
|
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||||
|
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
||||||
|
.addLayer(
|
||||||
|
new ConvolutionLayerSpace.Builder().nIn(1)
|
||||||
|
.nOut(new IntegerParameterSpace(5, 30))
|
||||||
|
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
||||||
|
new int[]{4, 4}, new int[]{5, 5}))
|
||||||
|
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
||||||
|
new int[]{2, 2}))
|
||||||
|
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
||||||
|
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
||||||
|
.build())
|
||||||
|
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
||||||
|
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||||
|
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
||||||
|
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
||||||
|
.build();
|
||||||
|
Map<String, Object> commands = new HashMap<>();
|
||||||
|
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
||||||
|
|
||||||
|
//Define configuration:
|
||||||
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
||||||
|
DataProvider dataProvider = new MnistDataSetProvider();
|
||||||
|
|
||||||
|
|
||||||
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
|
||||||
|
|
||||||
|
File f = new File(modelSavePath);
|
||||||
|
if (f.exists())
|
||||||
|
f.delete();
|
||||||
|
f.mkdir();
|
||||||
|
if (!f.exists())
|
||||||
|
throw new RuntimeException();
|
||||||
|
|
||||||
|
OptimizationConfiguration configuration =
|
||||||
|
new OptimizationConfiguration.Builder()
|
||||||
|
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
||||||
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
|
.terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES),
|
||||||
|
new MaxCandidatesCondition(100))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
IOptimizationRunner runner =
|
||||||
|
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
||||||
|
|
||||||
|
// add 3 different sessions to the same execution
|
||||||
|
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
StatsStorage ss = new InMemoryStatsStorage();
|
||||||
|
@NonNull String sessionId = "sid" + i;
|
||||||
|
statsStorageForSession.put(sessionId, ss);
|
||||||
|
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
||||||
|
runner.addListeners(sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
||||||
|
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
||||||
|
String serverAddress = uIServer.getAddress();
|
||||||
|
for (String sessionId : statsStorageForSession.keySet()) {
|
||||||
|
log.info("Arbiter session will start at {}/arbiter/{}", serverAddress, sessionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
runner.execute();
|
||||||
|
|
||||||
|
Thread.sleep(1000000);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
|
|
Loading…
Reference in New Issue