Fix UIServer features in multi-session mode, synchronous start and stop (#8856)
* fix Freemarker version mismatch: change version requested in TrainModule to 2.3.23 (freemarker.version in pom.xml) Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * fix checking multi-session mode in VertxUIServer.getInstance. Tested multiple calls in TestVertxUIMultiSession. Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * fix UIServer.getInstance() to return existing instance Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * extend timeout for manual UI tests from 30 to 600 seconds Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * start and stop UI server synchronously (wait until complete), tests Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * fix for auto-attaching StatsStorage given in VertxUIServer#getInstance(Integer, boolean, Function<String,StatsStorage>), test improvements Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * exception handling, test improvements Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * add asynchronous method to start UI server Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * fix UIServer.getInstance() to return existing instance Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * fix UI server language setting in multi-session mode Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * fix UI server system tab not loading data in multi-session mode Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * undo added InterruptedException in UIServer.getInstance() Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * fix async stopping of UIServer.stopAsync(Promise<Void>), added test Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * restore the daemon thread style behaviour of UIServer: don't keep the process alive just because the UI is running Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu> * speed up and don't @Ignore tests in TestVertxUI and TestVertxUIMultiSession, put longer tests to separate class TestVertxUIManual Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>master
parent
8eccc170ec
commit
722d5a052a
|
@ -19,6 +19,8 @@ package org.deeplearning4j.ui;
|
|||
import com.beust.jcommander.JCommander;
|
||||
import com.beust.jcommander.Parameter;
|
||||
import io.vertx.core.AbstractVerticle;
|
||||
import io.vertx.core.Future;
|
||||
import io.vertx.core.Promise;
|
||||
import io.vertx.core.Vertx;
|
||||
import io.vertx.core.http.HttpServer;
|
||||
import io.vertx.core.http.impl.MimeMapping;
|
||||
|
@ -35,6 +37,7 @@ import org.deeplearning4j.api.storage.StatsStorageEvent;
|
|||
import org.deeplearning4j.api.storage.StatsStorageListener;
|
||||
import org.deeplearning4j.api.storage.StatsStorageRouter;
|
||||
import org.deeplearning4j.config.DL4JSystemProperties;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
import org.deeplearning4j.ui.api.Route;
|
||||
import org.deeplearning4j.ui.api.UIModule;
|
||||
import org.deeplearning4j.ui.api.UIServer;
|
||||
|
@ -72,35 +75,69 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
private static Function<String, StatsStorage> statsStorageProvider;
|
||||
|
||||
private static Integer instancePort;
|
||||
private static Thread autoStopThread;
|
||||
|
||||
private TrainModule trainModule;
|
||||
|
||||
public static VertxUIServer getInstance() {
|
||||
return getInstance(null, multiSession.get(), null);
|
||||
/**
|
||||
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||
* @param port TCP socket port for {@link HttpServer} to listen
|
||||
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
|
||||
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
|
||||
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
|
||||
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
|
||||
* as URL path parameter in multi-session mode, or leave it {@code null}.
|
||||
* @return UI instance for this JVM
|
||||
* @throws DL4JException if UI server failed to start;
|
||||
* if the instance has already started in a different mode (multi/single-session);
|
||||
* if interrupted while waiting for completion
|
||||
*/
|
||||
public static VertxUIServer getInstance(Integer port, boolean multiSession,
|
||||
Function<String, StatsStorage> statsStorageProvider) throws DL4JException {
|
||||
return getInstance(port, multiSession, statsStorageProvider, null);
|
||||
}
|
||||
|
||||
public static VertxUIServer getInstance(Integer port, boolean multiSession, Function<String, StatsStorage> statsStorageProvider){
|
||||
/**
|
||||
*
|
||||
* Get (and, initialize if necessary) the UI server. This function will wait until the server started
|
||||
* (synchronous way), or pass the given callback to handle success or failure (asynchronous way).
|
||||
* @param port TCP socket port for {@link HttpServer} to listen
|
||||
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
|
||||
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
|
||||
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
|
||||
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
|
||||
* as URL path parameter in multi-session mode, or leave it {@code null}.
|
||||
* @param startCallback asynchronous deployment handler callback that will be notify of success or failure.
|
||||
* If {@code null} given, then this method will wait until deployment is complete.
|
||||
* If the deployment is successful the result will contain a String representing the
|
||||
* unique deployment ID of the deployment.
|
||||
* @return UI server instance
|
||||
* @throws DL4JException if UI server failed to start;
|
||||
* if the instance has already started in a different mode (multi/single-session);
|
||||
* if interrupted while waiting for completion
|
||||
*/
|
||||
public static VertxUIServer getInstance(Integer port, boolean multiSession,
|
||||
Function<String, StatsStorage> statsStorageProvider, Promise<String> startCallback)
|
||||
throws DL4JException {
|
||||
if (instance == null || instance.isStopped()) {
|
||||
VertxUIServer.multiSession.set(multiSession);
|
||||
VertxUIServer.setStatsStorageProvider(statsStorageProvider);
|
||||
instancePort = port;
|
||||
Vertx vertx = Vertx.vertx();
|
||||
|
||||
if (startCallback != null) {
|
||||
//Launch UI server verticle and pass asynchronous callback that will be notified of completion
|
||||
deploy(startCallback);
|
||||
} else {
|
||||
//Launch UI server verticle and wait for it to start
|
||||
CountDownLatch l = new CountDownLatch(1);
|
||||
vertx.deployVerticle(VertxUIServer.class.getName(), res -> {
|
||||
l.countDown();
|
||||
});
|
||||
try {
|
||||
l.await(5000, TimeUnit.MILLISECONDS);
|
||||
} catch (InterruptedException e){ } //Ignore
|
||||
deploy();
|
||||
}
|
||||
} else if (!instance.isStopped()) {
|
||||
if (instance.multiSession.get() && !instance.isMultiSession()) {
|
||||
throw new RuntimeException("Cannot return multi-session instance." +
|
||||
if (multiSession && !instance.isMultiSession()) {
|
||||
throw new DL4JException("Cannot return multi-session instance." +
|
||||
" UIServer has already started in single-session mode at " + instance.getAddress() +
|
||||
" You may stop the UI server instance, and start a new one.");
|
||||
} else if (!instance.multiSession.get() && instance.isMultiSession()) {
|
||||
throw new RuntimeException("Cannot return single-session instance." +
|
||||
} else if (!multiSession && instance.isMultiSession()) {
|
||||
throw new DL4JException("Cannot return single-session instance." +
|
||||
" UIServer has already started in multi-session mode at " + instance.getAddress() +
|
||||
" You may stop the UI server instance, and start a new one.");
|
||||
}
|
||||
|
@ -109,6 +146,64 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
return instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deploy (start) {@link VertxUIServer}, waiting until starting is complete.
|
||||
* @throws DL4JException if UI server failed to start;
|
||||
* if interrupted while waiting for completion
|
||||
*/
|
||||
private static void deploy() throws DL4JException {
|
||||
CountDownLatch l = new CountDownLatch(1);
|
||||
Promise<String> promise = Promise.promise();
|
||||
promise.future().compose(
|
||||
success -> Future.future(prom -> l.countDown()),
|
||||
failure -> Future.future(prom -> l.countDown())
|
||||
);
|
||||
deploy(promise);
|
||||
// synchronous function
|
||||
try {
|
||||
l.await();
|
||||
} catch (InterruptedException e) {
|
||||
throw new DL4JException(e);
|
||||
}
|
||||
|
||||
Future<String> future = promise.future();
|
||||
if (future.failed()) {
|
||||
throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deploy (start) {@link VertxUIServer},
|
||||
* and pass callback to handle successful or failed completion of deployment.
|
||||
* @param startCallback promise that will handle success or failure of deployment.
|
||||
* If the deployment is successful the result will contain a String representing the unique deployment ID of the
|
||||
* deployment.
|
||||
*/
|
||||
private static void deploy(Promise<String> startCallback) {
|
||||
log.debug("Deeplearning4j UI server is starting.");
|
||||
Promise<String> promise = Promise.promise();
|
||||
promise.future().compose(
|
||||
success -> Future.future(prom -> startCallback.complete(success)),
|
||||
failure -> Future.future(prom -> startCallback.fail(new RuntimeException(failure)))
|
||||
);
|
||||
|
||||
Vertx vertx = Vertx.vertx();
|
||||
vertx.deployVerticle(VertxUIServer.class.getName(), promise);
|
||||
|
||||
Thread currentThread = Thread.currentThread();
|
||||
VertxUIServer.autoStopThread = new Thread(() -> {
|
||||
try {
|
||||
currentThread.join();
|
||||
log.info("Deeplearning4j UI server is auto-stopping.");
|
||||
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
|
||||
instance.stop();
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
log.error("Deeplearning4j UI server auto-stop thread was interrupted.", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
||||
private RemoteReceiverModule remoteReceiverModule;
|
||||
|
@ -132,10 +227,18 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
instance = this;
|
||||
}
|
||||
|
||||
public static void stopInstance(){
|
||||
if(instance == null)
|
||||
public static void stopInstance() throws Exception {
|
||||
if(instance == null || instance.isStopped())
|
||||
return;
|
||||
instance.stop();
|
||||
VertxUIServer.reset();
|
||||
}
|
||||
|
||||
private static void reset() {
|
||||
VertxUIServer.instance = null;
|
||||
VertxUIServer.statsStorageProvider = null;
|
||||
VertxUIServer.instancePort = null;
|
||||
VertxUIServer.multiSession.set(false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -145,12 +248,14 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
||||
if (statsStorageProvider != null) {
|
||||
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
||||
if (trainModule != null) {
|
||||
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
public void start(Promise<Void> startCallback) throws Exception {
|
||||
//Create REST endpoints
|
||||
File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
|
||||
uploadDir.mkdirs();
|
||||
|
@ -192,6 +297,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
});
|
||||
}
|
||||
|
||||
if (VertxUIServer.statsStorageProvider != null) {
|
||||
autoAttachStatsStorageBySessionId(VertxUIServer.statsStorageProvider);
|
||||
}
|
||||
|
||||
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
||||
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
||||
|
@ -246,17 +354,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
server = vertx.createHttpServer()
|
||||
.requestHandler(r)
|
||||
.listen(port);
|
||||
|
||||
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
|
||||
uiEventRoutingThread.setDaemon(true);
|
||||
uiEventRoutingThread.start();
|
||||
|
||||
server = vertx.createHttpServer()
|
||||
.requestHandler(r)
|
||||
.listen(port, result -> {
|
||||
if (result.succeeded()) {
|
||||
String address = UIServer.getInstance().getAddress();
|
||||
log.info("Deeplearning4j UI server started at: {}", address);
|
||||
startCallback.complete();
|
||||
} else {
|
||||
startCallback.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port "
|
||||
+ server.actualPort(), result.cause()));
|
||||
}
|
||||
});
|
||||
VertxUIServer.autoStopThread.start();
|
||||
}
|
||||
|
||||
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
|
||||
|
@ -302,11 +416,33 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void stop() {
|
||||
server.close();
|
||||
shutdown.set(true);
|
||||
public void stop() throws InterruptedException {
|
||||
CountDownLatch l = new CountDownLatch(1);
|
||||
Promise<Void> promise = Promise.promise();
|
||||
promise.future().compose(
|
||||
successEvent -> Future.future(prom -> l.countDown()),
|
||||
failureEvent -> Future.future(prom -> l.countDown())
|
||||
);
|
||||
stopAsync(promise);
|
||||
// synchronous function should wait until the server is stopped
|
||||
l.await();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stopAsync(Promise<Void> stopCallback) {
|
||||
/**
|
||||
* Stop Vertx instance and release any resources held by it.
|
||||
* Pass promise to {@link #stop(Promise)}.
|
||||
*/
|
||||
vertx.close(ar -> stopCallback.handle(ar));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void stop(Promise<Void> stopCallback) {
|
||||
shutdown.set(true);
|
||||
stopCallback.complete();
|
||||
log.info("Deeplearning4j UI server stopped.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStopped() {
|
||||
|
@ -353,8 +489,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
if (!statsStorageInstances.contains(statsStorage))
|
||||
return; //No op
|
||||
boolean found = false;
|
||||
for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext(); ) {
|
||||
Pair<StatsStorage, StatsStorageListener> p = iterator.next();
|
||||
for (Pair<StatsStorage, StatsStorageListener> p : listeners) {
|
||||
if (p.getFirst() == statsStorage) { //Same object, not equality
|
||||
statsStorage.deregisterStatsStorageListener(p.getSecond());
|
||||
listeners.remove(p);
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
|
||||
package org.deeplearning4j.ui.api;
|
||||
|
||||
import io.vertx.core.Promise;
|
||||
import org.deeplearning4j.api.storage.StatsStorage;
|
||||
import org.deeplearning4j.api.storage.StatsStorageRouter;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
import org.deeplearning4j.ui.VertxUIServer;
|
||||
import org.nd4j.linalg.function.Function;
|
||||
|
||||
|
@ -32,18 +34,24 @@ import java.util.List;
|
|||
public interface UIServer {
|
||||
|
||||
/**
|
||||
* Get (and, initialize if necessary) the UI server.
|
||||
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||
* Singleton pattern - all calls to getInstance() will return the same UI instance.
|
||||
*
|
||||
* @return UI instance for this JVM
|
||||
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
|
||||
* @throws DL4JException if UI server failed to start;
|
||||
* if the instance has already started in a different mode (multi/single-session);
|
||||
* if interrupted while waiting for completion
|
||||
*/
|
||||
static UIServer getInstance() throws RuntimeException {
|
||||
static UIServer getInstance() throws DL4JException {
|
||||
if (VertxUIServer.getInstance() != null && !VertxUIServer.getInstance().isStopped()) {
|
||||
return VertxUIServer.getInstance();
|
||||
} else {
|
||||
return getInstance(false, null);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get (and, initialize if necessary) the UI server.
|
||||
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||
* Singleton pattern - all calls to getInstance() will return the same UI instance.
|
||||
*
|
||||
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
|
||||
|
@ -52,16 +60,19 @@ public interface UIServer {
|
|||
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
|
||||
* as URL path parameter in multi-session mode, or leave it {@code null}.
|
||||
* @return UI instance for this JVM
|
||||
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
|
||||
* @throws DL4JException if UI server failed to start;
|
||||
* if the instance has already started in a different mode (multi/single-session);
|
||||
* if interrupted while waiting for completion
|
||||
*/
|
||||
static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider) throws RuntimeException {
|
||||
static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider)
|
||||
throws DL4JException {
|
||||
return VertxUIServer.getInstance(null, multiSession, statsStorageProvider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop UIServer instance, if already running
|
||||
*/
|
||||
static void stopInstance() {
|
||||
static void stopInstance() throws Exception {
|
||||
VertxUIServer.stopInstance();
|
||||
}
|
||||
|
||||
|
@ -144,8 +155,15 @@ public interface UIServer {
|
|||
boolean isRemoteListenerEnabled();
|
||||
|
||||
/**
|
||||
* Stop/shut down the UI server.
|
||||
* Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
|
||||
*/
|
||||
void stop();
|
||||
void stop() throws Exception;
|
||||
|
||||
/**
|
||||
* Stop/shut down the UI server.
|
||||
* This asynchronous function should immediately return, and notify the callback {@link Promise} on completion:
|
||||
* either call {@link Promise#complete} or {@link io.vertx.core.Promise#fail}.
|
||||
* @param stopCallback callback {@link Promise} to notify on completion
|
||||
*/
|
||||
void stopAsync(Promise<Void> stopCallback);
|
||||
}
|
||||
|
|
|
@ -69,16 +69,20 @@ public class DefaultI18N implements I18N {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get instance for session (used in multi-session mode)
|
||||
* @param sessionId session
|
||||
* @return instance for session
|
||||
* Get instance for session
|
||||
* @param sessionId session ID for multi-session mode, leave it {@code null} for global instance
|
||||
* @return instance for session, or global instance
|
||||
*/
|
||||
public static synchronized I18N getInstance(String sessionId) {
|
||||
if (sessionId == null) {
|
||||
return getInstance();
|
||||
} else {
|
||||
if (!sessionInstances.containsKey(sessionId)) {
|
||||
sessionInstances.put(sessionId, new DefaultI18N());
|
||||
}
|
||||
return sessionInstances.get(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove I18N instance for session
|
||||
|
|
|
@ -137,7 +137,7 @@ public class TrainModule implements UIModule {
|
|||
maxChartPoints = DEFAULT_MAX_CHART_POINTS;
|
||||
}
|
||||
|
||||
configuration = new Configuration(new Version(2, 3, 29));
|
||||
configuration = new Configuration(new Version(2, 3, 23));
|
||||
configuration.setDefaultEncoding("UTF-8");
|
||||
configuration.setLocale(Locale.US);
|
||||
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
|
||||
|
@ -199,6 +199,7 @@ public class TrainModule implements UIModule {
|
|||
}
|
||||
}));
|
||||
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
|
||||
r.add(new Route("/train/:sessionId/system/data", HttpMethod.GET, (path, rc) -> this.getSystemDataForSession(path.get(0), rc)));
|
||||
} else {
|
||||
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
|
||||
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
|
||||
|
@ -226,7 +227,9 @@ public class TrainModule implements UIModule {
|
|||
* @param rc Routing context
|
||||
*/
|
||||
private void renderFtl(String file, RoutingContext rc) {
|
||||
Map<String, String> input = DefaultI18N.getInstance().getMessages(DefaultI18N.getInstance().getDefaultLanguage());
|
||||
String sessionId = rc.request().getParam("sessionID");
|
||||
String langCode = DefaultI18N.getInstance(sessionId).getDefaultLanguage();
|
||||
Map<String, String> input = DefaultI18N.getInstance().getMessages(langCode);
|
||||
String html;
|
||||
try {
|
||||
String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8);
|
||||
|
|
|
@ -17,10 +17,15 @@
|
|||
|
||||
package org.deeplearning4j.ui;
|
||||
|
||||
import io.vertx.core.Future;
|
||||
import io.vertx.core.Promise;
|
||||
import io.vertx.core.Vertx;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.api.storage.StatsStorage;
|
||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
|
@ -50,63 +55,31 @@ import java.net.URL;
|
|||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
* Created by Alex on 08/10/2016.
|
||||
*/
|
||||
@Slf4j
|
||||
@Ignore
|
||||
public class TestVertxUI extends BaseDL4JTest {
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
UIServer.stopInstance();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUI() throws Exception {
|
||||
|
||||
StatsStorage ss = new InMemoryStatsStorage();
|
||||
|
||||
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
|
||||
assertEquals(9000, uiServer.getPort());
|
||||
uiServer.stop();
|
||||
VertxUIServer vertxUIServer = new VertxUIServer();
|
||||
// vertxUIServer.runMain(new String[] {"--uiPort", "9100", "-r", "true"});
|
||||
//
|
||||
// assertEquals(9100, vertxUIServer.getPort());
|
||||
// vertxUIServer.stop();
|
||||
|
||||
|
||||
// uiServer.attach(ss);
|
||||
//
|
||||
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
// .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||||
// .list()
|
||||
// .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
|
||||
// .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
|
||||
// .build();
|
||||
//
|
||||
// MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
// net.init();
|
||||
// net.setListeners(new StatsListener(ss, 3), new ScoreIterationListener(1));
|
||||
//
|
||||
// DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
//
|
||||
// for (int i = 0; i < 500; i++) {
|
||||
// net.fit(iter);
|
||||
//// Thread.sleep(100);
|
||||
// Thread.sleep(100);
|
||||
// }
|
||||
//
|
||||
//// uiServer.stop();
|
||||
|
||||
Thread.sleep(100000);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUI_VAE() throws Exception {
|
||||
//Variational autoencoder - for unsupervised layerwise pretraining
|
||||
|
||||
|
@ -144,13 +117,9 @@ public class TestVertxUI extends BaseDL4JTest {
|
|||
Thread.sleep(100);
|
||||
}
|
||||
|
||||
|
||||
Thread.sleep(100000);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUIMultipleSessions() throws Exception {
|
||||
|
||||
for (int session = 0; session < 3; session++) {
|
||||
|
@ -178,60 +147,10 @@ public class TestVertxUI extends BaseDL4JTest {
|
|||
Thread.sleep(100);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Thread.sleep(1000000);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUISequentialSessions() throws Exception {
|
||||
UIServer uiServer = UIServer.getInstance();
|
||||
StatsStorage ss = null;
|
||||
for (int session = 0; session < 3; session++) {
|
||||
|
||||
if (ss != null) {
|
||||
uiServer.detach(ss);
|
||||
}
|
||||
ss = new InMemoryStatsStorage();
|
||||
uiServer.attach(ss);
|
||||
|
||||
int numInputs = 4;
|
||||
int outputNum = 3;
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.activation(Activation.TANH)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Sgd(0.03))
|
||||
.l2(1e-4)
|
||||
.list()
|
||||
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
|
||||
.build())
|
||||
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
|
||||
.build())
|
||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nIn(3).nOut(outputNum).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
net.fit(iter);
|
||||
}
|
||||
Thread.sleep(5000);
|
||||
}
|
||||
|
||||
|
||||
Thread.sleep(1000000);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUICompGraph() throws Exception {
|
||||
public void testUICompGraph() {
|
||||
|
||||
StatsStorage ss = new InMemoryStatsStorage();
|
||||
|
||||
|
@ -254,10 +173,7 @@ public class TestVertxUI extends BaseDL4JTest {
|
|||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
net.fit(iter);
|
||||
Thread.sleep(100);
|
||||
}
|
||||
|
||||
Thread.sleep(1000000);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -304,11 +220,11 @@ public class TestVertxUI extends BaseDL4JTest {
|
|||
}
|
||||
});
|
||||
|
||||
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8);
|
||||
// System.out.println(json1);
|
||||
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"),
|
||||
StandardCharsets.UTF_8);
|
||||
|
||||
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8);
|
||||
// System.out.println(json2);
|
||||
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"),
|
||||
StandardCharsets.UTF_8);
|
||||
|
||||
assertNotEquals(json1, json2);
|
||||
|
||||
|
@ -336,11 +252,106 @@ public class TestVertxUI extends BaseDL4JTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testUIServerStop() {
|
||||
public void testUIServerStop() throws Exception {
|
||||
UIServer uiServer = UIServer.getInstance(true, null);
|
||||
assertTrue(uiServer.isMultiSession());
|
||||
assertFalse(uiServer.isStopped());
|
||||
|
||||
long sleepMilliseconds = 1_000;
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer.stop();
|
||||
assertTrue(uiServer.isStopped());
|
||||
|
||||
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer = UIServer.getInstance(false, null);
|
||||
assertFalse(uiServer.isMultiSession());
|
||||
assertFalse(uiServer.isStopped());
|
||||
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer.stop();
|
||||
assertTrue(uiServer.isStopped());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testUIServerStopAsync() throws Exception {
|
||||
UIServer uiServer = UIServer.getInstance(true, null);
|
||||
assertTrue(uiServer.isMultiSession());
|
||||
assertFalse(uiServer.isStopped());
|
||||
|
||||
long sleepMilliseconds = 1_000;
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
Promise<Void> promise = Promise.promise();
|
||||
promise.future().compose(
|
||||
success -> Future.future(prom -> latch.countDown()),
|
||||
failure -> Future.future(prom -> latch.countDown())
|
||||
);
|
||||
|
||||
uiServer.stopAsync(promise);
|
||||
latch.await();
|
||||
assertTrue(uiServer.isStopped());
|
||||
|
||||
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer = UIServer.getInstance(false, null);
|
||||
assertFalse(uiServer.isMultiSession());
|
||||
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer.stop();
|
||||
}
|
||||
|
||||
@Test (expected = DL4JException.class)
|
||||
public void testUIStartPortAlreadyBound() throws InterruptedException {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
//Create HttpServer that binds the same port
|
||||
int port = VertxUIServer.DEFAULT_UI_PORT;
|
||||
Vertx vertx = Vertx.vertx();
|
||||
vertx.createHttpServer()
|
||||
.requestHandler(event -> {})
|
||||
.listen(port, result -> latch.countDown());
|
||||
latch.await();
|
||||
|
||||
try {
|
||||
//DL4JException signals that the port cannot be bound, UI server cannot start
|
||||
UIServer.getInstance();
|
||||
} finally {
|
||||
vertx.close();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUIStartAsync() throws InterruptedException {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
Promise<String> promise = Promise.promise();
|
||||
promise.future().compose(
|
||||
success -> Future.future(prom -> latch.countDown()),
|
||||
failure -> Future.future(prom -> latch.countDown())
|
||||
);
|
||||
int port = VertxUIServer.DEFAULT_UI_PORT;
|
||||
VertxUIServer.getInstance(port, false, null, promise);
|
||||
latch.await();
|
||||
if (promise.future().succeeded()) {
|
||||
String deploymentId = promise.future().result();
|
||||
log.debug("UI server deployed, deployment ID = {}", deploymentId);
|
||||
} else {
|
||||
log.debug("UI server failed to deploy.", promise.future().cause());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUIAutoStopOnThreadExit() throws InterruptedException {
|
||||
AtomicReference<UIServer> uiServer = new AtomicReference<>();
|
||||
Thread thread = new Thread(() -> uiServer.set(UIServer.getInstance()));
|
||||
thread.start();
|
||||
thread.join();
|
||||
Thread.sleep(1_000);
|
||||
assertTrue(uiServer.get().isStopped());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,272 @@
|
|||
package org.deeplearning4j.ui;
|
||||
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import io.vertx.core.Future;
|
||||
import io.vertx.core.Promise;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.api.storage.StatsStorage;
|
||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
import org.deeplearning4j.ui.api.UIServer;
|
||||
import org.deeplearning4j.ui.stats.StatsListener;
|
||||
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.function.Function;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.net.HttpURLConnection;
|
||||
import java.net.URL;
|
||||
import java.net.URLEncoder;
|
||||
import java.util.HashMap;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Slf4j
|
||||
@Ignore
|
||||
public class TestVertxUIManual extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 3600_000L;
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUI() throws Exception {
|
||||
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
|
||||
assertEquals(9000, uiServer.getPort());
|
||||
|
||||
Thread.sleep(3000_000);
|
||||
uiServer.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUISequentialSessions() throws Exception {
|
||||
UIServer uiServer = UIServer.getInstance();
|
||||
StatsStorage ss = null;
|
||||
for (int session = 0; session < 3; session++) {
|
||||
|
||||
if (ss != null) {
|
||||
uiServer.detach(ss);
|
||||
}
|
||||
ss = new InMemoryStatsStorage();
|
||||
uiServer.attach(ss);
|
||||
|
||||
int numInputs = 4;
|
||||
int outputNum = 3;
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.activation(Activation.TANH)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Sgd(0.03))
|
||||
.l2(1e-4)
|
||||
.list()
|
||||
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
|
||||
.build())
|
||||
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
|
||||
.build())
|
||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nIn(3).nOut(outputNum).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
||||
for (int i = 0; i < 100; i++) {
|
||||
net.fit(iter);
|
||||
}
|
||||
Thread.sleep(5_000);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUIServerStop() throws Exception {
|
||||
UIServer uiServer = UIServer.getInstance(true, null);
|
||||
assertTrue(uiServer.isMultiSession());
|
||||
assertFalse(uiServer.isStopped());
|
||||
|
||||
long sleepMilliseconds = 30_000;
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer.stop();
|
||||
assertTrue(uiServer.isStopped());
|
||||
|
||||
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer = UIServer.getInstance(false, null);
|
||||
assertFalse(uiServer.isMultiSession());
|
||||
assertFalse(uiServer.isStopped());
|
||||
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer.stop();
|
||||
assertTrue(uiServer.isStopped());
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUIServerStopAsync() throws Exception {
|
||||
UIServer uiServer = UIServer.getInstance(true, null);
|
||||
assertTrue(uiServer.isMultiSession());
|
||||
assertFalse(uiServer.isStopped());
|
||||
|
||||
long sleepMilliseconds = 30_000;
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
Promise<Void> promise = Promise.promise();
|
||||
promise.future().compose(
|
||||
success -> Future.future(prom -> latch.countDown()),
|
||||
failure -> Future.future(prom -> latch.countDown())
|
||||
);
|
||||
|
||||
uiServer.stopAsync(promise);
|
||||
latch.await();
|
||||
assertTrue(uiServer.isStopped());
|
||||
|
||||
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer = UIServer.getInstance(false, null);
|
||||
assertFalse(uiServer.isMultiSession());
|
||||
|
||||
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||
Thread.sleep(sleepMilliseconds);
|
||||
uiServer.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUIAutoAttachDetach() throws Exception {
|
||||
long detachTimeoutMillis = 15_000;
|
||||
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(detachTimeoutMillis);
|
||||
UIServer uIServer = UIServer.getInstance(true, statsProvider);
|
||||
statsProvider.setUIServer(uIServer);
|
||||
InMemoryStatsStorage ss = null;
|
||||
for (int session = 0; session < 3; session++) {
|
||||
int layerSize = session + 4;
|
||||
|
||||
ss = new InMemoryStatsStorage();
|
||||
String sessionId = Integer.toString(session);
|
||||
statsProvider.put(sessionId, ss);
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
||||
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
|
||||
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
StatsListener statsListener = new StatsListener(ss, 1);
|
||||
statsListener.setSessionID(sessionId);
|
||||
net.setListeners(statsListener, new ScoreIterationListener(1));
|
||||
uIServer.attach(ss);
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
||||
for (int i = 0; i < 20; i++) {
|
||||
net.fit(iter);
|
||||
}
|
||||
|
||||
assertTrue(uIServer.isAttached(ss));
|
||||
uIServer.detach(ss);
|
||||
assertFalse(uIServer.isAttached(ss));
|
||||
|
||||
/*
|
||||
* Visiting /train/:sessionId to auto-attach StatsStorage
|
||||
*/
|
||||
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
|
||||
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
|
||||
conn.connect();
|
||||
|
||||
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||
assertTrue(uIServer.isAttached(ss));
|
||||
}
|
||||
|
||||
Thread.sleep(detachTimeoutMillis + 60_000);
|
||||
assertFalse(uIServer.isAttached(ss));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get URL-encoded URL for training session on given server address
|
||||
* @param serverAddress server address
|
||||
* @param sessionId session ID
|
||||
* @return URL
|
||||
* @throws UnsupportedEncodingException if the used encoding is not supported
|
||||
*/
|
||||
private static String trainingSessionUrl(String serverAddress, String sessionId)
|
||||
throws UnsupportedEncodingException {
|
||||
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
||||
}
|
||||
|
||||
/**
|
||||
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
|
||||
* @author Tamas Fenyvesi
|
||||
*/
|
||||
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
|
||||
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
|
||||
UIServer uIServer;
|
||||
long autoDetachTimeoutMillis;
|
||||
|
||||
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
|
||||
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
|
||||
}
|
||||
|
||||
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
|
||||
storageForSession.put(sessionId, statsStorage);
|
||||
}
|
||||
|
||||
public void setUIServer(UIServer uIServer) {
|
||||
this.uIServer = uIServer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public StatsStorage apply(String sessionId) {
|
||||
StatsStorage statsStorage = storageForSession.get(sessionId);
|
||||
|
||||
if (statsStorage != null) {
|
||||
new Thread(() -> {
|
||||
try {
|
||||
log.info("Waiting to detach StatsStorage (session ID: {})" +
|
||||
" after {} ms ", sessionId, autoDetachTimeoutMillis);
|
||||
Thread.sleep(autoDetachTimeoutMillis);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
|
||||
sessionId, autoDetachTimeoutMillis);
|
||||
uIServer.detach(statsStorage);
|
||||
log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}",
|
||||
uIServer.getAddress(), sessionId);
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
return statsStorage;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -18,9 +18,11 @@
|
|||
package org.deeplearning4j.ui;
|
||||
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.api.storage.StatsStorage;
|
||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||
import org.deeplearning4j.exception.DL4JException;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
|
@ -32,7 +34,6 @@ import org.deeplearning4j.ui.api.UIServer;
|
|||
import org.deeplearning4j.ui.stats.StatsListener;
|
||||
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
|
||||
import org.junit.Before;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
@ -52,21 +53,22 @@ import static org.junit.Assert.*;
|
|||
/**
|
||||
* @author Tamas Fenyvesi
|
||||
*/
|
||||
@Ignore
|
||||
@Slf4j
|
||||
public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||
|
||||
@Before
|
||||
public void setUp() throws Exception {
|
||||
UIServer.stopInstance();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUIMultiSession() throws Exception {
|
||||
|
||||
public void testUIMultiSessionParallelTraining() throws Exception {
|
||||
UIServer uIServer = UIServer.getInstance(true, null);
|
||||
HashMap<Thread, StatsStorage> statStorageForThread = new HashMap<>();
|
||||
HashMap<Thread, String> sessionIdForThread = new HashMap<>();
|
||||
|
||||
for (int session = 0; session < 10; session++) {
|
||||
int parallelTrainingCount = 10;
|
||||
for (int session = 0; session < parallelTrainingCount; session++) {
|
||||
|
||||
StatsStorage ss = new InMemoryStatsStorage();
|
||||
|
||||
|
@ -104,8 +106,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
|||
sessionIdForThread.put(training, sessionId);
|
||||
}
|
||||
|
||||
Thread.sleep(10000000);
|
||||
|
||||
for (Thread thread: statStorageForThread.keySet()) {
|
||||
StatsStorage ss = statStorageForThread.get(thread);
|
||||
String sessionId = sessionIdForThread.get(thread);
|
||||
|
@ -120,7 +120,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
|||
|
||||
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||
assertTrue(uIServer.isAttached(ss));
|
||||
} catch (InterruptedException | IOException e) {
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
fail(e.getMessage());
|
||||
} finally {
|
||||
|
@ -128,7 +128,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
|||
assertFalse(uIServer.isAttached(ss));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -181,61 +180,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
|||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@Ignore
|
||||
public void testUIAutoAttachDetach() throws Exception {
|
||||
|
||||
long autoDetachTimeoutMillis = 30_000;
|
||||
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(autoDetachTimeoutMillis);
|
||||
UIServer uIServer = UIServer.getInstance(true, statsProvider);
|
||||
statsProvider.setUIServer(uIServer);
|
||||
|
||||
for (int session = 0; session < 3; session++) {
|
||||
int layerSize = session + 4;
|
||||
|
||||
InMemoryStatsStorage ss = new InMemoryStatsStorage();
|
||||
String sessionId = Integer.toString(session);
|
||||
statsProvider.put(sessionId, ss);
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
||||
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
|
||||
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
|
||||
.build();
|
||||
|
||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||
net.init();
|
||||
|
||||
StatsListener statsListener = new StatsListener(ss, 1);
|
||||
statsListener.setSessionID(sessionId);
|
||||
net.setListeners(statsListener, new ScoreIterationListener(1));
|
||||
uIServer.attach(ss);
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
||||
for (int i = 0; i < 20; i++) {
|
||||
net.fit(iter);
|
||||
}
|
||||
|
||||
assertTrue(uIServer.isAttached(ss));
|
||||
uIServer.detach(ss);
|
||||
assertFalse(uIServer.isAttached(ss));
|
||||
|
||||
/*
|
||||
* Visiting /train/:sessionId to auto-attach StatsStorage
|
||||
*/
|
||||
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
|
||||
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
|
||||
conn.connect();
|
||||
|
||||
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||
assertTrue(uIServer.isAttached(ss));
|
||||
}
|
||||
|
||||
Thread.sleep(1_000_000);
|
||||
}
|
||||
|
||||
@Test (expected = RuntimeException.class)
|
||||
@Test (expected = DL4JException.class)
|
||||
public void testUIServerGetInstanceMultipleCalls1() {
|
||||
UIServer uiServer = UIServer.getInstance();
|
||||
assertFalse(uiServer.isMultiSession());
|
||||
|
@ -243,7 +188,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
|||
|
||||
}
|
||||
|
||||
@Test (expected = RuntimeException.class)
|
||||
@Test (expected = DL4JException.class)
|
||||
public void testUIServerGetInstanceMultipleCalls2() {
|
||||
UIServer uiServer = UIServer.getInstance(true, null);
|
||||
assertTrue(uiServer.isMultiSession());
|
||||
|
@ -257,55 +202,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
|||
* @return URL
|
||||
* @throws UnsupportedEncodingException if the used encoding is not supported
|
||||
*/
|
||||
private static String trainingSessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
|
||||
private static String trainingSessionUrl(String serverAddress, String sessionId)
|
||||
throws UnsupportedEncodingException {
|
||||
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
||||
}
|
||||
|
||||
/**
|
||||
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
|
||||
* @author fenyvesit
|
||||
*
|
||||
*/
|
||||
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
|
||||
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
|
||||
UIServer uIServer;
|
||||
long autoDetachTimeoutMillis;
|
||||
|
||||
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
|
||||
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
|
||||
}
|
||||
|
||||
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
|
||||
storageForSession.put(sessionId, statsStorage);
|
||||
}
|
||||
|
||||
public void setUIServer(UIServer uIServer) {
|
||||
this.uIServer = uIServer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public StatsStorage apply(String sessionId) {
|
||||
StatsStorage statsStorage = storageForSession.get(sessionId);
|
||||
|
||||
if (statsStorage != null) {
|
||||
new Thread(() -> {
|
||||
try {
|
||||
System.out.println("Waiting to detach StatsStorage (session ID: " + sessionId + ")" +
|
||||
" after " + autoDetachTimeoutMillis + " ms ");
|
||||
Thread.sleep(autoDetachTimeoutMillis);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
} finally {
|
||||
System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " +
|
||||
autoDetachTimeoutMillis + " ms.");
|
||||
uIServer.detach(statsStorage);
|
||||
System.out.println(" To re-attach StatsStorage of training session, visit " +
|
||||
uIServer.getAddress() + "/train/" + sessionId);
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
|
||||
return statsStorage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue