From d3c759e03f85df54a3a05e427ccb1ea74a88036d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Thu, 23 Apr 2020 15:26:57 +0200 Subject: [PATCH] [WIP] basic handling of multi-session in ArbiterModule (routes: arbiter, arbiter/multisession) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../arbiter/ui/module/ArbiterModule.java | 81 +++++++++++++---- .../arbiter/optimize/TestBasic.java | 86 +++++++++++++++++++ 2 files changed, 151 insertions(+), 16 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java index 31e86d45f..6bcf3acbd 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java @@ -36,6 +36,7 @@ import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; import org.deeplearning4j.arbiter.ui.misc.UIUtils; import org.deeplearning4j.arbiter.util.ObjectUtils; import org.deeplearning4j.nn.conf.serde.JsonMappers; +import org.deeplearning4j.ui.VertxUIServer; import org.deeplearning4j.ui.api.Component; import org.deeplearning4j.ui.api.*; import org.deeplearning4j.ui.components.chart.ChartLine; @@ -77,7 +78,6 @@ public class ArbiterModule implements UIModule { private Map lastUpdateForSession = Collections.synchronizedMap(new HashMap<>()); - //Styles for UI: private static final StyleTable STYLE_TABLE = new StyleTable.Builder() .width(100, LengthUnit.Percent) @@ -134,20 +134,69 @@ public class ArbiterModule implements UIModule { @Override public List getRoutes() { - Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() - .putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html")); - Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc)); - Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc)); - Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc)); - Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc)); - Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc)); - Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc)); + boolean multiSession = VertxUIServer.getMultiSession().get(); + List r = new ArrayList<>(); + r.add(new Route("/arbiter/multisession", HttpMethod.GET, + (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); + if (multiSession) { + r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); + } else { + r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() + .putHeader("content-type", "text/html; charset=utf-8") + .sendFile("templates/ArbiterUI.html"))); + r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc))); + r.add(new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, + (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc))); + r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, + (path, rc) -> this.getCandidateInfo(path.get(0), rc))); + r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc))); + r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc))); + r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc))); - Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc)); - Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)); - Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)); + r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); + r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc))); + r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET, + (path, rc) -> this.setSession(path.get(0), rc))); + } - return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c); + return r; + } + + + /** + * List optimization sessions. Returns a HTML list of arbiter sessions + */ + private synchronized void listSessions(RoutingContext rc) { + StringBuilder sb = new StringBuilder("\n" + + "\n" + + "\n" + + " \n" + + " Optimization sessions - DL4J Arbiter UI\n" + + " \n" + + "\n" + + " \n" + + "

DL4J Arbiter UI

\n" + + "

UI server is in multi-session mode." + + " To visualize an optimization session, please select one from the following list.

\n" + + "

List of attached optimization sessions

\n"); + if (!knownSessionIDs.isEmpty()) { + sb.append(" "); + } else { + sb.append("No optimization session attached."); + } + + sb.append(" \n" + + "\n"); + + rc.response() + .putHeader("content-type", "text/html; charset=utf-8") + .end(sb.toString()); } @Override @@ -201,7 +250,7 @@ public class ArbiterModule implements UIModule { .end(asJson(sid)); } - private void listSessions(RoutingContext rc) { + private void sessionInfo(RoutingContext rc) { rc.response() .putHeader("content-type", "application/json") .end(asJson(knownSessionIDs.keySet())); @@ -309,7 +358,6 @@ public class ArbiterModule implements UIModule { * Get the info for a specific candidate - last section in the UI * * @param candidateId ID for the candidate - * @return Content/info for the candidate */ private void getCandidateInfo(String candidateId, RoutingContext rc){ @@ -320,7 +368,8 @@ public class ArbiterModule implements UIModule { return; } - GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);; + GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss + .getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId); diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index 64c873348..cb2374d64 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -16,6 +16,8 @@ package org.deeplearning4j.arbiter.optimize; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.arbiter.ComputationGraphSpace; @@ -59,6 +61,7 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.function.Function; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; @@ -71,8 +74,14 @@ import java.util.concurrent.TimeUnit; /** * Created by Alex on 19/07/2017. */ +@Slf4j public class TestBasic extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 3600_000L; + } + @Test @Ignore public void testBasicUiOnly() throws Exception { @@ -82,6 +91,83 @@ public class TestBasic extends BaseDL4JTest { Thread.sleep(1000000); } + @Test + @Ignore + public void testBasicUiMultiSession() throws Exception { + + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); + + MultiLayerSpace mls = new MultiLayerSpace.Builder() + .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) + .l2(new ContinuousParameterSpace(0.0001, 0.05)) + .addLayer( + new ConvolutionLayerSpace.Builder().nIn(1) + .nOut(new IntegerParameterSpace(5, 30)) + .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, + new int[]{4, 4}, new int[]{5, 5})) + .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, + new int[]{2, 2})) + .activation(new DiscreteParameterSpace<>(Activation.RELU, + Activation.SOFTPLUS, Activation.LEAKYRELU)) + .build()) + .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) + .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) + .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers + .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); + Map commands = new HashMap<>(); +// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); + + //Define configuration: + CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); + DataProvider dataProvider = new MnistDataSetProvider(); + + + String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); + + File f = new File(modelSavePath); + if (f.exists()) + f.delete(); + f.mkdir(); + if (!f.exists()) + throw new RuntimeException(); + + OptimizationConfiguration configuration = + new OptimizationConfiguration.Builder() + .candidateGenerator(candidateGenerator).dataProvider(dataProvider) + .modelSaver(new FileModelSaver(modelSavePath)) + .scoreFunction(new TestSetLossScoreFunction(true)) + .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), + new MaxCandidatesCondition(100)) + .build(); + + IOptimizationRunner runner = + new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); + + // add 3 different sessions to the same execution + HashMap statsStorageForSession = new HashMap<>(); + for (int i = 0; i < 3; i++) { + StatsStorage ss = new InMemoryStatsStorage(); + @NonNull String sessionId = "sid" + i; + statsStorageForSession.put(sessionId, ss); + StatusListener sl = new ArbiterStatusListener(sessionId, ss); + runner.addListeners(sl); + } + + Function statsStorageProvider = statsStorageForSession::get; + UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); + String serverAddress = uIServer.getAddress(); + for (String sessionId : statsStorageForSession.keySet()) { + log.info("Arbiter session will start at {}/arbiter/{}", serverAddress, sessionId); + } + + runner.execute(); + + Thread.sleep(1000000); + } + @Test @Ignore