From ef7c21c204f775f43e506730237eefcbcfa370d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Thu, 23 Apr 2020 13:56:32 +0200 Subject: [PATCH 01/10] fix logging of UI server auto-stop, fix indentation in test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../src/main/java/org/deeplearning4j/ui/VertxUIServer.java | 3 ++- .../src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java index e42c60f9e..1182c8f5f 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java @@ -194,8 +194,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { VertxUIServer.autoStopThread = new Thread(() -> { try { currentThread.join(); - log.info("Deeplearning4j UI server is auto-stopping."); if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) { + log.info("Deeplearning4j UI server is auto-stopping after thread (name: {}) died.", + currentThread.getName()); instance.stop(); } } catch (InterruptedException e) { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java index c47abbe61..582b2f1ac 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java @@ -40,7 +40,7 @@ import static org.junit.Assert.*; public class TestVertxUIManual extends BaseDL4JTest { @Override - public long getTimeoutMilliseconds() { + public long getTimeoutMilliseconds() { return 3600_000L; } From f1ebced7a119bd39bbdc3833519ae65a91696778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Thu, 23 Apr 2020 15:05:29 +0200 Subject: [PATCH 02/10] simplify TrainModule to use getters of VertxUIServer instance for multi-session mode, session loader and server address MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../org/deeplearning4j/ui/VertxUIServer.java | 63 +++++++------------ .../ui/module/train/TrainModule.java | 52 ++++++--------- 2 files changed, 42 insertions(+), 73 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java index 1182c8f5f..7c6a27f4e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java @@ -77,8 +77,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { private static Integer instancePort; private static Thread autoStopThread; - private TrainModule trainModule; - /** * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started. * @param port TCP socket port for {@link HttpServer} to listen @@ -208,7 +206,11 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { private List uiModules = new CopyOnWriteArrayList<>(); private RemoteReceiverModule remoteReceiverModule; - private StatsStorageLoader statsStorageLoader; + /** + * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID + */ + @Getter + private Function statsStorageLoader; //typeIDModuleMap: Records which modules are registered for which type IDs private Map> typeIDModuleMap = new ConcurrentHashMap<>(); @@ -248,10 +250,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { */ public void autoAttachStatsStorageBySessionId(Function statsStorageProvider) { if (statsStorageProvider != null) { - this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider); - if (trainModule != null) { - this.trainModule.setSessionLoader(this.statsStorageLoader); - } + this.statsStorageLoader = (sessionId) -> { + log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ")."); + StatsStorage statsStorage = statsStorageProvider.apply(sessionId); + if (statsStorage != null) { + if (statsStorage.sessionExists(sessionId)) { + attach(statsStorage); + return true; + } + log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " + + "Session ID (" + sessionId + ") does not exist in StatsStorage."); + return false; + } else { + log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " + + "StatsStorageProvider returned null."); + return false; + } + }; } } @@ -303,8 +318,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { } uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" - trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress); - uiModules.add(trainModule); + uiModules.add(new TrainModule()); uiModules.add(new ConvolutionalListenerModule()); uiModules.add(new TsneModule()); uiModules.add(new SameDiffModule()); @@ -597,37 +611,6 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { } } - /** - * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID - */ - private class StatsStorageLoader implements Function { - - Function statsStorageProvider; - - StatsStorageLoader(Function statsStorageProvider) { - this.statsStorageProvider = statsStorageProvider; - } - - @Override - public Boolean apply(String sessionId) { - log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + sessionId + ")."); - StatsStorage statsStorage = statsStorageProvider.apply(sessionId); - if (statsStorage != null) { - if (statsStorage.sessionExists(sessionId)) { - attach(statsStorage); - return true; - } - log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. " + - "Session ID (" + sessionId + ") does not exist in StatsStorage."); - return false; - } else { - log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + sessionId + "). " + - "StatsStorageProvider returned null."); - return false; - } - } - } - //================================================================================================================== // CLI Launcher diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 82cf33b77..e5c8d116e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -26,8 +26,6 @@ import io.vertx.ext.web.RoutingContext; import it.unimi.dsi.fastutil.longs.LongArrayList; import lombok.AllArgsConstructor; import lombok.Data; -import lombok.Getter; -import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; @@ -43,6 +41,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.serde.JsonMappers; +import org.deeplearning4j.ui.VertxUIServer; import org.deeplearning4j.ui.api.HttpMethod; import org.deeplearning4j.ui.api.I18N; import org.deeplearning4j.ui.api.Route; @@ -56,7 +55,6 @@ import org.deeplearning4j.ui.model.stats.api.StatsInitializationReport; import org.deeplearning4j.ui.model.stats.api.StatsReport; import org.deeplearning4j.ui.model.stats.api.StatsType; import org.nd4j.common.function.Function; -import org.nd4j.common.function.Supplier; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Triple; @@ -86,8 +84,6 @@ public class TrainModule implements UIModule { private static final DecimalFormat df2 = new DecimalFormat("#.00"); private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); - private final Supplier addressSupplier; - private enum ModelType { MLN, CG, Layer } @@ -99,29 +95,14 @@ public class TrainModule implements UIModule { private Map workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID private Map> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID private Map lastUpdateForSession = new ConcurrentHashMap<>(); - private final boolean multiSession; - @Getter @Setter - private Function sessionLoader; private final Configuration configuration; - public TrainModule() { - this(false, null, null); - } - /** * TrainModule - * - * @param multiSession multi-session mode - * @param sessionLoader StatsStorage loader to call if an unknown session ID is passed as URL path parameter - * in multi-session mode - * @param addressSupplier supplier for server address (server address in PlayUIServer gets initialized after modules) */ - public TrainModule(boolean multiSession, Function sessionLoader, Supplier addressSupplier) { - this.multiSession = multiSession; - this.sessionLoader = sessionLoader; - this.addressSupplier = addressSupplier; + public TrainModule() { String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY); int value = DEFAULT_MAX_CHART_POINTS; if (maxChartPointsProp != null) { @@ -159,8 +140,9 @@ public class TrainModule implements UIModule { @Override public List getRoutes() { List r = new ArrayList<>(); - r.add(new Route("/train/multisession", HttpMethod.GET, (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); - if (multiSession) { + r.add(new Route("/train/multisession", HttpMethod.GET, + (path, rc) -> rc.response().end(VertxUIServer.getInstance().isMultiSession() ? "true" : "false"))); + if (VertxUIServer.getInstance().isMultiSession()) { r.add(new Route("/train", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); r.add(new Route("/train/:sessionId", HttpMethod.GET, (path, rc) -> { rc.response() @@ -264,7 +246,9 @@ public class TrainModule implements UIModule { if (!knownSessionIDs.isEmpty()) { sb.append(" "); } else { @@ -284,9 +268,11 @@ public class TrainModule implements UIModule { * * @param sessionId session ID to look fo with provider * @param targetPath one of overview / model / system, or null + * @param rc routing context */ private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { - if (sessionLoader != null && sessionLoader.apply(sessionId)) { + Function loader = VertxUIServer.getInstance().getStatsStorageLoader(); + if (loader != null && loader.apply(sessionId)) { if (targetPath != null) { rc.reroute(targetPath); } else { @@ -306,9 +292,9 @@ public class TrainModule implements UIModule { && StatsListener.TYPE_ID.equals(sse.getTypeID()) && !knownSessionIDs.containsKey(sse.getSessionID())) { knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage()); - if (multiSession) { + if (VertxUIServer.getInstance().isMultiSession()) { log.info("Adding training session {}/train/{} of StatsStorage instance {}", - addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage()); + VertxUIServer.getInstance().getAddress(), sse.getSessionID(), sse.getStatsStorage()); } } @@ -332,9 +318,9 @@ public class TrainModule implements UIModule { if (!StatsListener.TYPE_ID.equals(typeID)) continue; knownSessionIDs.put(sessionID, statsStorage); - if (multiSession) { + if (VertxUIServer.getInstance().isMultiSession()) { log.info("Adding training session {}/train/{} of StatsStorage instance {}", - addressSupplier.get(), sessionID, statsStorage); + VertxUIServer.getInstance().getAddress(), sessionID, statsStorage); } List latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID); @@ -364,9 +350,9 @@ public class TrainModule implements UIModule { } for (String s : toRemove) { knownSessionIDs.remove(s); - if (multiSession) { + if (VertxUIServer.getInstance().isMultiSession()) { log.info("Removing training session {}/train/{} of StatsStorage instance {}.", - addressSupplier.get(), s, statsStorage); + VertxUIServer.getInstance().getAddress(), s, statsStorage); } lastUpdateForSession.remove(s); } @@ -602,13 +588,13 @@ public class TrainModule implements UIModule { } /** - * Get global {@link I18N} instance if {@link #multiSession} is {@code true}, or instance for session + * Get global {@link I18N} instance if {@link VertxUIServer#isMultiSession()} is {@code true}, or instance for session * * @param sessionId session ID * @return {@link I18N} instance */ private I18N getI18N(String sessionId) { - return multiSession ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance(); + return VertxUIServer.getInstance().isMultiSession() ? I18NProvider.getInstance(sessionId) : I18NProvider.getInstance(); } From d3c759e03f85df54a3a05e427ccb1ea74a88036d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Thu, 23 Apr 2020 15:26:57 +0200 Subject: [PATCH 03/10] [WIP] basic handling of multi-session in ArbiterModule (routes: arbiter, arbiter/multisession) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../arbiter/ui/module/ArbiterModule.java | 81 +++++++++++++---- .../arbiter/optimize/TestBasic.java | 86 +++++++++++++++++++ 2 files changed, 151 insertions(+), 16 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java index 31e86d45f..6bcf3acbd 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java @@ -36,6 +36,7 @@ import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; import org.deeplearning4j.arbiter.ui.misc.UIUtils; import org.deeplearning4j.arbiter.util.ObjectUtils; import org.deeplearning4j.nn.conf.serde.JsonMappers; +import org.deeplearning4j.ui.VertxUIServer; import org.deeplearning4j.ui.api.Component; import org.deeplearning4j.ui.api.*; import org.deeplearning4j.ui.components.chart.ChartLine; @@ -77,7 +78,6 @@ public class ArbiterModule implements UIModule { private Map lastUpdateForSession = Collections.synchronizedMap(new HashMap<>()); - //Styles for UI: private static final StyleTable STYLE_TABLE = new StyleTable.Builder() .width(100, LengthUnit.Percent) @@ -134,20 +134,69 @@ public class ArbiterModule implements UIModule { @Override public List getRoutes() { - Route r1 = new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() - .putHeader("content-type", "text/html; charset=utf-8").sendFile("templates/ArbiterUI.html")); - Route r3 = new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc)); - Route r4 = new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc)); - Route r5 = new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, (path, rc) -> this.getCandidateInfo(path.get(0), rc)); - Route r6 = new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc)); - Route r7 = new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc)); - Route r8 = new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc)); + boolean multiSession = VertxUIServer.getMultiSession().get(); + List r = new ArrayList<>(); + r.add(new Route("/arbiter/multisession", HttpMethod.GET, + (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); + if (multiSession) { + r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); + } else { + r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() + .putHeader("content-type", "text/html; charset=utf-8") + .sendFile("templates/ArbiterUI.html"))); + r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc))); + r.add(new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, + (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc))); + r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, + (path, rc) -> this.getCandidateInfo(path.get(0), rc))); + r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc))); + r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc))); + r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc))); - Route r9a = new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.listSessions(rc)); - Route r9b = new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc)); - Route r9c = new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc)); + r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); + r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc))); + r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET, + (path, rc) -> this.setSession(path.get(0), rc))); + } - return Arrays.asList(r1, r3, r4, r5, r6, r7, r8, r9a, r9b, r9c); + return r; + } + + + /** + * List optimization sessions. Returns a HTML list of arbiter sessions + */ + private synchronized void listSessions(RoutingContext rc) { + StringBuilder sb = new StringBuilder("\n" + + "\n" + + "\n" + + " \n" + + " Optimization sessions - DL4J Arbiter UI\n" + + " \n" + + "\n" + + " \n" + + "

DL4J Arbiter UI

\n" + + "

UI server is in multi-session mode." + + " To visualize an optimization session, please select one from the following list.

\n" + + "

List of attached optimization sessions

\n"); + if (!knownSessionIDs.isEmpty()) { + sb.append(" "); + } else { + sb.append("No optimization session attached."); + } + + sb.append(" \n" + + "\n"); + + rc.response() + .putHeader("content-type", "text/html; charset=utf-8") + .end(sb.toString()); } @Override @@ -201,7 +250,7 @@ public class ArbiterModule implements UIModule { .end(asJson(sid)); } - private void listSessions(RoutingContext rc) { + private void sessionInfo(RoutingContext rc) { rc.response() .putHeader("content-type", "application/json") .end(asJson(knownSessionIDs.keySet())); @@ -309,7 +358,6 @@ public class ArbiterModule implements UIModule { * Get the info for a specific candidate - last section in the UI * * @param candidateId ID for the candidate - * @return Content/info for the candidate */ private void getCandidateInfo(String candidateId, RoutingContext rc){ @@ -320,7 +368,8 @@ public class ArbiterModule implements UIModule { return; } - GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID);; + GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss + .getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); Persistable p = ss.getLatestUpdate(currentSessionID, ARBITER_UI_TYPE_ID, candidateId); diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index 64c873348..cb2374d64 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -16,6 +16,8 @@ package org.deeplearning4j.arbiter.optimize; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.arbiter.ComputationGraphSpace; @@ -59,6 +61,7 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.function.Function; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; @@ -71,8 +74,14 @@ import java.util.concurrent.TimeUnit; /** * Created by Alex on 19/07/2017. */ +@Slf4j public class TestBasic extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 3600_000L; + } + @Test @Ignore public void testBasicUiOnly() throws Exception { @@ -82,6 +91,83 @@ public class TestBasic extends BaseDL4JTest { Thread.sleep(1000000); } + @Test + @Ignore + public void testBasicUiMultiSession() throws Exception { + + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); + + MultiLayerSpace mls = new MultiLayerSpace.Builder() + .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) + .l2(new ContinuousParameterSpace(0.0001, 0.05)) + .addLayer( + new ConvolutionLayerSpace.Builder().nIn(1) + .nOut(new IntegerParameterSpace(5, 30)) + .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, + new int[]{4, 4}, new int[]{5, 5})) + .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, + new int[]{2, 2})) + .activation(new DiscreteParameterSpace<>(Activation.RELU, + Activation.SOFTPLUS, Activation.LEAKYRELU)) + .build()) + .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) + .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) + .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers + .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); + Map commands = new HashMap<>(); +// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); + + //Define configuration: + CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); + DataProvider dataProvider = new MnistDataSetProvider(); + + + String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); + + File f = new File(modelSavePath); + if (f.exists()) + f.delete(); + f.mkdir(); + if (!f.exists()) + throw new RuntimeException(); + + OptimizationConfiguration configuration = + new OptimizationConfiguration.Builder() + .candidateGenerator(candidateGenerator).dataProvider(dataProvider) + .modelSaver(new FileModelSaver(modelSavePath)) + .scoreFunction(new TestSetLossScoreFunction(true)) + .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), + new MaxCandidatesCondition(100)) + .build(); + + IOptimizationRunner runner = + new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); + + // add 3 different sessions to the same execution + HashMap statsStorageForSession = new HashMap<>(); + for (int i = 0; i < 3; i++) { + StatsStorage ss = new InMemoryStatsStorage(); + @NonNull String sessionId = "sid" + i; + statsStorageForSession.put(sessionId, ss); + StatusListener sl = new ArbiterStatusListener(sessionId, ss); + runner.addListeners(sl); + } + + Function statsStorageProvider = statsStorageForSession::get; + UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); + String serverAddress = uIServer.getAddress(); + for (String sessionId : statsStorageForSession.keySet()) { + log.info("Arbiter session will start at {}/arbiter/{}", serverAddress, sessionId); + } + + runner.execute(); + + Thread.sleep(1000000); + } + @Test @Ignore From 64c9008ea765f8f869c43463241807338d8ebd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Thu, 30 Apr 2020 11:39:52 +0200 Subject: [PATCH 04/10] multi-session routes in ArbiterModule, optimize getLastUpdateTime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../arbiter/ui/module/ArbiterModule.java | 193 ++++++++++++------ .../arbiter/optimize/TestBasic.java | 2 +- 2 files changed, 130 insertions(+), 65 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java index 6bcf3acbd..dab0174df 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java @@ -51,6 +51,7 @@ import org.deeplearning4j.ui.components.text.style.StyleText; import org.deeplearning4j.ui.i18n.I18NResource; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; +import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; import org.nd4j.shade.jackson.core.JsonProcessingException; @@ -140,29 +141,89 @@ public class ArbiterModule implements UIModule { (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); if (multiSession) { r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); + r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> rc.response() + .putHeader("content-type", "text/html; charset=utf-8") + .sendFile("templates/ArbiterUI.html"))); + + r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getLastUpdateTime(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getCandidateInfo(path.get(0), path.get(1), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getOptimizationConfig(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getSummaryResults(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); + r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + this.getSummaryStatus(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); } else { r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() .putHeader("content-type", "text/html; charset=utf-8") .sendFile("templates/ArbiterUI.html"))); - r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(rc))); - r.add(new Route("/arbiter/lastUpdate/:ids", HttpMethod.GET, - (path, rc) -> this.getModelLastUpdateTimes(path.get(0), rc))); + r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc))); r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, - (path, rc) -> this.getCandidateInfo(path.get(0), rc))); - r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(rc))); - r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(rc))); - r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(rc))); + (path, rc) -> this.getCandidateInfo(null, path.get(0), rc))); + r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc))); + r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc))); + r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc))); - r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc))); r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET, (path, rc) -> this.setSession(path.get(0), rc))); } + // common for single- and multi-session mode + r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); return r; } + /** + * Load StatsStorage via provider, or return "not found" + * + * @param sessionId session ID to look fo with provider + * @param targetPath one of overview / model / system, or null + * @param rc routing context + */ + private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { + Function loader = VertxUIServer.getInstance().getStatsStorageLoader(); + if (loader != null && loader.apply(sessionId)) { + if (targetPath != null) { + rc.reroute(targetPath); + } else { + rc.response().end(); + } + } else { + rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code()) + .end("Unknown session ID: " + sessionId); + } + } + + /** * List optimization sessions. Returns a HTML list of arbiter sessions */ @@ -182,7 +243,7 @@ public class ArbiterModule implements UIModule { if (!knownSessionIDs.isEmpty()) { sb.append("
    "); for (String sessionId : knownSessionIDs.keySet()) { - sb.append("
  • ") .append(sessionId).append("
  • \n"); } @@ -306,10 +367,25 @@ public class ArbiterModule implements UIModule { /** * Return the last update time for the page + * @param sessionId session ID (optional, for multi-session mode) + * @param rc routing context */ - private void getLastUpdateTime(RoutingContext rc){ - //TODO - this forces updates on every request... which is fine, just inefficient - long t = System.currentTimeMillis(); + private void getLastUpdateTime(String sessionId, RoutingContext rc){ + if (sessionId == null) { + sessionId = currentSessionID; + } + StatsStorage ss = knownSessionIDs.get(sessionId); + List latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID); + long t = 0; + if (latestUpdates.isEmpty()) { + t = System.currentTimeMillis(); + } else { + for (Persistable update : latestUpdates) { + if (update.getTimeStamp() > t) { + t = update.getTimeStamp(); + } + } + } UpdateStatus us = new UpdateStatus(t, t, t); rc.response().putHeader("content-type", "application/json").end(asJson(us)); @@ -323,56 +399,28 @@ public class ArbiterModule implements UIModule { } } - /** - * Get the last update time for the specified model IDs - * @param modelIDs Model IDs to get the update time for - */ - private void getModelLastUpdateTimes(String modelIDs, RoutingContext rc){ - if(currentSessionID == null){ - rc.response().end(); - return; - } - - StatsStorage ss = knownSessionIDs.get(currentSessionID); - if(ss == null){ - log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", currentSessionID); - rc.response().end("-1"); - return; - } - - String[] split = modelIDs.split(","); - - long[] lastUpdateTimes = new long[split.length]; - for( int i=0; i allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); + List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); List table = new ArrayList<>(); for(Persistable per : allModelInfoTemp){ ModelInfoPersistable mip = (ModelInfoPersistable)per; @@ -663,16 +723,21 @@ public class ArbiterModule implements UIModule { /** * Get summary status information: first section in the page + * @param sessionId session ID (optional, for multi-session mode) + * @param rc routing context */ - private void getSummaryStatus(RoutingContext rc){ - StatsStorage ss = knownSessionIDs.get(currentSessionID); + private void getSummaryStatus(String sessionId, RoutingContext rc){ + if (sessionId == null) { + sessionId = currentSessionID; + } + StatsStorage ss = knownSessionIDs.get(sessionId); if(ss == null){ - log.debug("getOptimizationConfig(): Session ID is unknown: {}", currentSessionID); + log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId); rc.response().end(); return; } - Persistable p = ss.getStaticInfo(currentSessionID, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); + Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); if(p == null){ log.info("No static info"); @@ -692,7 +757,7 @@ public class ArbiterModule implements UIModule { //How to get this? query all model infos... - List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(currentSessionID, ARBITER_UI_TYPE_ID)); + List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); List allModelInfo = new ArrayList<>(); for(Persistable per : allModelInfoTemp){ ModelInfoPersistable mip = (ModelInfoPersistable)per; diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index cb2374d64..025ce85c6 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -160,7 +160,7 @@ public class TestBasic extends BaseDL4JTest { UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); String serverAddress = uIServer.getAddress(); for (String sessionId : statsStorageForSession.keySet()) { - log.info("Arbiter session will start at {}/arbiter/{}", serverAddress, sessionId); + log.info("Arbiter session can be attached at {}/arbiter/{}", serverAddress, sessionId); } runner.execute(); From 8f55214d1b343a4e8269d0f9294587b4b0a1573e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Tue, 5 May 2020 18:54:42 +0200 Subject: [PATCH 05/10] JavaScript side of multi-session support in Arbiter UI MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../main/resources/templates/ArbiterUI.html | 229 +++++++++++------- 1 file changed, 143 insertions(+), 86 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html b/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html index 9a5b6a998..a1b4e92a0 100644 --- a/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html +++ b/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html @@ -197,110 +197,167 @@ var resultsTableContent; var selectedCandidateIdx = null; + + //Multi-session mode + var multiSession = null; + //Session selection + var currSession = ""; + + function getSessionIdFromUrl() { + // path is like /arbiter/:sessionId/overview + var sessionIdRegexp = /\/arbiter\/([^\/]+)/g; + var match = sessionIdRegexp.exec(window.location.pathname) + return match[1]; + } + function getCurrSession(callback) { + if (multiSession) { + if (currSession == "") { + // get only once + currSession = getSessionIdFromUrl(); + } + //we don't show session selector in multi-session mode (one can list sessions at /arbiter) + callback(); + } else { + $.ajax({ + url: "/arbiter/sessions/current", + async: true, + error: function (query, status, error) { + console.log("Error getting data: " + error); + }, + success: function (data) { + currSession = data; + console.log("Current session: " + currSession); + + //Update available sessions in session selector + $.get("/arbiter/sessions/all", function(data){ + var keys = data; // JSON.stringify(data); + + if(keys.length > 1){ + $("#sessionSelectDiv").show(); + + var elem = $("#sessionSelect"); + elem.empty(); + + var currSelectedIdx = 0; + for (var i = 0; i < keys.length; i++) { + if(keys[i] == currSession){ + currSelectedIdx = i; + } + elem.append(""); + } + + $("#sessionSelect option[value='" + keys[currSelectedIdx] +"']").attr("selected", "selected"); + $("#sessionSelectDiv").show(); + } + // console.log("Got sessions: " + keys + ", current: " + currSession); + callback(); + }); + + } + }); + } + } + + function getSessionSettings(callback) { + // load only once + if (multiSession != null) { + getCurrSession(callback); + } else { + $.ajax({ + url: "/arbiter/multisession", + async: true, + error: function (query, status, error) { + console.log("Error getting data: " + error); + }, + success: function (data) { + multiSession = data == "true"; + getCurrSession(callback); + } + }); + } + } + + //Initial update + doUpdate(); //Set basic interval function to do updates setInterval(doUpdate,5000); //Loop every 5 seconds - - + function doUpdate(){ //Get the update status, and do something with it: - $.get("/arbiter/lastUpdate",function(data){ - //Encoding: matches names in UpdateStatus class - var jsonObj = JSON.parse(JSON.stringify(data)); - var statusTime = jsonObj['statusUpdateTime']; - var settingsTime = jsonObj['settingsUpdateTime']; - var resultsTime = jsonObj['resultsUpdateTime']; - //console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime); + getSessionSettings(function(){ + var sessionUpdateUrl = multiSession ? "/arbiter/" + currSession + "/lastUpdate" : "/arbiter/lastUpdate"; + $.get(sessionUpdateUrl,function(data){ + //Encoding: matches names in UpdateStatus class + var jsonObj = JSON.parse(JSON.stringify(data)); + var statusTime = jsonObj['statusUpdateTime']; + var settingsTime = jsonObj['settingsUpdateTime']; + var resultsTime = jsonObj['resultsUpdateTime']; + //console.log("Last update times: " + statusTime + ", " + settingsTime + ", " + resultsTime); - //Update available sessions: - var currSession; - $.get("/arbiter/sessions/current", function(data){ - currSession = data; //JSON.stringify(data); - console.log("Current: " + currSession); - }); + //Check last update times for each part of document, and update as necessary + //First section: summary status + if(lastStatusUpdateTime != statusTime){ + var summaryStatusUrl = multiSession ? "/arbiter/" + currSession + "/summary" : "/arbiter/summary"; + $.get(summaryStatusUrl,function(data){ + var summaryStatusDiv = $('#statusdiv'); + summaryStatusDiv.html(''); - $.get("/arbiter/sessions/all", function(data){ - var keys = data; // JSON.stringify(data); + var str = JSON.stringify(data); + var component = Component.getComponent(str); + component.render(summaryStatusDiv); + }); - if(keys.length > 1){ - $("#sessionSelectDiv").show(); - - var elem = $("#sessionSelect"); - elem.empty(); - - var currSelectedIdx = 0; - for (var i = 0; i < keys.length; i++) { - if(keys[i] == currSession){ - currSelectedIdx = i; - } - elem.append(""); - } - - $("#sessionSelect option[value='" + keys[currSelectedIdx] +"']").attr("selected", "selected"); - $("#sessionSelectDiv").show(); + lastStatusUpdateTime = statusTime; } -// console.log("Got sessions: " + keys + ", current: " + currSession); - }); + //Second section: Optimization settings + if(lastSettingsUpdateTime != settingsTime){ + //Get JSON for components + var settingsUrl = multiSession ? "/arbiter/" + currSession + "/config" : "/arbiter/config"; + $.get(settingsUrl,function(data){ + var str = JSON.stringify(data); - //Check last update times for each part of document, and update as necessary - //First section: summary status - if(lastStatusUpdateTime != statusTime){ - //Get JSON: address set by SummaryStatusResource - $.get("/arbiter/summary",function(data){ - var summaryStatusDiv = $('#statusdiv'); - summaryStatusDiv.html(''); + var configDiv = $('#settingsdiv'); + configDiv.html(''); - var str = JSON.stringify(data); - var component = Component.getComponent(str); - component.render(summaryStatusDiv); - }); + var component = Component.getComponent(str); + component.render(configDiv); + }); - lastStatusUpdateTime = statusTime; - } + lastSettingsUpdateTime = settingsTime; + } - //Second section: Optimization settings - if(lastSettingsUpdateTime != settingsTime){ - //Get JSON for components - $.get("/arbiter/config",function(data){ - var str = JSON.stringify(data); + //Third section: Summary results table (summary info for each candidate) + if(lastResultsUpdateTime != resultsTime){ + //Get JSON for results table + var resultsUrl = multiSession ? "/arbiter/" + currSession + "/results" : "/arbiter/results"; + $.get(resultsUrl,function(data){ + //Expect an array of CandidateInfo type objects here + resultsTableContent = data; + drawResultTable(); + }); - var configDiv = $('#settingsdiv'); - configDiv.html(''); + lastResultsUpdateTime = resultsTime; + } - var component = Component.getComponent(str); - component.render(configDiv); - }); + //Finally: Currently selected result + if(selectedCandidateIdx != null){ + //Get JSON for components + var candidateInfoUrl = multiSession + ? "/arbiter/" + currSession + "/candidateInfo/" + selectedCandidateIdx + : "/arbiter/candidateInfo/" + selectedCandidateIdx; + $.get(candidateInfoUrl,function(data){ + var str = JSON.stringify(data); - lastSettingsUpdateTime = settingsTime; - } + var resultsViewDiv = $('#resultsviewdiv'); + resultsViewDiv.html(''); - //Third section: Summary results table (summary info for each candidate) - if(lastResultsUpdateTime != resultsTime){ - - //Get JSON; address set by SummaryResultsResource - $.get("/arbiter/results",function(data){ - //Expect an array of CandidateInfo type objects here - resultsTableContent = data; - drawResultTable(); - }); - - lastResultsUpdateTime = resultsTime; - } - - //Finally: Currently selected result - if(selectedCandidateIdx != null){ - //Get JSON for components - $.get("/arbiter/candidateInfo/"+selectedCandidateIdx,function(data){ - var str = JSON.stringify(data); - - var resultsViewDiv = $('#resultsviewdiv'); - resultsViewDiv.html(''); - - var component = Component.getComponent(str); - component.render(resultsViewDiv); - }); - } + var component = Component.getComponent(str); + component.render(resultsViewDiv); + }); + } + }) }) } From cf11c2394a18d3e00e56eef92bd99570d2f135d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Wed, 6 May 2020 11:27:45 +0200 Subject: [PATCH 06/10] fix session list relative links in TrainModule and ArbiterModule; fix auto-attach in ArbiterModule at arbiter/sid MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../arbiter/ui/module/ArbiterModule.java | 14 ++++++++++---- .../ui/module/train/TrainModule.java | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java index dab0174df..86be0f8d3 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java @@ -141,9 +141,15 @@ public class ArbiterModule implements UIModule { (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); if (multiSession) { r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); - r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> rc.response() - .putHeader("content-type", "text/html; charset=utf-8") - .sendFile("templates/ArbiterUI.html"))); + r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + rc.response() + .putHeader("content-type", "text/html; charset=utf-8") + .sendFile("templates/ArbiterUI.html"); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> { if (knownSessionIDs.containsKey(path.get(0))) { @@ -243,7 +249,7 @@ public class ArbiterModule implements UIModule { if (!knownSessionIDs.isEmpty()) { sb.append("
      "); for (String sessionId : knownSessionIDs.keySet()) { - sb.append("
    • ") .append(sessionId).append("
    • \n"); } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index e5c8d116e..8bde827f5 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -246,7 +246,7 @@ public class TrainModule implements UIModule { if (!knownSessionIDs.isEmpty()) { sb.append("
        "); for (String sessionId : knownSessionIDs.keySet()) { - sb.append("
      • ") .append(sessionId).append("
      • \n"); } From 55be66906966a49e8f9fbb7236d23636c1250e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Wed, 6 May 2020 13:39:36 +0200 Subject: [PATCH 07/10] use last completed task for Total Runtime instead of currentTimeMillis in ArbiterModule Summary Status MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../deeplearning4j/arbiter/ui/module/ArbiterModule.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java index 86be0f8d3..7a8d70387 100644 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java @@ -788,7 +788,6 @@ public class ArbiterModule implements UIModule { //TODO: I18N - //TODO don't use currentTimeMillis due to stored data?? long bestTime; Double bestScore = null; String bestModelString = null; @@ -805,7 +804,12 @@ public class ArbiterModule implements UIModule { String execTotalRuntimeStr = ""; if(execStartTime > 0){ execStartTimeStr = TIME_FORMATTER.print(execStartTime); - execTotalRuntimeStr = UIUtils.formatDuration(System.currentTimeMillis() - execStartTime); + // allModelInfo is sorted by Persistable::getTimeStamp + long lastCompleteTime = execStartTime; + if (!allModelInfo.isEmpty()) { + lastCompleteTime = allModelInfo.get(allModelInfo.size() - 1).getTimeStamp(); + } + execTotalRuntimeStr = UIUtils.formatDuration(lastCompleteTime - execStartTime); } From cf24728f353ddd47b6fc56fb7f16ea9b359abe2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Wed, 6 May 2020 14:56:22 +0200 Subject: [PATCH 08/10] tests for auto-attach and manual attach; improve JavaDoc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../arbiter/optimize/TestBasic.java | 307 +++++++++++------- .../ui/TestVertxUIMultiSession.java | 4 +- 2 files changed, 195 insertions(+), 116 deletions(-) diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index 025ce85c6..08f130369 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import io.netty.handler.codec.http.HttpResponseStatus; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; @@ -65,12 +66,17 @@ import org.nd4j.linalg.function.Function; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Properties; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLEncoder; +import java.util.*; import java.util.concurrent.TimeUnit; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + /** * Created by Alex on 19/07/2017. */ @@ -88,112 +94,15 @@ public class TestBasic extends BaseDL4JTest { UIServer.getInstance(); - Thread.sleep(1000000); + Thread.sleep(1000_000); } - @Test - @Ignore - public void testBasicUiMultiSession() throws Exception { - - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .addLayer( - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 30)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, - new int[]{4, 4}, new int[]{5, 5})) - .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, - new int[]{2, 2})) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build()) - .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - Map commands = new HashMap<>(); -// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new MnistDataSetProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - // add 3 different sessions to the same execution - HashMap statsStorageForSession = new HashMap<>(); - for (int i = 0; i < 3; i++) { - StatsStorage ss = new InMemoryStatsStorage(); - @NonNull String sessionId = "sid" + i; - statsStorageForSession.put(sessionId, ss); - StatusListener sl = new ArbiterStatusListener(sessionId, ss); - runner.addListeners(sl); - } - - Function statsStorageProvider = statsStorageForSession::get; - UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); - String serverAddress = uIServer.getAddress(); - for (String sessionId : statsStorageForSession.keySet()) { - log.info("Arbiter session can be attached at {}/arbiter/{}", serverAddress, sessionId); - } - - runner.execute(); - - Thread.sleep(1000000); - } - - @Test @Ignore public void testBasicMnist() throws Exception { Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .addLayer( - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 30)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, - new int[]{4, 4}, new int[]{5, 5})) - .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, - new int[]{2, 2})) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build()) - .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); + MultiLayerSpace mls = getMultiLayerSpaceMnist(); Map commands = new HashMap<>(); // commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); @@ -230,7 +139,30 @@ public class TestBasic extends BaseDL4JTest { UIServer.getInstance().attach(ss); runner.execute(); - Thread.sleep(100000); + Thread.sleep(1000_000); + } + + private static MultiLayerSpace getMultiLayerSpaceMnist() { + return new MultiLayerSpace.Builder() + .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) + .l2(new ContinuousParameterSpace(0.0001, 0.05)) + .addLayer( + new ConvolutionLayerSpace.Builder().nIn(1) + .nOut(new IntegerParameterSpace(5, 30)) + .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, + new int[]{4, 4}, new int[]{5, 5})) + .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, + new int[]{2, 2})) + .activation(new DiscreteParameterSpace<>(Activation.RELU, + Activation.SOFTPLUS, Activation.LEAKYRELU)) + .build()) + .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) + .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) + .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers + .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) + .lossFunction(LossFunctions.LossFunction.MCXENT).build()) + .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); } @Test @@ -319,7 +251,7 @@ public class TestBasic extends BaseDL4JTest { .build(); //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); + CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs); DataProvider dataProvider = new MnistDataSetProvider(); @@ -417,7 +349,7 @@ public class TestBasic extends BaseDL4JTest { UIServer.getInstance().attach(ss); runner.execute(); - Thread.sleep(100000); + Thread.sleep(1000_000); } @@ -482,7 +414,7 @@ public class TestBasic extends BaseDL4JTest { UIServer.getInstance().attach(ss); runner.execute(); - Thread.sleep(100000); + Thread.sleep(1000_000); } @@ -519,7 +451,7 @@ public class TestBasic extends BaseDL4JTest { .build(); //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); + CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs); DataProvider dataProvider = new MnistDataSetProvider(); @@ -551,13 +483,17 @@ public class TestBasic extends BaseDL4JTest { UIServer.getInstance().attach(ss); runner.execute(); - Thread.sleep(100000); + Thread.sleep(1000_000); } + /** + * Visualize multiple optimization sessions run one after another on single-session mode UI + * @throws InterruptedException if current thread has been interrupted + */ @Test @Ignore - public void testBasicMnistMultipleSessions() throws Exception { + public void testBasicMnistMultipleSessions() throws InterruptedException { MultiLayerSpace mls = new MultiLayerSpace.Builder() .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) @@ -585,8 +521,10 @@ public class TestBasic extends BaseDL4JTest { //Define configuration: CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new MnistDataSetProvider(); + Class ds = MnistDataSource.class; + Properties dsp = new Properties(); + dsp.setProperty("minibatch", "8"); String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); @@ -599,7 +537,7 @@ public class TestBasic extends BaseDL4JTest { OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) + .candidateGenerator(candidateGenerator).dataSource(ds, dsp) .modelSaver(new FileModelSaver(modelSavePath)) .scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), @@ -621,7 +559,7 @@ public class TestBasic extends BaseDL4JTest { candidateGenerator = new RandomSearchGenerator(mls, commands); configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) + .candidateGenerator(candidateGenerator).dataSource(ds, dsp) .modelSaver(new FileModelSaver(modelSavePath)) .scoreFunction(new TestSetLossScoreFunction(true)) .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), @@ -636,7 +574,148 @@ public class TestBasic extends BaseDL4JTest { runner.execute(); - Thread.sleep(100000); + Thread.sleep(1000_000); + } + + /** + * Auto-attach multiple optimization sessions to multi-session mode UI + * @throws IOException if could not connect to the server + */ + @Test + public void testUiMultiSessionAutoAttach() throws IOException { + + //Define configuration: + MultiLayerSpace mls = getMultiLayerSpaceMnist(); + CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); + + Class ds = MnistDataSource.class; + Properties dsp = new Properties(); + dsp.setProperty("minibatch", "8"); + + String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\") + .getAbsolutePath(); + + File f = new File(modelSavePath); + if (f.exists()) + f.delete(); + f.mkdir(); + if (!f.exists()) + throw new RuntimeException(); + + OptimizationConfiguration configuration = + new OptimizationConfiguration.Builder() + .candidateGenerator(candidateGenerator).dataSource(ds, dsp) + .modelSaver(new FileModelSaver(modelSavePath)) + .scoreFunction(new TestSetLossScoreFunction(true)) + .terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS), + new MaxCandidatesCondition(1)) + .build(); + + IOptimizationRunner runner = + new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); + + // add 3 different sessions to the same execution + HashMap statsStorageForSession = new HashMap<>(); + for (int i = 0; i < 3; i++) { + StatsStorage ss = new InMemoryStatsStorage(); + @NonNull String sessionId = "sid" + i; + statsStorageForSession.put(sessionId, ss); + StatusListener sl = new ArbiterStatusListener(sessionId, ss); + runner.addListeners(sl); + } + + Function statsStorageProvider = statsStorageForSession::get; + UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); + String serverAddress = uIServer.getAddress(); + + runner.execute(); + + for (String sessionId : statsStorageForSession.keySet()) { + /* + * Visiting /arbiter/:sessionId to auto-attach StatsStorage + */ + String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId); + HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); + conn.connect(); + + log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId)); + assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); + assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId))); + } + } + + /** + * Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL + * @throws Exception if an error occurred + */ + @Test + @Ignore + public void testUiMultiSessionManualAttach() throws Exception { + Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); + + //Define configuration: + MultiLayerSpace mls = getMultiLayerSpaceMnist(); + CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); + + Class ds = MnistDataSource.class; + Properties dsp = new Properties(); + dsp.setProperty("minibatch", "8"); + + String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\") + .getAbsolutePath(); + + File f = new File(modelSavePath); + if (f.exists()) + f.delete(); + f.mkdir(); + if (!f.exists()) + throw new RuntimeException(); + + OptimizationConfiguration configuration = + new OptimizationConfiguration.Builder() + .candidateGenerator(candidateGenerator).dataSource(ds, dsp) + .modelSaver(new FileModelSaver(modelSavePath)) + .scoreFunction(new TestSetLossScoreFunction(true)) + .terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES), + new MaxCandidatesCondition(10)) + .build(); + + + // parallel execution of multiple optimization sessions + HashMap statsStorageForSession = new HashMap<>(); + for (int i = 0; i < 3; i++) { + String sessionId = "sid" + i; + IOptimizationRunner runner = + new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); + StatsStorage ss = new InMemoryStatsStorage(); + statsStorageForSession.put(sessionId, ss); + StatusListener sl = new ArbiterStatusListener(sessionId, ss); + runner.addListeners(sl); + // Asynchronous execution + new Thread(runner::execute).start(); + } + + Function statsStorageProvider = statsStorageForSession::get; + UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); + String serverAddress = uIServer.getAddress(); + + for (String sessionId : statsStorageForSession.keySet()) { + log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId)); + } + + Thread.sleep(1000_000); + } + + + /** + * Get URL for arbiter session on given server address + * @param serverAddress server address, e.g.: http://localhost:9000 + * @param sessionId session ID (will be URL-encoded) + * @return URL + * @throws UnsupportedEncodingException if the character encoding is not supported + */ + private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException { + return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8")); } private static class MnistDataSetProvider implements DataProvider { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index c9577d4a3..0f3f50d41 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -197,9 +197,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { } /** - * Get URL-encoded URL for training session on given server address + * Get URL for training session on given server address * @param serverAddress server address - * @param sessionId session ID + * @param sessionId session ID (will be URL-encoded) * @return URL * @throws UnsupportedEncodingException if the used encoding is not supported */ From 9f47c8aca8d068480d85d8b48a84c68ea344c7c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Fenyvesi?= Date: Fri, 8 May 2020 12:51:01 +0200 Subject: [PATCH 09/10] narrow down UIServer.stop() to declare throwing InterruptedException; fix typo in test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Tamás Fenyvesi --- .../src/main/java/org/deeplearning4j/ui/api/UIServer.java | 3 ++- .../src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java index 81180990f..2bc98c20c 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java @@ -157,8 +157,9 @@ public interface UIServer { /** * Stop/shut down the UI server. This synchronous function should wait until the server is stopped. + * @throws InterruptedException if the current thread is interrupted while waiting */ - void stop() throws Exception; + void stop() throws InterruptedException; /** * Stop/shut down the UI server. diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java index 582b2f1ac..7fb0a041e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java @@ -259,7 +259,7 @@ public class TestVertxUIManual extends BaseDL4JTest { log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.", sessionId, autoDetachTimeoutMillis); uIServer.detach(statsStorage); - log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}", + log.info(" To re-attach StatsStorage of training session, visit {}/train/{}", uIServer.getAddress(), sessionId); } }).start(); From ea56ce747138e413d8fdd953b9994b92f973fbfd Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 11 May 2020 12:25:31 +1000 Subject: [PATCH 10/10] Import fix given recent refactoring Signed-off-by: Alex Black --- .../java/org/deeplearning4j/arbiter/optimize/TestBasic.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index 08f130369..3ecefe0b3 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -57,12 +57,12 @@ import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.common.function.Function; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.function.Function; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.io.File;