Fix UIServer features in multi-session mode, synchronous start and stop (#8856)

* fix Freemarker version mismatch: change version requested in TrainModule to 2.3.23 (freemarker.version in pom.xml)

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* fix checking multi-session mode in VertxUIServer.getInstance. Tested multiple calls in TestVertxUIMultiSession.

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* fix UIServer.getInstance() to return existing instance

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* extend timeout for manual UI tests from 30 to 600 seconds

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* start and stop UI server synchronously (wait until complete), tests

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* fix for auto-attaching StatsStorage given in VertxUIServer#getInstance(Integer, boolean, Function<String,StatsStorage>), test improvements

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* exception handling, test improvements

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* add asynchronous method to start UI server

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* fix UIServer.getInstance() to return existing instance

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* fix UI server language setting in multi-session mode

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* fix UI server system tab not loading data in multi-session mode

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* undo added InterruptedException in UIServer.getInstance()

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* fix async stopping of UIServer.stopAsync(Promise<Void>), added test

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* restore the daemon thread style behaviour of UIServer: don't keep the process alive just because the UI is running

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>

* speed up and don't @Ignore tests in TestVertxUI and TestVertxUIMultiSession, put longer tests to separate class TestVertxUIManual

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>
master
Tamás Fenyvesi 2020-04-23 02:26:51 +02:00 committed by GitHub
parent 8eccc170ec
commit 722d5a052a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 604 additions and 262 deletions

View File

@ -19,6 +19,8 @@ package org.deeplearning4j.ui;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.impl.MimeMapping;
@ -35,6 +37,7 @@ import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
@ -72,35 +75,69 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
private static Function<String, StatsStorage> statsStorageProvider;
private static Integer instancePort;
private static Thread autoStopThread;
private TrainModule trainModule;
public static VertxUIServer getInstance() {
return getInstance(null, multiSession.get(), null);
/**
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* @param port TCP socket port for {@link HttpServer} to listen
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}.
* @return UI instance for this JVM
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
public static VertxUIServer getInstance(Integer port, boolean multiSession,
Function<String, StatsStorage> statsStorageProvider) throws DL4JException {
return getInstance(port, multiSession, statsStorageProvider, null);
}
public static VertxUIServer getInstance(Integer port, boolean multiSession, Function<String, StatsStorage> statsStorageProvider){
/**
*
* Get (and, initialize if necessary) the UI server. This function will wait until the server started
* (synchronous way), or pass the given callback to handle success or failure (asynchronous way).
* @param port TCP socket port for {@link HttpServer} to listen
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}.
* @param startCallback asynchronous deployment handler callback that will be notify of success or failure.
* If {@code null} given, then this method will wait until deployment is complete.
* If the deployment is successful the result will contain a String representing the
* unique deployment ID of the deployment.
* @return UI server instance
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
public static VertxUIServer getInstance(Integer port, boolean multiSession,
Function<String, StatsStorage> statsStorageProvider, Promise<String> startCallback)
throws DL4JException {
if (instance == null || instance.isStopped()) {
VertxUIServer.multiSession.set(multiSession);
VertxUIServer.setStatsStorageProvider(statsStorageProvider);
instancePort = port;
Vertx vertx = Vertx.vertx();
//Launch UI server verticle and wait for it to start
CountDownLatch l = new CountDownLatch(1);
vertx.deployVerticle(VertxUIServer.class.getName(), res -> {
l.countDown();
});
try {
l.await(5000, TimeUnit.MILLISECONDS);
} catch (InterruptedException e){ } //Ignore
if (startCallback != null) {
//Launch UI server verticle and pass asynchronous callback that will be notified of completion
deploy(startCallback);
} else {
//Launch UI server verticle and wait for it to start
deploy();
}
} else if (!instance.isStopped()) {
if (instance.multiSession.get() && !instance.isMultiSession()) {
throw new RuntimeException("Cannot return multi-session instance." +
if (multiSession && !instance.isMultiSession()) {
throw new DL4JException("Cannot return multi-session instance." +
" UIServer has already started in single-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one.");
} else if (!instance.multiSession.get() && instance.isMultiSession()) {
throw new RuntimeException("Cannot return single-session instance." +
} else if (!multiSession && instance.isMultiSession()) {
throw new DL4JException("Cannot return single-session instance." +
" UIServer has already started in multi-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one.");
}
@ -109,6 +146,64 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
return instance;
}
/**
* Deploy (start) {@link VertxUIServer}, waiting until starting is complete.
* @throws DL4JException if UI server failed to start;
* if interrupted while waiting for completion
*/
private static void deploy() throws DL4JException {
CountDownLatch l = new CountDownLatch(1);
Promise<String> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> l.countDown()),
failure -> Future.future(prom -> l.countDown())
);
deploy(promise);
// synchronous function
try {
l.await();
} catch (InterruptedException e) {
throw new DL4JException(e);
}
Future<String> future = promise.future();
if (future.failed()) {
throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause());
}
}
/**
* Deploy (start) {@link VertxUIServer},
* and pass callback to handle successful or failed completion of deployment.
* @param startCallback promise that will handle success or failure of deployment.
* If the deployment is successful the result will contain a String representing the unique deployment ID of the
* deployment.
*/
private static void deploy(Promise<String> startCallback) {
log.debug("Deeplearning4j UI server is starting.");
Promise<String> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> startCallback.complete(success)),
failure -> Future.future(prom -> startCallback.fail(new RuntimeException(failure)))
);
Vertx vertx = Vertx.vertx();
vertx.deployVerticle(VertxUIServer.class.getName(), promise);
Thread currentThread = Thread.currentThread();
VertxUIServer.autoStopThread = new Thread(() -> {
try {
currentThread.join();
log.info("Deeplearning4j UI server is auto-stopping.");
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
instance.stop();
}
} catch (InterruptedException e) {
log.error("Deeplearning4j UI server auto-stop thread was interrupted.", e);
}
});
}
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
private RemoteReceiverModule remoteReceiverModule;
@ -132,10 +227,18 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
instance = this;
}
public static void stopInstance(){
if(instance == null)
public static void stopInstance() throws Exception {
if(instance == null || instance.isStopped())
return;
instance.stop();
VertxUIServer.reset();
}
private static void reset() {
VertxUIServer.instance = null;
VertxUIServer.statsStorageProvider = null;
VertxUIServer.instancePort = null;
VertxUIServer.multiSession.set(false);
}
/**
@ -145,12 +248,14 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
if (statsStorageProvider != null) {
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
this.trainModule.setSessionLoader(this.statsStorageLoader);
if (trainModule != null) {
this.trainModule.setSessionLoader(this.statsStorageLoader);
}
}
}
@Override
public void start() throws Exception {
public void start(Promise<Void> startCallback) throws Exception {
//Create REST endpoints
File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
uploadDir.mkdirs();
@ -192,6 +297,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
});
}
if (VertxUIServer.statsStorageProvider != null) {
autoAttachStatsStorageBySessionId(VertxUIServer.statsStorageProvider);
}
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
@ -246,17 +354,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
}
server = vertx.createHttpServer()
.requestHandler(r)
.listen(port);
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
uiEventRoutingThread.setDaemon(true);
uiEventRoutingThread.start();
String address = UIServer.getInstance().getAddress();
log.info("Deeplearning4j UI server started at: {}", address);
server = vertx.createHttpServer()
.requestHandler(r)
.listen(port, result -> {
if (result.succeeded()) {
String address = UIServer.getInstance().getAddress();
log.info("Deeplearning4j UI server started at: {}", address);
startCallback.complete();
} else {
startCallback.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port "
+ server.actualPort(), result.cause()));
}
});
VertxUIServer.autoStopThread.start();
}
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
@ -302,11 +416,33 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}
@Override
public void stop() {
server.close();
shutdown.set(true);
public void stop() throws InterruptedException {
CountDownLatch l = new CountDownLatch(1);
Promise<Void> promise = Promise.promise();
promise.future().compose(
successEvent -> Future.future(prom -> l.countDown()),
failureEvent -> Future.future(prom -> l.countDown())
);
stopAsync(promise);
// synchronous function should wait until the server is stopped
l.await();
}
@Override
public void stopAsync(Promise<Void> stopCallback) {
/**
* Stop Vertx instance and release any resources held by it.
* Pass promise to {@link #stop(Promise)}.
*/
vertx.close(ar -> stopCallback.handle(ar));
}
@Override
public void stop(Promise<Void> stopCallback) {
shutdown.set(true);
stopCallback.complete();
log.info("Deeplearning4j UI server stopped.");
}
@Override
public boolean isStopped() {
@ -353,8 +489,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
if (!statsStorageInstances.contains(statsStorage))
return; //No op
boolean found = false;
for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext(); ) {
Pair<StatsStorage, StatsStorageListener> p = iterator.next();
for (Pair<StatsStorage, StatsStorageListener> p : listeners) {
if (p.getFirst() == statsStorage) { //Same object, not equality
statsStorage.deregisterStatsStorageListener(p.getSecond());
listeners.remove(p);

View File

@ -17,8 +17,10 @@
package org.deeplearning4j.ui.api;
import io.vertx.core.Promise;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.VertxUIServer;
import org.nd4j.linalg.function.Function;
@ -32,18 +34,24 @@ import java.util.List;
public interface UIServer {
/**
* Get (and, initialize if necessary) the UI server.
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* Singleton pattern - all calls to getInstance() will return the same UI instance.
*
* @return UI instance for this JVM
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
static UIServer getInstance() throws RuntimeException {
return getInstance(false, null);
static UIServer getInstance() throws DL4JException {
if (VertxUIServer.getInstance() != null && !VertxUIServer.getInstance().isStopped()) {
return VertxUIServer.getInstance();
} else {
return getInstance(false, null);
}
}
/**
* Get (and, initialize if necessary) the UI server.
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* Singleton pattern - all calls to getInstance() will return the same UI instance.
*
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
@ -52,16 +60,19 @@ public interface UIServer {
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}.
* @return UI instance for this JVM
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider) throws RuntimeException {
static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider)
throws DL4JException {
return VertxUIServer.getInstance(null, multiSession, statsStorageProvider);
}
/**
* Stop UIServer instance, if already running
*/
static void stopInstance() {
static void stopInstance() throws Exception {
VertxUIServer.stopInstance();
}
@ -144,8 +155,15 @@ public interface UIServer {
boolean isRemoteListenerEnabled();
/**
* Stop/shut down the UI server.
* Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
*/
void stop();
void stop() throws Exception;
/**
* Stop/shut down the UI server.
* This asynchronous function should immediately return, and notify the callback {@link Promise} on completion:
* either call {@link Promise#complete} or {@link io.vertx.core.Promise#fail}.
* @param stopCallback callback {@link Promise} to notify on completion
*/
void stopAsync(Promise<Void> stopCallback);
}

View File

@ -69,15 +69,19 @@ public class DefaultI18N implements I18N {
}
/**
* Get instance for session (used in multi-session mode)
* @param sessionId session
* @return instance for session
* Get instance for session
* @param sessionId session ID for multi-session mode, leave it {@code null} for global instance
* @return instance for session, or global instance
*/
public static synchronized I18N getInstance(String sessionId) {
if (!sessionInstances.containsKey(sessionId)) {
sessionInstances.put(sessionId, new DefaultI18N());
if (sessionId == null) {
return getInstance();
} else {
if (!sessionInstances.containsKey(sessionId)) {
sessionInstances.put(sessionId, new DefaultI18N());
}
return sessionInstances.get(sessionId);
}
return sessionInstances.get(sessionId);
}
/**

View File

@ -137,7 +137,7 @@ public class TrainModule implements UIModule {
maxChartPoints = DEFAULT_MAX_CHART_POINTS;
}
configuration = new Configuration(new Version(2, 3, 29));
configuration = new Configuration(new Version(2, 3, 23));
configuration.setDefaultEncoding("UTF-8");
configuration.setLocale(Locale.US);
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
@ -199,6 +199,7 @@ public class TrainModule implements UIModule {
}
}));
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
r.add(new Route("/train/:sessionId/system/data", HttpMethod.GET, (path, rc) -> this.getSystemDataForSession(path.get(0), rc)));
} else {
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
@ -226,7 +227,9 @@ public class TrainModule implements UIModule {
* @param rc Routing context
*/
private void renderFtl(String file, RoutingContext rc) {
Map<String, String> input = DefaultI18N.getInstance().getMessages(DefaultI18N.getInstance().getDefaultLanguage());
String sessionId = rc.request().getParam("sessionID");
String langCode = DefaultI18N.getInstance(sessionId).getDefaultLanguage();
Map<String, String> input = DefaultI18N.getInstance().getMessages(langCode);
String html;
try {
String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8);

View File

@ -17,10 +17,15 @@
package org.deeplearning4j.ui;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -50,63 +55,31 @@ import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*;
/**
* Created by Alex on 08/10/2016.
*/
@Slf4j
@Ignore
public class TestVertxUI extends BaseDL4JTest {
@Before
public void setUp() throws Exception {
UIServer.stopInstance();
}
@Test
@Ignore
public void testUI() throws Exception {
StatsStorage ss = new InMemoryStatsStorage();
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
assertEquals(9000, uiServer.getPort());
uiServer.stop();
VertxUIServer vertxUIServer = new VertxUIServer();
// vertxUIServer.runMain(new String[] {"--uiPort", "9100", "-r", "true"});
//
// assertEquals(9100, vertxUIServer.getPort());
// vertxUIServer.stop();
// uiServer.attach(ss);
//
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
// .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
// .list()
// .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
// .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
// .build();
//
// MultiLayerNetwork net = new MultiLayerNetwork(conf);
// net.init();
// net.setListeners(new StatsListener(ss, 3), new ScoreIterationListener(1));
//
// DataSetIterator iter = new IrisDataSetIterator(150, 150);
//
// for (int i = 0; i < 500; i++) {
// net.fit(iter);
//// Thread.sleep(100);
// Thread.sleep(100);
// }
//
//// uiServer.stop();
Thread.sleep(100000);
}
@Test
@Ignore
public void testUI_VAE() throws Exception {
//Variational autoencoder - for unsupervised layerwise pretraining
@ -144,13 +117,9 @@ public class TestVertxUI extends BaseDL4JTest {
Thread.sleep(100);
}
Thread.sleep(100000);
}
@Test
@Ignore
public void testUIMultipleSessions() throws Exception {
for (int session = 0; session < 3; session++) {
@ -178,60 +147,10 @@ public class TestVertxUI extends BaseDL4JTest {
Thread.sleep(100);
}
}
Thread.sleep(1000000);
}
@Test
@Ignore
public void testUISequentialSessions() throws Exception {
UIServer uiServer = UIServer.getInstance();
StatsStorage ss = null;
for (int session = 0; session < 3; session++) {
if (ss != null) {
uiServer.detach(ss);
}
ss = new InMemoryStatsStorage();
uiServer.attach(ss);
int numInputs = 4;
int outputNum = 3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(0.03))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(outputNum).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 1000; i++) {
net.fit(iter);
}
Thread.sleep(5000);
}
Thread.sleep(1000000);
}
@Test
@Ignore
public void testUICompGraph() throws Exception {
public void testUICompGraph() {
StatsStorage ss = new InMemoryStatsStorage();
@ -254,10 +173,7 @@ public class TestVertxUI extends BaseDL4JTest {
for (int i = 0; i < 100; i++) {
net.fit(iter);
Thread.sleep(100);
}
Thread.sleep(1000000);
}
@Test
@ -304,11 +220,11 @@ public class TestVertxUI extends BaseDL4JTest {
}
});
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8);
// System.out.println(json1);
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"),
StandardCharsets.UTF_8);
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8);
// System.out.println(json2);
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"),
StandardCharsets.UTF_8);
assertNotEquals(json1, json2);
@ -336,11 +252,106 @@ public class TestVertxUI extends BaseDL4JTest {
}
@Test
public void testUIServerStop() {
public void testUIServerStop() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 1_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
assertTrue(uiServer.isStopped());
}
@Test
public void testUIServerStopAsync() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 1_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
CountDownLatch latch = new CountDownLatch(1);
Promise<Void> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> latch.countDown()),
failure -> Future.future(prom -> latch.countDown())
);
uiServer.stopAsync(promise);
latch.await();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
}
@Test (expected = DL4JException.class)
public void testUIStartPortAlreadyBound() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
//Create HttpServer that binds the same port
int port = VertxUIServer.DEFAULT_UI_PORT;
Vertx vertx = Vertx.vertx();
vertx.createHttpServer()
.requestHandler(event -> {})
.listen(port, result -> latch.countDown());
latch.await();
try {
//DL4JException signals that the port cannot be bound, UI server cannot start
UIServer.getInstance();
} finally {
vertx.close();
}
}
@Test
public void testUIStartAsync() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Promise<String> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> latch.countDown()),
failure -> Future.future(prom -> latch.countDown())
);
int port = VertxUIServer.DEFAULT_UI_PORT;
VertxUIServer.getInstance(port, false, null, promise);
latch.await();
if (promise.future().succeeded()) {
String deploymentId = promise.future().result();
log.debug("UI server deployed, deployment ID = {}", deploymentId);
} else {
log.debug("UI server failed to deploy.", promise.future().cause());
}
}
@Test
public void testUIAutoStopOnThreadExit() throws InterruptedException {
AtomicReference<UIServer> uiServer = new AtomicReference<>();
Thread thread = new Thread(() -> uiServer.set(UIServer.getInstance()));
thread.start();
thread.join();
Thread.sleep(1_000);
assertTrue(uiServer.get().isStopped());
}
}

View File

@ -0,0 +1,272 @@
package org.deeplearning4j.ui;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.UnsupportedEncodingException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.util.HashMap;
import java.util.concurrent.CountDownLatch;
import static org.junit.Assert.*;
@Slf4j
@Ignore
public class TestVertxUIManual extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 3600_000L;
}
@Test
@Ignore
public void testUI() throws Exception {
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
assertEquals(9000, uiServer.getPort());
Thread.sleep(3000_000);
uiServer.stop();
}
@Test
@Ignore
public void testUISequentialSessions() throws Exception {
UIServer uiServer = UIServer.getInstance();
StatsStorage ss = null;
for (int session = 0; session < 3; session++) {
if (ss != null) {
uiServer.detach(ss);
}
ss = new InMemoryStatsStorage();
uiServer.attach(ss);
int numInputs = 4;
int outputNum = 3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(0.03))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(outputNum).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 100; i++) {
net.fit(iter);
}
Thread.sleep(5_000);
}
}
@Test
@Ignore
public void testUIServerStop() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 30_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
assertTrue(uiServer.isStopped());
}
@Test
@Ignore
public void testUIServerStopAsync() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 30_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
CountDownLatch latch = new CountDownLatch(1);
Promise<Void> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> latch.countDown()),
failure -> Future.future(prom -> latch.countDown())
);
uiServer.stopAsync(promise);
latch.await();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
}
@Test
@Ignore
public void testUIAutoAttachDetach() throws Exception {
long detachTimeoutMillis = 15_000;
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(detachTimeoutMillis);
UIServer uIServer = UIServer.getInstance(true, statsProvider);
statsProvider.setUIServer(uIServer);
InMemoryStatsStorage ss = null;
for (int session = 0; session < 3; session++) {
int layerSize = session + 4;
ss = new InMemoryStatsStorage();
String sessionId = Integer.toString(session);
statsProvider.put(sessionId, ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
StatsListener statsListener = new StatsListener(ss, 1);
statsListener.setSessionID(sessionId);
net.setListeners(statsListener, new ScoreIterationListener(1));
uIServer.attach(ss);
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 20; i++) {
net.fit(iter);
}
assertTrue(uIServer.isAttached(ss));
uIServer.detach(ss);
assertFalse(uIServer.isAttached(ss));
/*
* Visiting /train/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss));
}
Thread.sleep(detachTimeoutMillis + 60_000);
assertFalse(uIServer.isAttached(ss));
}
/**
* Get URL-encoded URL for training session on given server address
* @param serverAddress server address
* @param sessionId session ID
* @return URL
* @throws UnsupportedEncodingException if the used encoding is not supported
*/
private static String trainingSessionUrl(String serverAddress, String sessionId)
throws UnsupportedEncodingException {
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
}
/**
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
* @author Tamas Fenyvesi
*/
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
UIServer uIServer;
long autoDetachTimeoutMillis;
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
}
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
storageForSession.put(sessionId, statsStorage);
}
public void setUIServer(UIServer uIServer) {
this.uIServer = uIServer;
}
@Override
public StatsStorage apply(String sessionId) {
StatsStorage statsStorage = storageForSession.get(sessionId);
if (statsStorage != null) {
new Thread(() -> {
try {
log.info("Waiting to detach StatsStorage (session ID: {})" +
" after {} ms ", sessionId, autoDetachTimeoutMillis);
Thread.sleep(autoDetachTimeoutMillis);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
sessionId, autoDetachTimeoutMillis);
uIServer.detach(statsStorage);
log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}",
uIServer.getAddress(), sessionId);
}
}).start();
}
return statsStorage;
}
}
}

View File

@ -18,9 +18,11 @@
package org.deeplearning4j.ui;
import io.netty.handler.codec.http.HttpResponseStatus;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -32,7 +34,6 @@ import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -52,21 +53,22 @@ import static org.junit.Assert.*;
/**
* @author Tamas Fenyvesi
*/
@Ignore
@Slf4j
public class TestVertxUIMultiSession extends BaseDL4JTest {
@Before
public void setUp() throws Exception {
UIServer.stopInstance();
}
@Test
public void testUIMultiSession() throws Exception {
public void testUIMultiSessionParallelTraining() throws Exception {
UIServer uIServer = UIServer.getInstance(true, null);
HashMap<Thread, StatsStorage> statStorageForThread = new HashMap<>();
HashMap<Thread, String> sessionIdForThread = new HashMap<>();
for (int session = 0; session < 10; session++) {
int parallelTrainingCount = 10;
for (int session = 0; session < parallelTrainingCount; session++) {
StatsStorage ss = new InMemoryStatsStorage();
@ -104,8 +106,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
sessionIdForThread.put(training, sessionId);
}
Thread.sleep(10000000);
for (Thread thread: statStorageForThread.keySet()) {
StatsStorage ss = statStorageForThread.get(thread);
String sessionId = sessionIdForThread.get(thread);
@ -120,7 +120,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss));
} catch (InterruptedException | IOException e) {
} catch (IOException e) {
e.printStackTrace();
fail(e.getMessage());
} finally {
@ -128,7 +128,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
assertFalse(uIServer.isAttached(ss));
}
}
}
@Test
@ -181,61 +180,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
}
}
@Test
@Ignore
public void testUIAutoAttachDetach() throws Exception {
long autoDetachTimeoutMillis = 30_000;
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(autoDetachTimeoutMillis);
UIServer uIServer = UIServer.getInstance(true, statsProvider);
statsProvider.setUIServer(uIServer);
for (int session = 0; session < 3; session++) {
int layerSize = session + 4;
InMemoryStatsStorage ss = new InMemoryStatsStorage();
String sessionId = Integer.toString(session);
statsProvider.put(sessionId, ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
StatsListener statsListener = new StatsListener(ss, 1);
statsListener.setSessionID(sessionId);
net.setListeners(statsListener, new ScoreIterationListener(1));
uIServer.attach(ss);
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 20; i++) {
net.fit(iter);
}
assertTrue(uIServer.isAttached(ss));
uIServer.detach(ss);
assertFalse(uIServer.isAttached(ss));
/*
* Visiting /train/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss));
}
Thread.sleep(1_000_000);
}
@Test (expected = RuntimeException.class)
@Test (expected = DL4JException.class)
public void testUIServerGetInstanceMultipleCalls1() {
UIServer uiServer = UIServer.getInstance();
assertFalse(uiServer.isMultiSession());
@ -243,7 +188,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
}
@Test (expected = RuntimeException.class)
@Test (expected = DL4JException.class)
public void testUIServerGetInstanceMultipleCalls2() {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
@ -257,55 +202,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
* @return URL
* @throws UnsupportedEncodingException if the used encoding is not supported
*/
private static String trainingSessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
private static String trainingSessionUrl(String serverAddress, String sessionId)
throws UnsupportedEncodingException {
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
}
/**
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
* @author fenyvesit
*
*/
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
UIServer uIServer;
long autoDetachTimeoutMillis;
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
}
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
storageForSession.put(sessionId, statsStorage);
}
public void setUIServer(UIServer uIServer) {
this.uIServer = uIServer;
}
@Override
public StatsStorage apply(String sessionId) {
StatsStorage statsStorage = storageForSession.get(sessionId);
if (statsStorage != null) {
new Thread(() -> {
try {
System.out.println("Waiting to detach StatsStorage (session ID: " + sessionId + ")" +
" after " + autoDetachTimeoutMillis + " ms ");
Thread.sleep(autoDetachTimeoutMillis);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " +
autoDetachTimeoutMillis + " ms.");
uIServer.detach(statsStorage);
System.out.println(" To re-attach StatsStorage of training session, visit " +
uIServer.getAddress() + "/train/" + sessionId);
}
}).start();
}
return statsStorage;
}
}
}