Merge & fix small test conflict

Signed-off-by: Alex Black <blacka101@gmail.com>
master
Alex Black 2020-04-23 10:31:15 +10:00
commit 0a1f75a5d9
7 changed files with 602 additions and 263 deletions

View File

@ -19,6 +19,8 @@ package org.deeplearning4j.ui;
import com.beust.jcommander.JCommander; import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter; import com.beust.jcommander.Parameter;
import io.vertx.core.AbstractVerticle; import io.vertx.core.AbstractVerticle;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx; import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer; import io.vertx.core.http.HttpServer;
import io.vertx.core.http.impl.MimeMapping; import io.vertx.core.http.impl.MimeMapping;
@ -35,6 +37,7 @@ import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener; import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.config.DL4JSystemProperties; import org.deeplearning4j.config.DL4JSystemProperties;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.api.Route; import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule; import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.api.UIServer;
@ -72,35 +75,69 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
private static Function<String, StatsStorage> statsStorageProvider; private static Function<String, StatsStorage> statsStorageProvider;
private static Integer instancePort; private static Integer instancePort;
private static Thread autoStopThread;
private TrainModule trainModule; private TrainModule trainModule;
public static VertxUIServer getInstance() { /**
return getInstance(null, multiSession.get(), null); * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* @param port TCP socket port for {@link HttpServer} to listen
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}.
* @return UI instance for this JVM
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
public static VertxUIServer getInstance(Integer port, boolean multiSession,
Function<String, StatsStorage> statsStorageProvider) throws DL4JException {
return getInstance(port, multiSession, statsStorageProvider, null);
} }
public static VertxUIServer getInstance(Integer port, boolean multiSession, Function<String, StatsStorage> statsStorageProvider){ /**
*
* Get (and, initialize if necessary) the UI server. This function will wait until the server started
* (synchronous way), or pass the given callback to handle success or failure (asynchronous way).
* @param port TCP socket port for {@link HttpServer} to listen
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}.
* @param startCallback asynchronous deployment handler callback that will be notify of success or failure.
* If {@code null} given, then this method will wait until deployment is complete.
* If the deployment is successful the result will contain a String representing the
* unique deployment ID of the deployment.
* @return UI server instance
* @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/
public static VertxUIServer getInstance(Integer port, boolean multiSession,
Function<String, StatsStorage> statsStorageProvider, Promise<String> startCallback)
throws DL4JException {
if (instance == null || instance.isStopped()) { if (instance == null || instance.isStopped()) {
VertxUIServer.multiSession.set(multiSession); VertxUIServer.multiSession.set(multiSession);
VertxUIServer.setStatsStorageProvider(statsStorageProvider); VertxUIServer.setStatsStorageProvider(statsStorageProvider);
instancePort = port; instancePort = port;
Vertx vertx = Vertx.vertx();
//Launch UI server verticle and wait for it to start if (startCallback != null) {
CountDownLatch l = new CountDownLatch(1); //Launch UI server verticle and pass asynchronous callback that will be notified of completion
vertx.deployVerticle(VertxUIServer.class.getName(), res -> { deploy(startCallback);
l.countDown(); } else {
}); //Launch UI server verticle and wait for it to start
try { deploy();
l.await(5000, TimeUnit.MILLISECONDS); }
} catch (InterruptedException e){ } //Ignore
} else if (!instance.isStopped()) { } else if (!instance.isStopped()) {
if (instance.multiSession.get() && !instance.isMultiSession()) { if (multiSession && !instance.isMultiSession()) {
throw new RuntimeException("Cannot return multi-session instance." + throw new DL4JException("Cannot return multi-session instance." +
" UIServer has already started in single-session mode at " + instance.getAddress() + " UIServer has already started in single-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one."); " You may stop the UI server instance, and start a new one.");
} else if (!instance.multiSession.get() && instance.isMultiSession()) { } else if (!multiSession && instance.isMultiSession()) {
throw new RuntimeException("Cannot return single-session instance." + throw new DL4JException("Cannot return single-session instance." +
" UIServer has already started in multi-session mode at " + instance.getAddress() + " UIServer has already started in multi-session mode at " + instance.getAddress() +
" You may stop the UI server instance, and start a new one."); " You may stop the UI server instance, and start a new one.");
} }
@ -109,6 +146,64 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
return instance; return instance;
} }
/**
* Deploy (start) {@link VertxUIServer}, waiting until starting is complete.
* @throws DL4JException if UI server failed to start;
* if interrupted while waiting for completion
*/
private static void deploy() throws DL4JException {
CountDownLatch l = new CountDownLatch(1);
Promise<String> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> l.countDown()),
failure -> Future.future(prom -> l.countDown())
);
deploy(promise);
// synchronous function
try {
l.await();
} catch (InterruptedException e) {
throw new DL4JException(e);
}
Future<String> future = promise.future();
if (future.failed()) {
throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause());
}
}
/**
* Deploy (start) {@link VertxUIServer},
* and pass callback to handle successful or failed completion of deployment.
* @param startCallback promise that will handle success or failure of deployment.
* If the deployment is successful the result will contain a String representing the unique deployment ID of the
* deployment.
*/
private static void deploy(Promise<String> startCallback) {
log.debug("Deeplearning4j UI server is starting.");
Promise<String> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> startCallback.complete(success)),
failure -> Future.future(prom -> startCallback.fail(new RuntimeException(failure)))
);
Vertx vertx = Vertx.vertx();
vertx.deployVerticle(VertxUIServer.class.getName(), promise);
Thread currentThread = Thread.currentThread();
VertxUIServer.autoStopThread = new Thread(() -> {
try {
currentThread.join();
log.info("Deeplearning4j UI server is auto-stopping.");
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
instance.stop();
}
} catch (InterruptedException e) {
log.error("Deeplearning4j UI server auto-stop thread was interrupted.", e);
}
});
}
private List<UIModule> uiModules = new CopyOnWriteArrayList<>(); private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
private RemoteReceiverModule remoteReceiverModule; private RemoteReceiverModule remoteReceiverModule;
@ -132,10 +227,18 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
instance = this; instance = this;
} }
public static void stopInstance(){ public static void stopInstance() throws Exception {
if(instance == null) if(instance == null || instance.isStopped())
return; return;
instance.stop(); instance.stop();
VertxUIServer.reset();
}
private static void reset() {
VertxUIServer.instance = null;
VertxUIServer.statsStorageProvider = null;
VertxUIServer.instancePort = null;
VertxUIServer.multiSession.set(false);
} }
/** /**
@ -145,12 +248,14 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) { public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
if (statsStorageProvider != null) { if (statsStorageProvider != null) {
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider); this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
this.trainModule.setSessionLoader(this.statsStorageLoader); if (trainModule != null) {
this.trainModule.setSessionLoader(this.statsStorageLoader);
}
} }
} }
@Override @Override
public void start() throws Exception { public void start(Promise<Void> startCallback) throws Exception {
//Create REST endpoints //Create REST endpoints
File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis()); File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
uploadDir.mkdirs(); uploadDir.mkdirs();
@ -192,6 +297,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
}); });
} }
if (VertxUIServer.statsStorageProvider != null) {
autoAttachStatsStorageBySessionId(VertxUIServer.statsStorageProvider);
}
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress); trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
@ -246,17 +354,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
} }
} }
server = vertx.createHttpServer()
.requestHandler(r)
.listen(port);
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable()); uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
uiEventRoutingThread.setDaemon(true); uiEventRoutingThread.setDaemon(true);
uiEventRoutingThread.start(); uiEventRoutingThread.start();
String address = UIServer.getInstance().getAddress(); server = vertx.createHttpServer()
log.info("Deeplearning4j UI server started at: {}", address); .requestHandler(r)
.listen(port, result -> {
if (result.succeeded()) {
String address = UIServer.getInstance().getAddress();
log.info("Deeplearning4j UI server started at: {}", address);
startCallback.complete();
} else {
startCallback.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port "
+ server.actualPort(), result.cause()));
}
});
VertxUIServer.autoStopThread.start();
} }
private List<String> extractArgsFromRoute(String path, RoutingContext rc) { private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
@ -302,11 +416,33 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
} }
@Override @Override
public void stop() { public void stop() throws InterruptedException {
server.close(); CountDownLatch l = new CountDownLatch(1);
shutdown.set(true); Promise<Void> promise = Promise.promise();
promise.future().compose(
successEvent -> Future.future(prom -> l.countDown()),
failureEvent -> Future.future(prom -> l.countDown())
);
stopAsync(promise);
// synchronous function should wait until the server is stopped
l.await();
} }
@Override
public void stopAsync(Promise<Void> stopCallback) {
/**
* Stop Vertx instance and release any resources held by it.
* Pass promise to {@link #stop(Promise)}.
*/
vertx.close(ar -> stopCallback.handle(ar));
}
@Override
public void stop(Promise<Void> stopCallback) {
shutdown.set(true);
stopCallback.complete();
log.info("Deeplearning4j UI server stopped.");
}
@Override @Override
public boolean isStopped() { public boolean isStopped() {
@ -353,8 +489,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
if (!statsStorageInstances.contains(statsStorage)) if (!statsStorageInstances.contains(statsStorage))
return; //No op return; //No op
boolean found = false; boolean found = false;
for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext(); ) { for (Pair<StatsStorage, StatsStorageListener> p : listeners) {
Pair<StatsStorage, StatsStorageListener> p = iterator.next();
if (p.getFirst() == statsStorage) { //Same object, not equality if (p.getFirst() == statsStorage) { //Same object, not equality
statsStorage.deregisterStatsStorageListener(p.getSecond()); statsStorage.deregisterStatsStorageListener(p.getSecond());
listeners.remove(p); listeners.remove(p);

View File

@ -17,8 +17,10 @@
package org.deeplearning4j.ui.api; package org.deeplearning4j.ui.api;
import io.vertx.core.Promise;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter; import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.VertxUIServer; import org.deeplearning4j.ui.VertxUIServer;
import org.nd4j.linalg.function.Function; import org.nd4j.linalg.function.Function;
@ -32,18 +34,24 @@ import java.util.List;
public interface UIServer { public interface UIServer {
/** /**
* Get (and, initialize if necessary) the UI server. * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* Singleton pattern - all calls to getInstance() will return the same UI instance. * Singleton pattern - all calls to getInstance() will return the same UI instance.
* *
* @return UI instance for this JVM * @return UI instance for this JVM
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session) * @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/ */
static UIServer getInstance() throws RuntimeException { static UIServer getInstance() throws DL4JException {
return getInstance(false, null); if (VertxUIServer.getInstance() != null && !VertxUIServer.getInstance().isStopped()) {
return VertxUIServer.getInstance();
} else {
return getInstance(false, null);
}
} }
/** /**
* Get (and, initialize if necessary) the UI server. * Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
* Singleton pattern - all calls to getInstance() will return the same UI instance. * Singleton pattern - all calls to getInstance() will return the same UI instance.
* *
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs. * @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
@ -52,16 +60,19 @@ public interface UIServer {
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed * <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
* as URL path parameter in multi-session mode, or leave it {@code null}. * as URL path parameter in multi-session mode, or leave it {@code null}.
* @return UI instance for this JVM * @return UI instance for this JVM
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session) * @throws DL4JException if UI server failed to start;
* if the instance has already started in a different mode (multi/single-session);
* if interrupted while waiting for completion
*/ */
static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider) throws RuntimeException { static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider)
throws DL4JException {
return VertxUIServer.getInstance(null, multiSession, statsStorageProvider); return VertxUIServer.getInstance(null, multiSession, statsStorageProvider);
} }
/** /**
* Stop UIServer instance, if already running * Stop UIServer instance, if already running
*/ */
static void stopInstance() { static void stopInstance() throws Exception {
VertxUIServer.stopInstance(); VertxUIServer.stopInstance();
} }
@ -144,8 +155,15 @@ public interface UIServer {
boolean isRemoteListenerEnabled(); boolean isRemoteListenerEnabled();
/** /**
* Stop/shut down the UI server. * Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
*/ */
void stop(); void stop() throws Exception;
/**
* Stop/shut down the UI server.
* This asynchronous function should immediately return, and notify the callback {@link Promise} on completion:
* either call {@link Promise#complete} or {@link io.vertx.core.Promise#fail}.
* @param stopCallback callback {@link Promise} to notify on completion
*/
void stopAsync(Promise<Void> stopCallback);
} }

View File

@ -69,15 +69,19 @@ public class DefaultI18N implements I18N {
} }
/** /**
* Get instance for session (used in multi-session mode) * Get instance for session
* @param sessionId session * @param sessionId session ID for multi-session mode, leave it {@code null} for global instance
* @return instance for session * @return instance for session, or global instance
*/ */
public static synchronized I18N getInstance(String sessionId) { public static synchronized I18N getInstance(String sessionId) {
if (!sessionInstances.containsKey(sessionId)) { if (sessionId == null) {
sessionInstances.put(sessionId, new DefaultI18N()); return getInstance();
} else {
if (!sessionInstances.containsKey(sessionId)) {
sessionInstances.put(sessionId, new DefaultI18N());
}
return sessionInstances.get(sessionId);
} }
return sessionInstances.get(sessionId);
} }
/** /**

View File

@ -137,7 +137,7 @@ public class TrainModule implements UIModule {
maxChartPoints = DEFAULT_MAX_CHART_POINTS; maxChartPoints = DEFAULT_MAX_CHART_POINTS;
} }
configuration = new Configuration(new Version(2, 3, 29)); configuration = new Configuration(new Version(2, 3, 23));
configuration.setDefaultEncoding("UTF-8"); configuration.setDefaultEncoding("UTF-8");
configuration.setLocale(Locale.US); configuration.setLocale(Locale.US);
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER); configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
@ -199,6 +199,7 @@ public class TrainModule implements UIModule {
} }
})); }));
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc))); r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
r.add(new Route("/train/:sessionId/system/data", HttpMethod.GET, (path, rc) -> this.getSystemDataForSession(path.get(0), rc)));
} else { } else {
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview"))); r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID))); r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
@ -226,7 +227,9 @@ public class TrainModule implements UIModule {
* @param rc Routing context * @param rc Routing context
*/ */
private void renderFtl(String file, RoutingContext rc) { private void renderFtl(String file, RoutingContext rc) {
Map<String, String> input = DefaultI18N.getInstance().getMessages(DefaultI18N.getInstance().getDefaultLanguage()); String sessionId = rc.request().getParam("sessionID");
String langCode = DefaultI18N.getInstance(sessionId).getDefaultLanguage();
Map<String, String> input = DefaultI18N.getInstance().getMessages(langCode);
String html; String html;
try { try {
String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8); String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8);

View File

@ -17,10 +17,15 @@
package org.deeplearning4j.ui; package org.deeplearning4j.ui;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
@ -50,63 +55,31 @@ import java.net.URL;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*; import static org.junit.Assert.*;
/** /**
* Created by Alex on 08/10/2016. * Created by Alex on 08/10/2016.
*/ */
@Slf4j
@Ignore @Ignore
public class TestVertxUI extends BaseDL4JTest { public class TestVertxUI extends BaseDL4JTest {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
UIServer.stopInstance(); UIServer.stopInstance();
} }
@Test @Test
@Ignore
public void testUI() throws Exception { public void testUI() throws Exception {
StatsStorage ss = new InMemoryStatsStorage();
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance(); VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
assertEquals(9000, uiServer.getPort()); assertEquals(9000, uiServer.getPort());
uiServer.stop(); uiServer.stop();
VertxUIServer vertxUIServer = new VertxUIServer();
// vertxUIServer.runMain(new String[] {"--uiPort", "9100", "-r", "true"});
//
// assertEquals(9100, vertxUIServer.getPort());
// vertxUIServer.stop();
// uiServer.attach(ss);
//
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
// .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
// .list()
// .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
// .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
// .build();
//
// MultiLayerNetwork net = new MultiLayerNetwork(conf);
// net.init();
// net.setListeners(new StatsListener(ss, 3), new ScoreIterationListener(1));
//
// DataSetIterator iter = new IrisDataSetIterator(150, 150);
//
// for (int i = 0; i < 500; i++) {
// net.fit(iter);
//// Thread.sleep(100);
// Thread.sleep(100);
// }
//
//// uiServer.stop();
Thread.sleep(100000);
} }
@Test @Test
@Ignore
public void testUI_VAE() throws Exception { public void testUI_VAE() throws Exception {
//Variational autoencoder - for unsupervised layerwise pretraining //Variational autoencoder - for unsupervised layerwise pretraining
@ -144,13 +117,9 @@ public class TestVertxUI extends BaseDL4JTest {
Thread.sleep(100); Thread.sleep(100);
} }
Thread.sleep(100000);
} }
@Test @Test
@Ignore
public void testUIMultipleSessions() throws Exception { public void testUIMultipleSessions() throws Exception {
for (int session = 0; session < 3; session++) { for (int session = 0; session < 3; session++) {
@ -178,60 +147,10 @@ public class TestVertxUI extends BaseDL4JTest {
Thread.sleep(100); Thread.sleep(100);
} }
} }
Thread.sleep(1000000);
}
@Test
@Ignore
public void testUISequentialSessions() throws Exception {
UIServer uiServer = UIServer.getInstance();
StatsStorage ss = null;
for (int session = 0; session < 3; session++) {
if (ss != null) {
uiServer.detach(ss);
}
ss = new InMemoryStatsStorage();
uiServer.attach(ss);
int numInputs = 4;
int outputNum = 3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(0.03))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(outputNum).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 1000; i++) {
net.fit(iter);
}
Thread.sleep(5000);
}
Thread.sleep(1000000);
} }
@Test @Test
@Ignore public void testUICompGraph() {
public void testUICompGraph() throws Exception {
StatsStorage ss = new InMemoryStatsStorage(); StatsStorage ss = new InMemoryStatsStorage();
@ -254,10 +173,7 @@ public class TestVertxUI extends BaseDL4JTest {
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
net.fit(iter); net.fit(iter);
Thread.sleep(100);
} }
Thread.sleep(1000000);
} }
@Test @Test
@ -304,11 +220,11 @@ public class TestVertxUI extends BaseDL4JTest {
} }
}); });
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8); String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"),
// System.out.println(json1); StandardCharsets.UTF_8);
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8); String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"),
// System.out.println(json2); StandardCharsets.UTF_8);
assertNotEquals(json1, json2); assertNotEquals(json1, json2);
@ -336,11 +252,106 @@ public class TestVertxUI extends BaseDL4JTest {
} }
@Test @Test
public void testUIServerStop() { public void testUIServerStop() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null); UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession()); assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 1_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop(); uiServer.stop();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null); uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession()); assertFalse(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
assertTrue(uiServer.isStopped());
}
@Test
public void testUIServerStopAsync() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 1_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
CountDownLatch latch = new CountDownLatch(1);
Promise<Void> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> latch.countDown()),
failure -> Future.future(prom -> latch.countDown())
);
uiServer.stopAsync(promise);
latch.await();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
}
@Test (expected = DL4JException.class)
public void testUIStartPortAlreadyBound() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
//Create HttpServer that binds the same port
int port = VertxUIServer.DEFAULT_UI_PORT;
Vertx vertx = Vertx.vertx();
vertx.createHttpServer()
.requestHandler(event -> {})
.listen(port, result -> latch.countDown());
latch.await();
try {
//DL4JException signals that the port cannot be bound, UI server cannot start
UIServer.getInstance();
} finally {
vertx.close();
}
}
@Test
public void testUIStartAsync() throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
Promise<String> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> latch.countDown()),
failure -> Future.future(prom -> latch.countDown())
);
int port = VertxUIServer.DEFAULT_UI_PORT;
VertxUIServer.getInstance(port, false, null, promise);
latch.await();
if (promise.future().succeeded()) {
String deploymentId = promise.future().result();
log.debug("UI server deployed, deployment ID = {}", deploymentId);
} else {
log.debug("UI server failed to deploy.", promise.future().cause());
}
}
@Test
public void testUIAutoStopOnThreadExit() throws InterruptedException {
AtomicReference<UIServer> uiServer = new AtomicReference<>();
Thread thread = new Thread(() -> uiServer.set(UIServer.getInstance()));
thread.start();
thread.join();
Thread.sleep(1_000);
assertTrue(uiServer.get().isStopped());
} }
} }

View File

@ -0,0 +1,272 @@
package org.deeplearning4j.ui;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.UnsupportedEncodingException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.util.HashMap;
import java.util.concurrent.CountDownLatch;
import static org.junit.Assert.*;
@Slf4j
@Ignore
public class TestVertxUIManual extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 3600_000L;
}
@Test
@Ignore
public void testUI() throws Exception {
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
assertEquals(9000, uiServer.getPort());
Thread.sleep(3000_000);
uiServer.stop();
}
@Test
@Ignore
public void testUISequentialSessions() throws Exception {
UIServer uiServer = UIServer.getInstance();
StatsStorage ss = null;
for (int session = 0; session < 3; session++) {
if (ss != null) {
uiServer.detach(ss);
}
ss = new InMemoryStatsStorage();
uiServer.attach(ss);
int numInputs = 4;
int outputNum = 3;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.updater(new Sgd(0.03))
.l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
.build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(3).nOut(outputNum).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 100; i++) {
net.fit(iter);
}
Thread.sleep(5_000);
}
}
@Test
@Ignore
public void testUIServerStop() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 30_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
assertTrue(uiServer.isStopped());
}
@Test
@Ignore
public void testUIServerStopAsync() throws Exception {
UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession());
assertFalse(uiServer.isStopped());
long sleepMilliseconds = 30_000;
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
CountDownLatch latch = new CountDownLatch(1);
Promise<Void> promise = Promise.promise();
promise.future().compose(
success -> Future.future(prom -> latch.countDown()),
failure -> Future.future(prom -> latch.countDown())
);
uiServer.stopAsync(promise);
latch.await();
assertTrue(uiServer.isStopped());
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer = UIServer.getInstance(false, null);
assertFalse(uiServer.isMultiSession());
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
Thread.sleep(sleepMilliseconds);
uiServer.stop();
}
@Test
@Ignore
public void testUIAutoAttachDetach() throws Exception {
long detachTimeoutMillis = 15_000;
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(detachTimeoutMillis);
UIServer uIServer = UIServer.getInstance(true, statsProvider);
statsProvider.setUIServer(uIServer);
InMemoryStatsStorage ss = null;
for (int session = 0; session < 3; session++) {
int layerSize = session + 4;
ss = new InMemoryStatsStorage();
String sessionId = Integer.toString(session);
statsProvider.put(sessionId, ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
StatsListener statsListener = new StatsListener(ss, 1);
statsListener.setSessionID(sessionId);
net.setListeners(statsListener, new ScoreIterationListener(1));
uIServer.attach(ss);
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 20; i++) {
net.fit(iter);
}
assertTrue(uIServer.isAttached(ss));
uIServer.detach(ss);
assertFalse(uIServer.isAttached(ss));
/*
* Visiting /train/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss));
}
Thread.sleep(detachTimeoutMillis + 60_000);
assertFalse(uIServer.isAttached(ss));
}
/**
* Get URL-encoded URL for training session on given server address
* @param serverAddress server address
* @param sessionId session ID
* @return URL
* @throws UnsupportedEncodingException if the used encoding is not supported
*/
private static String trainingSessionUrl(String serverAddress, String sessionId)
throws UnsupportedEncodingException {
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
}
/**
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
* @author Tamas Fenyvesi
*/
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
UIServer uIServer;
long autoDetachTimeoutMillis;
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
}
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
storageForSession.put(sessionId, statsStorage);
}
public void setUIServer(UIServer uIServer) {
this.uIServer = uIServer;
}
@Override
public StatsStorage apply(String sessionId) {
StatsStorage statsStorage = storageForSession.get(sessionId);
if (statsStorage != null) {
new Thread(() -> {
try {
log.info("Waiting to detach StatsStorage (session ID: {})" +
" after {} ms ", sessionId, autoDetachTimeoutMillis);
Thread.sleep(autoDetachTimeoutMillis);
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
sessionId, autoDetachTimeoutMillis);
uIServer.detach(statsStorage);
log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}",
uIServer.getAddress(), sessionId);
}
}).start();
}
return statsStorage;
}
}
}

View File

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
@ -33,7 +34,6 @@ import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener; import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.junit.Before; import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -53,22 +53,22 @@ import static org.junit.Assert.*;
/** /**
* @author Tamas Fenyvesi * @author Tamas Fenyvesi
*/ */
@Ignore
@Slf4j @Slf4j
public class TestVertxUIMultiSession extends BaseDL4JTest { public class TestVertxUIMultiSession extends BaseDL4JTest {
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
UIServer.stopInstance(); UIServer.stopInstance();
} }
@Test @Test
public void testUIMultiSession() throws Exception { public void testUIMultiSessionParallelTraining() throws Exception {
UIServer uIServer = UIServer.getInstance(true, null); UIServer uIServer = UIServer.getInstance(true, null);
HashMap<Thread, StatsStorage> statStorageForThread = new HashMap<>(); HashMap<Thread, StatsStorage> statStorageForThread = new HashMap<>();
HashMap<Thread, String> sessionIdForThread = new HashMap<>(); HashMap<Thread, String> sessionIdForThread = new HashMap<>();
for (int session = 0; session < 10; session++) { int parallelTrainingCount = 10;
for (int session = 0; session < parallelTrainingCount; session++) {
StatsStorage ss = new InMemoryStatsStorage(); StatsStorage ss = new InMemoryStatsStorage();
@ -106,8 +106,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
sessionIdForThread.put(training, sessionId); sessionIdForThread.put(training, sessionId);
} }
Thread.sleep(10000000);
for (Thread thread: statStorageForThread.keySet()) { for (Thread thread: statStorageForThread.keySet()) {
StatsStorage ss = statStorageForThread.get(thread); StatsStorage ss = statStorageForThread.get(thread);
String sessionId = sessionIdForThread.get(thread); String sessionId = sessionIdForThread.get(thread);
@ -122,7 +120,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss)); assertTrue(uIServer.isAttached(ss));
} catch (InterruptedException | IOException e) { } catch (IOException e) {
log.error("",e); log.error("",e);
fail(e.getMessage()); fail(e.getMessage());
} finally { } finally {
@ -130,7 +128,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
assertFalse(uIServer.isAttached(ss)); assertFalse(uIServer.isAttached(ss));
} }
} }
} }
@Test @Test
@ -183,61 +180,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
} }
} }
@Test @Test (expected = DL4JException.class)
@Ignore
public void testUIAutoAttachDetach() throws Exception {
long autoDetachTimeoutMillis = 30_000;
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(autoDetachTimeoutMillis);
UIServer uIServer = UIServer.getInstance(true, statsProvider);
statsProvider.setUIServer(uIServer);
for (int session = 0; session < 3; session++) {
int layerSize = session + 4;
InMemoryStatsStorage ss = new InMemoryStatsStorage();
String sessionId = Integer.toString(session);
statsProvider.put(sessionId, ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
StatsListener statsListener = new StatsListener(ss, 1);
statsListener.setSessionID(sessionId);
net.setListeners(statsListener, new ScoreIterationListener(1));
uIServer.attach(ss);
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 20; i++) {
net.fit(iter);
}
assertTrue(uIServer.isAttached(ss));
uIServer.detach(ss);
assertFalse(uIServer.isAttached(ss));
/*
* Visiting /train/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(ss));
}
Thread.sleep(1_000_000);
}
@Test (expected = RuntimeException.class)
public void testUIServerGetInstanceMultipleCalls1() { public void testUIServerGetInstanceMultipleCalls1() {
UIServer uiServer = UIServer.getInstance(); UIServer uiServer = UIServer.getInstance();
assertFalse(uiServer.isMultiSession()); assertFalse(uiServer.isMultiSession());
@ -245,7 +188,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
} }
@Test (expected = RuntimeException.class) @Test (expected = DL4JException.class)
public void testUIServerGetInstanceMultipleCalls2() { public void testUIServerGetInstanceMultipleCalls2() {
UIServer uiServer = UIServer.getInstance(true, null); UIServer uiServer = UIServer.getInstance(true, null);
assertTrue(uiServer.isMultiSession()); assertTrue(uiServer.isMultiSession());
@ -259,55 +202,8 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
* @return URL * @return URL
* @throws UnsupportedEncodingException if the used encoding is not supported * @throws UnsupportedEncodingException if the used encoding is not supported
*/ */
private static String trainingSessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException { private static String trainingSessionUrl(String serverAddress, String sessionId)
throws UnsupportedEncodingException {
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8")); return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
} }
/**
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
* @author fenyvesit
*
*/
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
UIServer uIServer;
long autoDetachTimeoutMillis;
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
}
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
storageForSession.put(sessionId, statsStorage);
}
public void setUIServer(UIServer uIServer) {
this.uIServer = uIServer;
}
@Override
public StatsStorage apply(String sessionId) {
StatsStorage statsStorage = storageForSession.get(sessionId);
if (statsStorage != null) {
new Thread(() -> {
try {
System.out.println("Waiting to detach StatsStorage (session ID: " + sessionId + ")" +
" after " + autoDetachTimeoutMillis + " ms ");
Thread.sleep(autoDetachTimeoutMillis);
} catch (InterruptedException e) {
log.error("",e);
} finally {
System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " +
autoDetachTimeoutMillis + " ms.");
uIServer.detach(statsStorage);
System.out.println(" To re-attach StatsStorage of training session, visit " +
uIServer.getAddress() + "/train/" + sessionId);
}
}).start();
}
return statsStorage;
}
}
} }