diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java index 64b033133..77382d6f4 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java @@ -19,6 +19,8 @@ package org.deeplearning4j.ui; import com.beust.jcommander.JCommander; import com.beust.jcommander.Parameter; import io.vertx.core.AbstractVerticle; +import io.vertx.core.Future; +import io.vertx.core.Promise; import io.vertx.core.Vertx; import io.vertx.core.http.HttpServer; import io.vertx.core.http.impl.MimeMapping; @@ -35,6 +37,7 @@ import org.deeplearning4j.api.storage.StatsStorageEvent; import org.deeplearning4j.api.storage.StatsStorageListener; import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.config.DL4JSystemProperties; +import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIServer; @@ -72,35 +75,69 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { private static Function statsStorageProvider; private static Integer instancePort; + private static Thread autoStopThread; private TrainModule trainModule; - public static VertxUIServer getInstance() { - return getInstance(null, multiSession.get(), null); + /** + * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started. + * @param port TCP socket port for {@link HttpServer} to listen + * @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs. + *
URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId + * @param statsStorageProvider function that returns a StatsStorage containing the given session ID. + *
Use this to auto-attach StatsStorage if an unknown session ID is passed + * as URL path parameter in multi-session mode, or leave it {@code null}. + * @return UI instance for this JVM + * @throws DL4JException if UI server failed to start; + * if the instance has already started in a different mode (multi/single-session); + * if interrupted while waiting for completion + */ + public static VertxUIServer getInstance(Integer port, boolean multiSession, + Function statsStorageProvider) throws DL4JException { + return getInstance(port, multiSession, statsStorageProvider, null); } - public static VertxUIServer getInstance(Integer port, boolean multiSession, Function statsStorageProvider){ + /** + * + * Get (and, initialize if necessary) the UI server. This function will wait until the server started + * (synchronous way), or pass the given callback to handle success or failure (asynchronous way). + * @param port TCP socket port for {@link HttpServer} to listen + * @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs. + *
URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId + * @param statsStorageProvider function that returns a StatsStorage containing the given session ID. + *
Use this to auto-attach StatsStorage if an unknown session ID is passed + * as URL path parameter in multi-session mode, or leave it {@code null}. + * @param startCallback asynchronous deployment handler callback that will be notify of success or failure. + * If {@code null} given, then this method will wait until deployment is complete. + * If the deployment is successful the result will contain a String representing the + * unique deployment ID of the deployment. + * @return UI server instance + * @throws DL4JException if UI server failed to start; + * if the instance has already started in a different mode (multi/single-session); + * if interrupted while waiting for completion + */ + public static VertxUIServer getInstance(Integer port, boolean multiSession, + Function statsStorageProvider, Promise startCallback) + throws DL4JException { if (instance == null || instance.isStopped()) { VertxUIServer.multiSession.set(multiSession); VertxUIServer.setStatsStorageProvider(statsStorageProvider); instancePort = port; - Vertx vertx = Vertx.vertx(); - //Launch UI server verticle and wait for it to start - CountDownLatch l = new CountDownLatch(1); - vertx.deployVerticle(VertxUIServer.class.getName(), res -> { - l.countDown(); - }); - try { - l.await(5000, TimeUnit.MILLISECONDS); - } catch (InterruptedException e){ } //Ignore + if (startCallback != null) { + //Launch UI server verticle and pass asynchronous callback that will be notified of completion + deploy(startCallback); + } else { + //Launch UI server verticle and wait for it to start + deploy(); + } } else if (!instance.isStopped()) { - if (instance.multiSession.get() && !instance.isMultiSession()) { - throw new RuntimeException("Cannot return multi-session instance." + + if (multiSession && !instance.isMultiSession()) { + throw new DL4JException("Cannot return multi-session instance." + " UIServer has already started in single-session mode at " + instance.getAddress() + " You may stop the UI server instance, and start a new one."); - } else if (!instance.multiSession.get() && instance.isMultiSession()) { - throw new RuntimeException("Cannot return single-session instance." + + } else if (!multiSession && instance.isMultiSession()) { + throw new DL4JException("Cannot return single-session instance." + " UIServer has already started in multi-session mode at " + instance.getAddress() + " You may stop the UI server instance, and start a new one."); } @@ -109,6 +146,64 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { return instance; } + /** + * Deploy (start) {@link VertxUIServer}, waiting until starting is complete. + * @throws DL4JException if UI server failed to start; + * if interrupted while waiting for completion + */ + private static void deploy() throws DL4JException { + CountDownLatch l = new CountDownLatch(1); + Promise promise = Promise.promise(); + promise.future().compose( + success -> Future.future(prom -> l.countDown()), + failure -> Future.future(prom -> l.countDown()) + ); + deploy(promise); + // synchronous function + try { + l.await(); + } catch (InterruptedException e) { + throw new DL4JException(e); + } + + Future future = promise.future(); + if (future.failed()) { + throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause()); + } + } + + /** + * Deploy (start) {@link VertxUIServer}, + * and pass callback to handle successful or failed completion of deployment. + * @param startCallback promise that will handle success or failure of deployment. + * If the deployment is successful the result will contain a String representing the unique deployment ID of the + * deployment. + */ + private static void deploy(Promise startCallback) { + log.debug("Deeplearning4j UI server is starting."); + Promise promise = Promise.promise(); + promise.future().compose( + success -> Future.future(prom -> startCallback.complete(success)), + failure -> Future.future(prom -> startCallback.fail(new RuntimeException(failure))) + ); + + Vertx vertx = Vertx.vertx(); + vertx.deployVerticle(VertxUIServer.class.getName(), promise); + + Thread currentThread = Thread.currentThread(); + VertxUIServer.autoStopThread = new Thread(() -> { + try { + currentThread.join(); + log.info("Deeplearning4j UI server is auto-stopping."); + if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) { + instance.stop(); + } + } catch (InterruptedException e) { + log.error("Deeplearning4j UI server auto-stop thread was interrupted.", e); + } + }); + } + private List uiModules = new CopyOnWriteArrayList<>(); private RemoteReceiverModule remoteReceiverModule; @@ -132,10 +227,18 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { instance = this; } - public static void stopInstance(){ - if(instance == null) + public static void stopInstance() throws Exception { + if(instance == null || instance.isStopped()) return; instance.stop(); + VertxUIServer.reset(); + } + + private static void reset() { + VertxUIServer.instance = null; + VertxUIServer.statsStorageProvider = null; + VertxUIServer.instancePort = null; + VertxUIServer.multiSession.set(false); } /** @@ -145,12 +248,14 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { public void autoAttachStatsStorageBySessionId(Function statsStorageProvider) { if (statsStorageProvider != null) { this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider); - this.trainModule.setSessionLoader(this.statsStorageLoader); + if (trainModule != null) { + this.trainModule.setSessionLoader(this.statsStorageLoader); + } } } @Override - public void start() throws Exception { + public void start(Promise startCallback) throws Exception { //Create REST endpoints File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis()); uploadDir.mkdirs(); @@ -192,6 +297,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { }); } + if (VertxUIServer.statsStorageProvider != null) { + autoAttachStatsStorageBySessionId(VertxUIServer.statsStorageProvider); + } uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress); @@ -246,17 +354,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { } } - - server = vertx.createHttpServer() - .requestHandler(r) - .listen(port); - uiEventRoutingThread = new Thread(new StatsEventRouterRunnable()); uiEventRoutingThread.setDaemon(true); uiEventRoutingThread.start(); - String address = UIServer.getInstance().getAddress(); - log.info("Deeplearning4j UI server started at: {}", address); + server = vertx.createHttpServer() + .requestHandler(r) + .listen(port, result -> { + if (result.succeeded()) { + String address = UIServer.getInstance().getAddress(); + log.info("Deeplearning4j UI server started at: {}", address); + startCallback.complete(); + } else { + startCallback.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port " + + server.actualPort(), result.cause())); + } + }); + VertxUIServer.autoStopThread.start(); } private List extractArgsFromRoute(String path, RoutingContext rc) { @@ -302,11 +416,33 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { } @Override - public void stop() { - server.close(); - shutdown.set(true); + public void stop() throws InterruptedException { + CountDownLatch l = new CountDownLatch(1); + Promise promise = Promise.promise(); + promise.future().compose( + successEvent -> Future.future(prom -> l.countDown()), + failureEvent -> Future.future(prom -> l.countDown()) + ); + stopAsync(promise); + // synchronous function should wait until the server is stopped + l.await(); } + @Override + public void stopAsync(Promise stopCallback) { + /** + * Stop Vertx instance and release any resources held by it. + * Pass promise to {@link #stop(Promise)}. + */ + vertx.close(ar -> stopCallback.handle(ar)); + } + + @Override + public void stop(Promise stopCallback) { + shutdown.set(true); + stopCallback.complete(); + log.info("Deeplearning4j UI server stopped."); + } @Override public boolean isStopped() { @@ -353,8 +489,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { if (!statsStorageInstances.contains(statsStorage)) return; //No op boolean found = false; - for (Iterator> iterator = listeners.iterator(); iterator.hasNext(); ) { - Pair p = iterator.next(); + for (Pair p : listeners) { if (p.getFirst() == statsStorage) { //Same object, not equality statsStorage.deregisterStatsStorageListener(p.getSecond()); listeners.remove(p); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java index 368863333..e287111c2 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java @@ -17,8 +17,10 @@ package org.deeplearning4j.ui.api; +import io.vertx.core.Promise; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorageRouter; +import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.ui.VertxUIServer; import org.nd4j.linalg.function.Function; @@ -32,18 +34,24 @@ import java.util.List; public interface UIServer { /** - * Get (and, initialize if necessary) the UI server. + * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started. * Singleton pattern - all calls to getInstance() will return the same UI instance. * * @return UI instance for this JVM - * @throws RuntimeException if the instance has already started in a different mode (multi/single-session) + * @throws DL4JException if UI server failed to start; + * if the instance has already started in a different mode (multi/single-session); + * if interrupted while waiting for completion */ - static UIServer getInstance() throws RuntimeException { - return getInstance(false, null); + static UIServer getInstance() throws DL4JException { + if (VertxUIServer.getInstance() != null && !VertxUIServer.getInstance().isStopped()) { + return VertxUIServer.getInstance(); + } else { + return getInstance(false, null); + } } /** - * Get (and, initialize if necessary) the UI server. + * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started. * Singleton pattern - all calls to getInstance() will return the same UI instance. * * @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs. @@ -52,16 +60,19 @@ public interface UIServer { *
Use this to auto-attach StatsStorage if an unknown session ID is passed * as URL path parameter in multi-session mode, or leave it {@code null}. * @return UI instance for this JVM - * @throws RuntimeException if the instance has already started in a different mode (multi/single-session) + * @throws DL4JException if UI server failed to start; + * if the instance has already started in a different mode (multi/single-session); + * if interrupted while waiting for completion */ - static UIServer getInstance(boolean multiSession, Function statsStorageProvider) throws RuntimeException { + static UIServer getInstance(boolean multiSession, Function statsStorageProvider) + throws DL4JException { return VertxUIServer.getInstance(null, multiSession, statsStorageProvider); } /** * Stop UIServer instance, if already running */ - static void stopInstance() { + static void stopInstance() throws Exception { VertxUIServer.stopInstance(); } @@ -144,8 +155,15 @@ public interface UIServer { boolean isRemoteListenerEnabled(); /** - * Stop/shut down the UI server. + * Stop/shut down the UI server. This synchronous function should wait until the server is stopped. */ - void stop(); + void stop() throws Exception; + /** + * Stop/shut down the UI server. + * This asynchronous function should immediately return, and notify the callback {@link Promise} on completion: + * either call {@link Promise#complete} or {@link io.vertx.core.Promise#fail}. + * @param stopCallback callback {@link Promise} to notify on completion + */ + void stopAsync(Promise stopCallback); } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java index dd57017bf..523dcd48e 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java @@ -69,15 +69,19 @@ public class DefaultI18N implements I18N { } /** - * Get instance for session (used in multi-session mode) - * @param sessionId session - * @return instance for session + * Get instance for session + * @param sessionId session ID for multi-session mode, leave it {@code null} for global instance + * @return instance for session, or global instance */ public static synchronized I18N getInstance(String sessionId) { - if (!sessionInstances.containsKey(sessionId)) { - sessionInstances.put(sessionId, new DefaultI18N()); + if (sessionId == null) { + return getInstance(); + } else { + if (!sessionInstances.containsKey(sessionId)) { + sessionInstances.put(sessionId, new DefaultI18N()); + } + return sessionInstances.get(sessionId); } - return sessionInstances.get(sessionId); } /** diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index a644f0c1d..f90360d26 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -137,7 +137,7 @@ public class TrainModule implements UIModule { maxChartPoints = DEFAULT_MAX_CHART_POINTS; } - configuration = new Configuration(new Version(2, 3, 29)); + configuration = new Configuration(new Version(2, 3, 23)); configuration.setDefaultEncoding("UTF-8"); configuration.setLocale(Locale.US); configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER); @@ -199,6 +199,7 @@ public class TrainModule implements UIModule { } })); r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc))); + r.add(new Route("/train/:sessionId/system/data", HttpMethod.GET, (path, rc) -> this.getSystemDataForSession(path.get(0), rc))); } else { r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview"))); r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID))); @@ -226,7 +227,9 @@ public class TrainModule implements UIModule { * @param rc Routing context */ private void renderFtl(String file, RoutingContext rc) { - Map input = DefaultI18N.getInstance().getMessages(DefaultI18N.getInstance().getDefaultLanguage()); + String sessionId = rc.request().getParam("sessionID"); + String langCode = DefaultI18N.getInstance(sessionId).getDefaultLanguage(); + Map input = DefaultI18N.getInstance().getMessages(langCode); String html; try { String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java index 43e6c76df..7f11ef970 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java @@ -17,10 +17,15 @@ package org.deeplearning4j.ui; +import io.vertx.core.Future; +import io.vertx.core.Promise; +import io.vertx.core.Vertx; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; +import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; @@ -50,63 +55,31 @@ import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.*; /** * Created by Alex on 08/10/2016. */ +@Slf4j @Ignore public class TestVertxUI extends BaseDL4JTest { + @Before public void setUp() throws Exception { UIServer.stopInstance(); } @Test - @Ignore public void testUI() throws Exception { - - StatsStorage ss = new InMemoryStatsStorage(); - VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance(); assertEquals(9000, uiServer.getPort()); uiServer.stop(); - VertxUIServer vertxUIServer = new VertxUIServer(); -// vertxUIServer.runMain(new String[] {"--uiPort", "9100", "-r", "true"}); -// -// assertEquals(9100, vertxUIServer.getPort()); -// vertxUIServer.stop(); - - - // uiServer.attach(ss); - // - // MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - // .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - // .list() - // .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()) - // .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build()) - // .build(); - // - // MultiLayerNetwork net = new MultiLayerNetwork(conf); - // net.init(); - // net.setListeners(new StatsListener(ss, 3), new ScoreIterationListener(1)); - // - // DataSetIterator iter = new IrisDataSetIterator(150, 150); - // - // for (int i = 0; i < 500; i++) { - // net.fit(iter); - //// Thread.sleep(100); - // Thread.sleep(100); - // } - // - //// uiServer.stop(); - - Thread.sleep(100000); } @Test - @Ignore public void testUI_VAE() throws Exception { //Variational autoencoder - for unsupervised layerwise pretraining @@ -144,13 +117,9 @@ public class TestVertxUI extends BaseDL4JTest { Thread.sleep(100); } - - Thread.sleep(100000); } - @Test - @Ignore public void testUIMultipleSessions() throws Exception { for (int session = 0; session < 3; session++) { @@ -178,60 +147,10 @@ public class TestVertxUI extends BaseDL4JTest { Thread.sleep(100); } } - - - Thread.sleep(1000000); - } - - @Test - @Ignore - public void testUISequentialSessions() throws Exception { - UIServer uiServer = UIServer.getInstance(); - StatsStorage ss = null; - for (int session = 0; session < 3; session++) { - - if (ss != null) { - uiServer.detach(ss); - } - ss = new InMemoryStatsStorage(); - uiServer.attach(ss); - - int numInputs = 4; - int outputNum = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .activation(Activation.TANH) - .weightInit(WeightInit.XAVIER) - .updater(new Sgd(0.03)) - .l2(1e-4) - .list() - .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) - .build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) - .build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX) - .nIn(3).nOut(outputNum).build()) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - net.setListeners(new StatsListener(ss), new ScoreIterationListener(1)); - - DataSetIterator iter = new IrisDataSetIterator(150, 150); - - for (int i = 0; i < 1000; i++) { - net.fit(iter); - } - Thread.sleep(5000); - } - - - Thread.sleep(1000000); } @Test - @Ignore - public void testUICompGraph() throws Exception { + public void testUICompGraph() { StatsStorage ss = new InMemoryStatsStorage(); @@ -254,10 +173,7 @@ public class TestVertxUI extends BaseDL4JTest { for (int i = 0; i < 100; i++) { net.fit(iter); - Thread.sleep(100); } - - Thread.sleep(1000000); } @Test @@ -304,11 +220,11 @@ public class TestVertxUI extends BaseDL4JTest { } }); - String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8); -// System.out.println(json1); + String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), + StandardCharsets.UTF_8); - String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8); -// System.out.println(json2); + String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), + StandardCharsets.UTF_8); assertNotEquals(json1, json2); @@ -336,11 +252,106 @@ public class TestVertxUI extends BaseDL4JTest { } @Test - public void testUIServerStop() { + public void testUIServerStop() throws Exception { UIServer uiServer = UIServer.getInstance(true, null); assertTrue(uiServer.isMultiSession()); + assertFalse(uiServer.isStopped()); + + long sleepMilliseconds = 1_000; + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); uiServer.stop(); + assertTrue(uiServer.isStopped()); + + log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); uiServer = UIServer.getInstance(false, null); assertFalse(uiServer.isMultiSession()); + assertFalse(uiServer.isStopped()); + + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer.stop(); + assertTrue(uiServer.isStopped()); + } + + + @Test + public void testUIServerStopAsync() throws Exception { + UIServer uiServer = UIServer.getInstance(true, null); + assertTrue(uiServer.isMultiSession()); + assertFalse(uiServer.isStopped()); + + long sleepMilliseconds = 1_000; + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + + CountDownLatch latch = new CountDownLatch(1); + Promise promise = Promise.promise(); + promise.future().compose( + success -> Future.future(prom -> latch.countDown()), + failure -> Future.future(prom -> latch.countDown()) + ); + + uiServer.stopAsync(promise); + latch.await(); + assertTrue(uiServer.isStopped()); + + log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer = UIServer.getInstance(false, null); + assertFalse(uiServer.isMultiSession()); + + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer.stop(); + } + + @Test (expected = DL4JException.class) + public void testUIStartPortAlreadyBound() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + //Create HttpServer that binds the same port + int port = VertxUIServer.DEFAULT_UI_PORT; + Vertx vertx = Vertx.vertx(); + vertx.createHttpServer() + .requestHandler(event -> {}) + .listen(port, result -> latch.countDown()); + latch.await(); + + try { + //DL4JException signals that the port cannot be bound, UI server cannot start + UIServer.getInstance(); + } finally { + vertx.close(); + } + } + + @Test + public void testUIStartAsync() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(1); + Promise promise = Promise.promise(); + promise.future().compose( + success -> Future.future(prom -> latch.countDown()), + failure -> Future.future(prom -> latch.countDown()) + ); + int port = VertxUIServer.DEFAULT_UI_PORT; + VertxUIServer.getInstance(port, false, null, promise); + latch.await(); + if (promise.future().succeeded()) { + String deploymentId = promise.future().result(); + log.debug("UI server deployed, deployment ID = {}", deploymentId); + } else { + log.debug("UI server failed to deploy.", promise.future().cause()); + } + } + + @Test + public void testUIAutoStopOnThreadExit() throws InterruptedException { + AtomicReference uiServer = new AtomicReference<>(); + Thread thread = new Thread(() -> uiServer.set(UIServer.getInstance())); + thread.start(); + thread.join(); + Thread.sleep(1_000); + assertTrue(uiServer.get().isStopped()); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java new file mode 100644 index 000000000..befd8a093 --- /dev/null +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java @@ -0,0 +1,272 @@ +package org.deeplearning4j.ui; + +import io.netty.handler.codec.http.HttpResponseStatus; +import io.vertx.core.Future; +import io.vertx.core.Promise; +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.api.storage.StatsStorage; +import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.ScoreIterationListener; +import org.deeplearning4j.ui.api.UIServer; +import org.deeplearning4j.ui.stats.StatsListener; +import org.deeplearning4j.ui.storage.InMemoryStatsStorage; +import org.junit.Ignore; +import org.junit.Test; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.function.Function; +import org.nd4j.linalg.learning.config.Sgd; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import java.io.UnsupportedEncodingException; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLEncoder; +import java.util.HashMap; +import java.util.concurrent.CountDownLatch; + +import static org.junit.Assert.*; + +@Slf4j +@Ignore +public class TestVertxUIManual extends BaseDL4JTest { + + @Override + public long getTimeoutMilliseconds() { + return 3600_000L; + } + + @Test + @Ignore + public void testUI() throws Exception { + VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance(); + assertEquals(9000, uiServer.getPort()); + + Thread.sleep(3000_000); + uiServer.stop(); + } + + @Test + @Ignore + public void testUISequentialSessions() throws Exception { + UIServer uiServer = UIServer.getInstance(); + StatsStorage ss = null; + for (int session = 0; session < 3; session++) { + + if (ss != null) { + uiServer.detach(ss); + } + ss = new InMemoryStatsStorage(); + uiServer.attach(ss); + + int numInputs = 4; + int outputNum = 3; + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .activation(Activation.TANH) + .weightInit(WeightInit.XAVIER) + .updater(new Sgd(0.03)) + .l2(1e-4) + .list() + .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) + .build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) + .build()) + .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .activation(Activation.SOFTMAX) + .nIn(3).nOut(outputNum).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + net.setListeners(new StatsListener(ss), new ScoreIterationListener(1)); + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + + for (int i = 0; i < 100; i++) { + net.fit(iter); + } + Thread.sleep(5_000); + } + } + + @Test + @Ignore + public void testUIServerStop() throws Exception { + UIServer uiServer = UIServer.getInstance(true, null); + assertTrue(uiServer.isMultiSession()); + assertFalse(uiServer.isStopped()); + + long sleepMilliseconds = 30_000; + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer.stop(); + assertTrue(uiServer.isStopped()); + + log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer = UIServer.getInstance(false, null); + assertFalse(uiServer.isMultiSession()); + assertFalse(uiServer.isStopped()); + + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer.stop(); + assertTrue(uiServer.isStopped()); + } + + + @Test + @Ignore + public void testUIServerStopAsync() throws Exception { + UIServer uiServer = UIServer.getInstance(true, null); + assertTrue(uiServer.isMultiSession()); + assertFalse(uiServer.isStopped()); + + long sleepMilliseconds = 30_000; + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + + CountDownLatch latch = new CountDownLatch(1); + Promise promise = Promise.promise(); + promise.future().compose( + success -> Future.future(prom -> latch.countDown()), + failure -> Future.future(prom -> latch.countDown()) + ); + + uiServer.stopAsync(promise); + latch.await(); + assertTrue(uiServer.isStopped()); + + log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer = UIServer.getInstance(false, null); + assertFalse(uiServer.isMultiSession()); + + log.info("Waiting {} ms before stopping.", sleepMilliseconds); + Thread.sleep(sleepMilliseconds); + uiServer.stop(); + } + + @Test + @Ignore + public void testUIAutoAttachDetach() throws Exception { + long detachTimeoutMillis = 15_000; + AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(detachTimeoutMillis); + UIServer uIServer = UIServer.getInstance(true, statsProvider); + statsProvider.setUIServer(uIServer); + InMemoryStatsStorage ss = null; + for (int session = 0; session < 3; session++) { + int layerSize = session + 4; + + ss = new InMemoryStatsStorage(); + String sessionId = Integer.toString(session); + statsProvider.put(sessionId, ss); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) + .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + StatsListener statsListener = new StatsListener(ss, 1); + statsListener.setSessionID(sessionId); + net.setListeners(statsListener, new ScoreIterationListener(1)); + uIServer.attach(ss); + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + + for (int i = 0; i < 20; i++) { + net.fit(iter); + } + + assertTrue(uIServer.isAttached(ss)); + uIServer.detach(ss); + assertFalse(uIServer.isAttached(ss)); + + /* + * Visiting /train/:sessionId to auto-attach StatsStorage + */ + String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId); + HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); + conn.connect(); + + assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); + assertTrue(uIServer.isAttached(ss)); + } + + Thread.sleep(detachTimeoutMillis + 60_000); + assertFalse(uIServer.isAttached(ss)); + } + + + /** + * Get URL-encoded URL for training session on given server address + * @param serverAddress server address + * @param sessionId session ID + * @return URL + * @throws UnsupportedEncodingException if the used encoding is not supported + */ + private static String trainingSessionUrl(String serverAddress, String sessionId) + throws UnsupportedEncodingException { + return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8")); + } + + /** + * StatsStorage provider with automatic detaching of StatsStorage after a timeout + * @author Tamas Fenyvesi + */ + private static class AutoDetachingStatsStorageProvider implements Function { + HashMap storageForSession = new HashMap<>(); + UIServer uIServer; + long autoDetachTimeoutMillis; + + public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) { + this.autoDetachTimeoutMillis = autoDetachTimeoutMillis; + } + + public void put(String sessionId, InMemoryStatsStorage statsStorage) { + storageForSession.put(sessionId, statsStorage); + } + + public void setUIServer(UIServer uIServer) { + this.uIServer = uIServer; + } + + @Override + public StatsStorage apply(String sessionId) { + StatsStorage statsStorage = storageForSession.get(sessionId); + + if (statsStorage != null) { + new Thread(() -> { + try { + log.info("Waiting to detach StatsStorage (session ID: {})" + + " after {} ms ", sessionId, autoDetachTimeoutMillis); + Thread.sleep(autoDetachTimeoutMillis); + } catch (InterruptedException e) { + e.printStackTrace(); + } finally { + log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.", + sessionId, autoDetachTimeoutMillis); + uIServer.detach(statsStorage); + log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}", + uIServer.getAddress(), sessionId); + } + }).start(); + } + + return statsStorage; + } + } + +} diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index 0ca870107..b0be7beb7 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; +import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -33,7 +34,6 @@ import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.junit.Before; -import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -53,22 +53,22 @@ import static org.junit.Assert.*; /** * @author Tamas Fenyvesi */ -@Ignore @Slf4j public class TestVertxUIMultiSession extends BaseDL4JTest { + @Before public void setUp() throws Exception { UIServer.stopInstance(); } @Test - public void testUIMultiSession() throws Exception { - + public void testUIMultiSessionParallelTraining() throws Exception { UIServer uIServer = UIServer.getInstance(true, null); HashMap statStorageForThread = new HashMap<>(); HashMap sessionIdForThread = new HashMap<>(); - for (int session = 0; session < 10; session++) { + int parallelTrainingCount = 10; + for (int session = 0; session < parallelTrainingCount; session++) { StatsStorage ss = new InMemoryStatsStorage(); @@ -106,8 +106,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { sessionIdForThread.put(training, sessionId); } - Thread.sleep(10000000); - for (Thread thread: statStorageForThread.keySet()) { StatsStorage ss = statStorageForThread.get(thread); String sessionId = sessionIdForThread.get(thread); @@ -122,7 +120,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); assertTrue(uIServer.isAttached(ss)); - } catch (InterruptedException | IOException e) { + } catch (IOException e) { log.error("",e); fail(e.getMessage()); } finally { @@ -130,7 +128,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { assertFalse(uIServer.isAttached(ss)); } } - } @Test @@ -183,61 +180,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { } } - @Test - @Ignore - public void testUIAutoAttachDetach() throws Exception { - - long autoDetachTimeoutMillis = 30_000; - AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(autoDetachTimeoutMillis); - UIServer uIServer = UIServer.getInstance(true, statsProvider); - statsProvider.setUIServer(uIServer); - - for (int session = 0; session < 3; session++) { - int layerSize = session + 4; - - InMemoryStatsStorage ss = new InMemoryStatsStorage(); - String sessionId = Integer.toString(session); - statsProvider.put(sessionId, ss); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() - .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) - .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build()) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - StatsListener statsListener = new StatsListener(ss, 1); - statsListener.setSessionID(sessionId); - net.setListeners(statsListener, new ScoreIterationListener(1)); - uIServer.attach(ss); - - DataSetIterator iter = new IrisDataSetIterator(150, 150); - - for (int i = 0; i < 20; i++) { - net.fit(iter); - } - - assertTrue(uIServer.isAttached(ss)); - uIServer.detach(ss); - assertFalse(uIServer.isAttached(ss)); - - /* - * Visiting /train/:sessionId to auto-attach StatsStorage - */ - String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId); - HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); - conn.connect(); - - assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); - assertTrue(uIServer.isAttached(ss)); - } - - Thread.sleep(1_000_000); - } - - @Test (expected = RuntimeException.class) + @Test (expected = DL4JException.class) public void testUIServerGetInstanceMultipleCalls1() { UIServer uiServer = UIServer.getInstance(); assertFalse(uiServer.isMultiSession()); @@ -245,7 +188,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { } - @Test (expected = RuntimeException.class) + @Test (expected = DL4JException.class) public void testUIServerGetInstanceMultipleCalls2() { UIServer uiServer = UIServer.getInstance(true, null); assertTrue(uiServer.isMultiSession()); @@ -259,55 +202,8 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { * @return URL * @throws UnsupportedEncodingException if the used encoding is not supported */ - private static String trainingSessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException { + private static String trainingSessionUrl(String serverAddress, String sessionId) + throws UnsupportedEncodingException { return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8")); } - - /** - * StatsStorage provider with automatic detaching of StatsStorage after a timeout - * @author fenyvesit - * - */ - private static class AutoDetachingStatsStorageProvider implements Function { - HashMap storageForSession = new HashMap<>(); - UIServer uIServer; - long autoDetachTimeoutMillis; - - public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) { - this.autoDetachTimeoutMillis = autoDetachTimeoutMillis; - } - - public void put(String sessionId, InMemoryStatsStorage statsStorage) { - storageForSession.put(sessionId, statsStorage); - } - - public void setUIServer(UIServer uIServer) { - this.uIServer = uIServer; - } - - @Override - public StatsStorage apply(String sessionId) { - StatsStorage statsStorage = storageForSession.get(sessionId); - - if (statsStorage != null) { - new Thread(() -> { - try { - System.out.println("Waiting to detach StatsStorage (session ID: " + sessionId + ")" + - " after " + autoDetachTimeoutMillis + " ms "); - Thread.sleep(autoDetachTimeoutMillis); - } catch (InterruptedException e) { - log.error("",e); - } finally { - System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " + - autoDetachTimeoutMillis + " ms."); - uIServer.detach(statsStorage); - System.out.println(" To re-attach StatsStorage of training session, visit " + - uIServer.getAddress() + "/train/" + sessionId); - } - }).start(); - } - - return statsStorage; - } - } }