Merge & fix small test conflict
Signed-off-by: Alex Black <blacka101@gmail.com>master
commit
0a1f75a5d9
|
@ -19,6 +19,8 @@ package org.deeplearning4j.ui;
|
||||||
import com.beust.jcommander.JCommander;
|
import com.beust.jcommander.JCommander;
|
||||||
import com.beust.jcommander.Parameter;
|
import com.beust.jcommander.Parameter;
|
||||||
import io.vertx.core.AbstractVerticle;
|
import io.vertx.core.AbstractVerticle;
|
||||||
|
import io.vertx.core.Future;
|
||||||
|
import io.vertx.core.Promise;
|
||||||
import io.vertx.core.Vertx;
|
import io.vertx.core.Vertx;
|
||||||
import io.vertx.core.http.HttpServer;
|
import io.vertx.core.http.HttpServer;
|
||||||
import io.vertx.core.http.impl.MimeMapping;
|
import io.vertx.core.http.impl.MimeMapping;
|
||||||
|
@ -35,6 +37,7 @@ import org.deeplearning4j.api.storage.StatsStorageEvent;
|
||||||
import org.deeplearning4j.api.storage.StatsStorageListener;
|
import org.deeplearning4j.api.storage.StatsStorageListener;
|
||||||
import org.deeplearning4j.api.storage.StatsStorageRouter;
|
import org.deeplearning4j.api.storage.StatsStorageRouter;
|
||||||
import org.deeplearning4j.config.DL4JSystemProperties;
|
import org.deeplearning4j.config.DL4JSystemProperties;
|
||||||
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.ui.api.Route;
|
import org.deeplearning4j.ui.api.Route;
|
||||||
import org.deeplearning4j.ui.api.UIModule;
|
import org.deeplearning4j.ui.api.UIModule;
|
||||||
import org.deeplearning4j.ui.api.UIServer;
|
import org.deeplearning4j.ui.api.UIServer;
|
||||||
|
@ -72,35 +75,69 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
private static Function<String, StatsStorage> statsStorageProvider;
|
private static Function<String, StatsStorage> statsStorageProvider;
|
||||||
|
|
||||||
private static Integer instancePort;
|
private static Integer instancePort;
|
||||||
|
private static Thread autoStopThread;
|
||||||
|
|
||||||
private TrainModule trainModule;
|
private TrainModule trainModule;
|
||||||
|
|
||||||
public static VertxUIServer getInstance() {
|
/**
|
||||||
return getInstance(null, multiSession.get(), null);
|
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||||
|
* @param port TCP socket port for {@link HttpServer} to listen
|
||||||
|
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
|
||||||
|
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
|
||||||
|
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
|
||||||
|
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
|
||||||
|
* as URL path parameter in multi-session mode, or leave it {@code null}.
|
||||||
|
* @return UI instance for this JVM
|
||||||
|
* @throws DL4JException if UI server failed to start;
|
||||||
|
* if the instance has already started in a different mode (multi/single-session);
|
||||||
|
* if interrupted while waiting for completion
|
||||||
|
*/
|
||||||
|
public static VertxUIServer getInstance(Integer port, boolean multiSession,
|
||||||
|
Function<String, StatsStorage> statsStorageProvider) throws DL4JException {
|
||||||
|
return getInstance(port, multiSession, statsStorageProvider, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static VertxUIServer getInstance(Integer port, boolean multiSession, Function<String, StatsStorage> statsStorageProvider){
|
/**
|
||||||
|
*
|
||||||
|
* Get (and, initialize if necessary) the UI server. This function will wait until the server started
|
||||||
|
* (synchronous way), or pass the given callback to handle success or failure (asynchronous way).
|
||||||
|
* @param port TCP socket port for {@link HttpServer} to listen
|
||||||
|
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
|
||||||
|
* <br/>URL path will include session ID as a parameter, i.e.: /train becomes /train/:sessionId
|
||||||
|
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID.
|
||||||
|
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
|
||||||
|
* as URL path parameter in multi-session mode, or leave it {@code null}.
|
||||||
|
* @param startCallback asynchronous deployment handler callback that will be notify of success or failure.
|
||||||
|
* If {@code null} given, then this method will wait until deployment is complete.
|
||||||
|
* If the deployment is successful the result will contain a String representing the
|
||||||
|
* unique deployment ID of the deployment.
|
||||||
|
* @return UI server instance
|
||||||
|
* @throws DL4JException if UI server failed to start;
|
||||||
|
* if the instance has already started in a different mode (multi/single-session);
|
||||||
|
* if interrupted while waiting for completion
|
||||||
|
*/
|
||||||
|
public static VertxUIServer getInstance(Integer port, boolean multiSession,
|
||||||
|
Function<String, StatsStorage> statsStorageProvider, Promise<String> startCallback)
|
||||||
|
throws DL4JException {
|
||||||
if (instance == null || instance.isStopped()) {
|
if (instance == null || instance.isStopped()) {
|
||||||
VertxUIServer.multiSession.set(multiSession);
|
VertxUIServer.multiSession.set(multiSession);
|
||||||
VertxUIServer.setStatsStorageProvider(statsStorageProvider);
|
VertxUIServer.setStatsStorageProvider(statsStorageProvider);
|
||||||
instancePort = port;
|
instancePort = port;
|
||||||
Vertx vertx = Vertx.vertx();
|
|
||||||
|
|
||||||
//Launch UI server verticle and wait for it to start
|
if (startCallback != null) {
|
||||||
CountDownLatch l = new CountDownLatch(1);
|
//Launch UI server verticle and pass asynchronous callback that will be notified of completion
|
||||||
vertx.deployVerticle(VertxUIServer.class.getName(), res -> {
|
deploy(startCallback);
|
||||||
l.countDown();
|
} else {
|
||||||
});
|
//Launch UI server verticle and wait for it to start
|
||||||
try {
|
deploy();
|
||||||
l.await(5000, TimeUnit.MILLISECONDS);
|
}
|
||||||
} catch (InterruptedException e){ } //Ignore
|
|
||||||
} else if (!instance.isStopped()) {
|
} else if (!instance.isStopped()) {
|
||||||
if (instance.multiSession.get() && !instance.isMultiSession()) {
|
if (multiSession && !instance.isMultiSession()) {
|
||||||
throw new RuntimeException("Cannot return multi-session instance." +
|
throw new DL4JException("Cannot return multi-session instance." +
|
||||||
" UIServer has already started in single-session mode at " + instance.getAddress() +
|
" UIServer has already started in single-session mode at " + instance.getAddress() +
|
||||||
" You may stop the UI server instance, and start a new one.");
|
" You may stop the UI server instance, and start a new one.");
|
||||||
} else if (!instance.multiSession.get() && instance.isMultiSession()) {
|
} else if (!multiSession && instance.isMultiSession()) {
|
||||||
throw new RuntimeException("Cannot return single-session instance." +
|
throw new DL4JException("Cannot return single-session instance." +
|
||||||
" UIServer has already started in multi-session mode at " + instance.getAddress() +
|
" UIServer has already started in multi-session mode at " + instance.getAddress() +
|
||||||
" You may stop the UI server instance, and start a new one.");
|
" You may stop the UI server instance, and start a new one.");
|
||||||
}
|
}
|
||||||
|
@ -109,6 +146,64 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deploy (start) {@link VertxUIServer}, waiting until starting is complete.
|
||||||
|
* @throws DL4JException if UI server failed to start;
|
||||||
|
* if interrupted while waiting for completion
|
||||||
|
*/
|
||||||
|
private static void deploy() throws DL4JException {
|
||||||
|
CountDownLatch l = new CountDownLatch(1);
|
||||||
|
Promise<String> promise = Promise.promise();
|
||||||
|
promise.future().compose(
|
||||||
|
success -> Future.future(prom -> l.countDown()),
|
||||||
|
failure -> Future.future(prom -> l.countDown())
|
||||||
|
);
|
||||||
|
deploy(promise);
|
||||||
|
// synchronous function
|
||||||
|
try {
|
||||||
|
l.await();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
throw new DL4JException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<String> future = promise.future();
|
||||||
|
if (future.failed()) {
|
||||||
|
throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deploy (start) {@link VertxUIServer},
|
||||||
|
* and pass callback to handle successful or failed completion of deployment.
|
||||||
|
* @param startCallback promise that will handle success or failure of deployment.
|
||||||
|
* If the deployment is successful the result will contain a String representing the unique deployment ID of the
|
||||||
|
* deployment.
|
||||||
|
*/
|
||||||
|
private static void deploy(Promise<String> startCallback) {
|
||||||
|
log.debug("Deeplearning4j UI server is starting.");
|
||||||
|
Promise<String> promise = Promise.promise();
|
||||||
|
promise.future().compose(
|
||||||
|
success -> Future.future(prom -> startCallback.complete(success)),
|
||||||
|
failure -> Future.future(prom -> startCallback.fail(new RuntimeException(failure)))
|
||||||
|
);
|
||||||
|
|
||||||
|
Vertx vertx = Vertx.vertx();
|
||||||
|
vertx.deployVerticle(VertxUIServer.class.getName(), promise);
|
||||||
|
|
||||||
|
Thread currentThread = Thread.currentThread();
|
||||||
|
VertxUIServer.autoStopThread = new Thread(() -> {
|
||||||
|
try {
|
||||||
|
currentThread.join();
|
||||||
|
log.info("Deeplearning4j UI server is auto-stopping.");
|
||||||
|
if (VertxUIServer.instance != null && !VertxUIServer.instance.isStopped()) {
|
||||||
|
instance.stop();
|
||||||
|
}
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
log.error("Deeplearning4j UI server auto-stop thread was interrupted.", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
private List<UIModule> uiModules = new CopyOnWriteArrayList<>();
|
||||||
private RemoteReceiverModule remoteReceiverModule;
|
private RemoteReceiverModule remoteReceiverModule;
|
||||||
|
@ -132,10 +227,18 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
instance = this;
|
instance = this;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void stopInstance(){
|
public static void stopInstance() throws Exception {
|
||||||
if(instance == null)
|
if(instance == null || instance.isStopped())
|
||||||
return;
|
return;
|
||||||
instance.stop();
|
instance.stop();
|
||||||
|
VertxUIServer.reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void reset() {
|
||||||
|
VertxUIServer.instance = null;
|
||||||
|
VertxUIServer.statsStorageProvider = null;
|
||||||
|
VertxUIServer.instancePort = null;
|
||||||
|
VertxUIServer.multiSession.set(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -145,12 +248,14 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
||||||
if (statsStorageProvider != null) {
|
if (statsStorageProvider != null) {
|
||||||
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
||||||
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
if (trainModule != null) {
|
||||||
|
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void start() throws Exception {
|
public void start(Promise<Void> startCallback) throws Exception {
|
||||||
//Create REST endpoints
|
//Create REST endpoints
|
||||||
File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
|
File uploadDir = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
|
||||||
uploadDir.mkdirs();
|
uploadDir.mkdirs();
|
||||||
|
@ -192,6 +297,9 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (VertxUIServer.statsStorageProvider != null) {
|
||||||
|
autoAttachStatsStorageBySessionId(VertxUIServer.statsStorageProvider);
|
||||||
|
}
|
||||||
|
|
||||||
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
||||||
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
||||||
|
@ -246,17 +354,23 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
server = vertx.createHttpServer()
|
|
||||||
.requestHandler(r)
|
|
||||||
.listen(port);
|
|
||||||
|
|
||||||
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
|
uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
|
||||||
uiEventRoutingThread.setDaemon(true);
|
uiEventRoutingThread.setDaemon(true);
|
||||||
uiEventRoutingThread.start();
|
uiEventRoutingThread.start();
|
||||||
|
|
||||||
String address = UIServer.getInstance().getAddress();
|
server = vertx.createHttpServer()
|
||||||
log.info("Deeplearning4j UI server started at: {}", address);
|
.requestHandler(r)
|
||||||
|
.listen(port, result -> {
|
||||||
|
if (result.succeeded()) {
|
||||||
|
String address = UIServer.getInstance().getAddress();
|
||||||
|
log.info("Deeplearning4j UI server started at: {}", address);
|
||||||
|
startCallback.complete();
|
||||||
|
} else {
|
||||||
|
startCallback.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port "
|
||||||
|
+ server.actualPort(), result.cause()));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
VertxUIServer.autoStopThread.start();
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
|
private List<String> extractArgsFromRoute(String path, RoutingContext rc) {
|
||||||
|
@ -302,11 +416,33 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void stop() {
|
public void stop() throws InterruptedException {
|
||||||
server.close();
|
CountDownLatch l = new CountDownLatch(1);
|
||||||
shutdown.set(true);
|
Promise<Void> promise = Promise.promise();
|
||||||
|
promise.future().compose(
|
||||||
|
successEvent -> Future.future(prom -> l.countDown()),
|
||||||
|
failureEvent -> Future.future(prom -> l.countDown())
|
||||||
|
);
|
||||||
|
stopAsync(promise);
|
||||||
|
// synchronous function should wait until the server is stopped
|
||||||
|
l.await();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stopAsync(Promise<Void> stopCallback) {
|
||||||
|
/**
|
||||||
|
* Stop Vertx instance and release any resources held by it.
|
||||||
|
* Pass promise to {@link #stop(Promise)}.
|
||||||
|
*/
|
||||||
|
vertx.close(ar -> stopCallback.handle(ar));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void stop(Promise<Void> stopCallback) {
|
||||||
|
shutdown.set(true);
|
||||||
|
stopCallback.complete();
|
||||||
|
log.info("Deeplearning4j UI server stopped.");
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isStopped() {
|
public boolean isStopped() {
|
||||||
|
@ -353,8 +489,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
||||||
if (!statsStorageInstances.contains(statsStorage))
|
if (!statsStorageInstances.contains(statsStorage))
|
||||||
return; //No op
|
return; //No op
|
||||||
boolean found = false;
|
boolean found = false;
|
||||||
for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext(); ) {
|
for (Pair<StatsStorage, StatsStorageListener> p : listeners) {
|
||||||
Pair<StatsStorage, StatsStorageListener> p = iterator.next();
|
|
||||||
if (p.getFirst() == statsStorage) { //Same object, not equality
|
if (p.getFirst() == statsStorage) { //Same object, not equality
|
||||||
statsStorage.deregisterStatsStorageListener(p.getSecond());
|
statsStorage.deregisterStatsStorageListener(p.getSecond());
|
||||||
listeners.remove(p);
|
listeners.remove(p);
|
||||||
|
|
|
@ -17,8 +17,10 @@
|
||||||
|
|
||||||
package org.deeplearning4j.ui.api;
|
package org.deeplearning4j.ui.api;
|
||||||
|
|
||||||
|
import io.vertx.core.Promise;
|
||||||
import org.deeplearning4j.api.storage.StatsStorage;
|
import org.deeplearning4j.api.storage.StatsStorage;
|
||||||
import org.deeplearning4j.api.storage.StatsStorageRouter;
|
import org.deeplearning4j.api.storage.StatsStorageRouter;
|
||||||
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.ui.VertxUIServer;
|
import org.deeplearning4j.ui.VertxUIServer;
|
||||||
import org.nd4j.linalg.function.Function;
|
import org.nd4j.linalg.function.Function;
|
||||||
|
|
||||||
|
@ -32,18 +34,24 @@ import java.util.List;
|
||||||
public interface UIServer {
|
public interface UIServer {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get (and, initialize if necessary) the UI server.
|
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||||
* Singleton pattern - all calls to getInstance() will return the same UI instance.
|
* Singleton pattern - all calls to getInstance() will return the same UI instance.
|
||||||
*
|
*
|
||||||
* @return UI instance for this JVM
|
* @return UI instance for this JVM
|
||||||
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
|
* @throws DL4JException if UI server failed to start;
|
||||||
|
* if the instance has already started in a different mode (multi/single-session);
|
||||||
|
* if interrupted while waiting for completion
|
||||||
*/
|
*/
|
||||||
static UIServer getInstance() throws RuntimeException {
|
static UIServer getInstance() throws DL4JException {
|
||||||
return getInstance(false, null);
|
if (VertxUIServer.getInstance() != null && !VertxUIServer.getInstance().isStopped()) {
|
||||||
|
return VertxUIServer.getInstance();
|
||||||
|
} else {
|
||||||
|
return getInstance(false, null);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get (and, initialize if necessary) the UI server.
|
* Get (and, initialize if necessary) the UI server. This synchronous function will wait until the server started.
|
||||||
* Singleton pattern - all calls to getInstance() will return the same UI instance.
|
* Singleton pattern - all calls to getInstance() will return the same UI instance.
|
||||||
*
|
*
|
||||||
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
|
* @param multiSession in multi-session mode, multiple training sessions can be visualized in separate browser tabs.
|
||||||
|
@ -52,16 +60,19 @@ public interface UIServer {
|
||||||
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
|
* <br/>Use this to auto-attach StatsStorage if an unknown session ID is passed
|
||||||
* as URL path parameter in multi-session mode, or leave it {@code null}.
|
* as URL path parameter in multi-session mode, or leave it {@code null}.
|
||||||
* @return UI instance for this JVM
|
* @return UI instance for this JVM
|
||||||
* @throws RuntimeException if the instance has already started in a different mode (multi/single-session)
|
* @throws DL4JException if UI server failed to start;
|
||||||
|
* if the instance has already started in a different mode (multi/single-session);
|
||||||
|
* if interrupted while waiting for completion
|
||||||
*/
|
*/
|
||||||
static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider) throws RuntimeException {
|
static UIServer getInstance(boolean multiSession, Function<String, StatsStorage> statsStorageProvider)
|
||||||
|
throws DL4JException {
|
||||||
return VertxUIServer.getInstance(null, multiSession, statsStorageProvider);
|
return VertxUIServer.getInstance(null, multiSession, statsStorageProvider);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stop UIServer instance, if already running
|
* Stop UIServer instance, if already running
|
||||||
*/
|
*/
|
||||||
static void stopInstance() {
|
static void stopInstance() throws Exception {
|
||||||
VertxUIServer.stopInstance();
|
VertxUIServer.stopInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,8 +155,15 @@ public interface UIServer {
|
||||||
boolean isRemoteListenerEnabled();
|
boolean isRemoteListenerEnabled();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stop/shut down the UI server.
|
* Stop/shut down the UI server. This synchronous function should wait until the server is stopped.
|
||||||
*/
|
*/
|
||||||
void stop();
|
void stop() throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stop/shut down the UI server.
|
||||||
|
* This asynchronous function should immediately return, and notify the callback {@link Promise} on completion:
|
||||||
|
* either call {@link Promise#complete} or {@link io.vertx.core.Promise#fail}.
|
||||||
|
* @param stopCallback callback {@link Promise} to notify on completion
|
||||||
|
*/
|
||||||
|
void stopAsync(Promise<Void> stopCallback);
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,15 +69,19 @@ public class DefaultI18N implements I18N {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get instance for session (used in multi-session mode)
|
* Get instance for session
|
||||||
* @param sessionId session
|
* @param sessionId session ID for multi-session mode, leave it {@code null} for global instance
|
||||||
* @return instance for session
|
* @return instance for session, or global instance
|
||||||
*/
|
*/
|
||||||
public static synchronized I18N getInstance(String sessionId) {
|
public static synchronized I18N getInstance(String sessionId) {
|
||||||
if (!sessionInstances.containsKey(sessionId)) {
|
if (sessionId == null) {
|
||||||
sessionInstances.put(sessionId, new DefaultI18N());
|
return getInstance();
|
||||||
|
} else {
|
||||||
|
if (!sessionInstances.containsKey(sessionId)) {
|
||||||
|
sessionInstances.put(sessionId, new DefaultI18N());
|
||||||
|
}
|
||||||
|
return sessionInstances.get(sessionId);
|
||||||
}
|
}
|
||||||
return sessionInstances.get(sessionId);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -137,7 +137,7 @@ public class TrainModule implements UIModule {
|
||||||
maxChartPoints = DEFAULT_MAX_CHART_POINTS;
|
maxChartPoints = DEFAULT_MAX_CHART_POINTS;
|
||||||
}
|
}
|
||||||
|
|
||||||
configuration = new Configuration(new Version(2, 3, 29));
|
configuration = new Configuration(new Version(2, 3, 23));
|
||||||
configuration.setDefaultEncoding("UTF-8");
|
configuration.setDefaultEncoding("UTF-8");
|
||||||
configuration.setLocale(Locale.US);
|
configuration.setLocale(Locale.US);
|
||||||
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
|
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
|
||||||
|
@ -199,6 +199,7 @@ public class TrainModule implements UIModule {
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
|
r.add(new Route("/train/:sessionId/info", HttpMethod.GET, (path, rc) -> this.sessionInfoForSession(path.get(0), rc)));
|
||||||
|
r.add(new Route("/train/:sessionId/system/data", HttpMethod.GET, (path, rc) -> this.getSystemDataForSession(path.get(0), rc)));
|
||||||
} else {
|
} else {
|
||||||
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
|
r.add(new Route("/train", HttpMethod.GET, (path, rc) -> rc.reroute("/train/overview")));
|
||||||
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
|
r.add(new Route("/train/sessions/current", HttpMethod.GET, (path, rc) -> rc.response().end(currentSessionID == null ? "" : currentSessionID)));
|
||||||
|
@ -226,7 +227,9 @@ public class TrainModule implements UIModule {
|
||||||
* @param rc Routing context
|
* @param rc Routing context
|
||||||
*/
|
*/
|
||||||
private void renderFtl(String file, RoutingContext rc) {
|
private void renderFtl(String file, RoutingContext rc) {
|
||||||
Map<String, String> input = DefaultI18N.getInstance().getMessages(DefaultI18N.getInstance().getDefaultLanguage());
|
String sessionId = rc.request().getParam("sessionID");
|
||||||
|
String langCode = DefaultI18N.getInstance(sessionId).getDefaultLanguage();
|
||||||
|
Map<String, String> input = DefaultI18N.getInstance().getMessages(langCode);
|
||||||
String html;
|
String html;
|
||||||
try {
|
try {
|
||||||
String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8);
|
String content = FileUtils.readFileToString(Resources.asFile("templates/" + file), StandardCharsets.UTF_8);
|
||||||
|
|
|
@ -17,10 +17,15 @@
|
||||||
|
|
||||||
package org.deeplearning4j.ui;
|
package org.deeplearning4j.ui;
|
||||||
|
|
||||||
|
import io.vertx.core.Future;
|
||||||
|
import io.vertx.core.Promise;
|
||||||
|
import io.vertx.core.Vertx;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.api.storage.StatsStorage;
|
import org.deeplearning4j.api.storage.StatsStorage;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
@ -50,63 +55,31 @@ import java.net.URL;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 08/10/2016.
|
* Created by Alex on 08/10/2016.
|
||||||
*/
|
*/
|
||||||
|
@Slf4j
|
||||||
@Ignore
|
@Ignore
|
||||||
public class TestVertxUI extends BaseDL4JTest {
|
public class TestVertxUI extends BaseDL4JTest {
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
UIServer.stopInstance();
|
UIServer.stopInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
|
||||||
public void testUI() throws Exception {
|
public void testUI() throws Exception {
|
||||||
|
|
||||||
StatsStorage ss = new InMemoryStatsStorage();
|
|
||||||
|
|
||||||
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
|
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
|
||||||
assertEquals(9000, uiServer.getPort());
|
assertEquals(9000, uiServer.getPort());
|
||||||
uiServer.stop();
|
uiServer.stop();
|
||||||
VertxUIServer vertxUIServer = new VertxUIServer();
|
|
||||||
// vertxUIServer.runMain(new String[] {"--uiPort", "9100", "-r", "true"});
|
|
||||||
//
|
|
||||||
// assertEquals(9100, vertxUIServer.getPort());
|
|
||||||
// vertxUIServer.stop();
|
|
||||||
|
|
||||||
|
|
||||||
// uiServer.attach(ss);
|
|
||||||
//
|
|
||||||
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
// .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
||||||
// .list()
|
|
||||||
// .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
|
|
||||||
// .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
|
|
||||||
// .build();
|
|
||||||
//
|
|
||||||
// MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
||||||
// net.init();
|
|
||||||
// net.setListeners(new StatsListener(ss, 3), new ScoreIterationListener(1));
|
|
||||||
//
|
|
||||||
// DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
|
||||||
//
|
|
||||||
// for (int i = 0; i < 500; i++) {
|
|
||||||
// net.fit(iter);
|
|
||||||
//// Thread.sleep(100);
|
|
||||||
// Thread.sleep(100);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//// uiServer.stop();
|
|
||||||
|
|
||||||
Thread.sleep(100000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
|
||||||
public void testUI_VAE() throws Exception {
|
public void testUI_VAE() throws Exception {
|
||||||
//Variational autoencoder - for unsupervised layerwise pretraining
|
//Variational autoencoder - for unsupervised layerwise pretraining
|
||||||
|
|
||||||
|
@ -144,13 +117,9 @@ public class TestVertxUI extends BaseDL4JTest {
|
||||||
Thread.sleep(100);
|
Thread.sleep(100);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Thread.sleep(100000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
|
||||||
public void testUIMultipleSessions() throws Exception {
|
public void testUIMultipleSessions() throws Exception {
|
||||||
|
|
||||||
for (int session = 0; session < 3; session++) {
|
for (int session = 0; session < 3; session++) {
|
||||||
|
@ -178,60 +147,10 @@ public class TestVertxUI extends BaseDL4JTest {
|
||||||
Thread.sleep(100);
|
Thread.sleep(100);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Thread.sleep(1000000);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Ignore
|
|
||||||
public void testUISequentialSessions() throws Exception {
|
|
||||||
UIServer uiServer = UIServer.getInstance();
|
|
||||||
StatsStorage ss = null;
|
|
||||||
for (int session = 0; session < 3; session++) {
|
|
||||||
|
|
||||||
if (ss != null) {
|
|
||||||
uiServer.detach(ss);
|
|
||||||
}
|
|
||||||
ss = new InMemoryStatsStorage();
|
|
||||||
uiServer.attach(ss);
|
|
||||||
|
|
||||||
int numInputs = 4;
|
|
||||||
int outputNum = 3;
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
.activation(Activation.TANH)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.updater(new Sgd(0.03))
|
|
||||||
.l2(1e-4)
|
|
||||||
.list()
|
|
||||||
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
|
|
||||||
.build())
|
|
||||||
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
|
|
||||||
.build())
|
|
||||||
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.activation(Activation.SOFTMAX)
|
|
||||||
.nIn(3).nOut(outputNum).build())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
||||||
net.init();
|
|
||||||
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
|
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
|
||||||
|
|
||||||
for (int i = 0; i < 1000; i++) {
|
|
||||||
net.fit(iter);
|
|
||||||
}
|
|
||||||
Thread.sleep(5000);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
Thread.sleep(1000000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
public void testUICompGraph() {
|
||||||
public void testUICompGraph() throws Exception {
|
|
||||||
|
|
||||||
StatsStorage ss = new InMemoryStatsStorage();
|
StatsStorage ss = new InMemoryStatsStorage();
|
||||||
|
|
||||||
|
@ -254,10 +173,7 @@ public class TestVertxUI extends BaseDL4JTest {
|
||||||
|
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
net.fit(iter);
|
net.fit(iter);
|
||||||
Thread.sleep(100);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Thread.sleep(1000000);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -304,11 +220,11 @@ public class TestVertxUI extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8);
|
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"),
|
||||||
// System.out.println(json1);
|
StandardCharsets.UTF_8);
|
||||||
|
|
||||||
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8);
|
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"),
|
||||||
// System.out.println(json2);
|
StandardCharsets.UTF_8);
|
||||||
|
|
||||||
assertNotEquals(json1, json2);
|
assertNotEquals(json1, json2);
|
||||||
|
|
||||||
|
@ -336,11 +252,106 @@ public class TestVertxUI extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUIServerStop() {
|
public void testUIServerStop() throws Exception {
|
||||||
UIServer uiServer = UIServer.getInstance(true, null);
|
UIServer uiServer = UIServer.getInstance(true, null);
|
||||||
assertTrue(uiServer.isMultiSession());
|
assertTrue(uiServer.isMultiSession());
|
||||||
|
assertFalse(uiServer.isStopped());
|
||||||
|
|
||||||
|
long sleepMilliseconds = 1_000;
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
uiServer.stop();
|
uiServer.stop();
|
||||||
|
assertTrue(uiServer.isStopped());
|
||||||
|
|
||||||
|
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
uiServer = UIServer.getInstance(false, null);
|
uiServer = UIServer.getInstance(false, null);
|
||||||
assertFalse(uiServer.isMultiSession());
|
assertFalse(uiServer.isMultiSession());
|
||||||
|
assertFalse(uiServer.isStopped());
|
||||||
|
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer.stop();
|
||||||
|
assertTrue(uiServer.isStopped());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUIServerStopAsync() throws Exception {
|
||||||
|
UIServer uiServer = UIServer.getInstance(true, null);
|
||||||
|
assertTrue(uiServer.isMultiSession());
|
||||||
|
assertFalse(uiServer.isStopped());
|
||||||
|
|
||||||
|
long sleepMilliseconds = 1_000;
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
|
||||||
|
CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
Promise<Void> promise = Promise.promise();
|
||||||
|
promise.future().compose(
|
||||||
|
success -> Future.future(prom -> latch.countDown()),
|
||||||
|
failure -> Future.future(prom -> latch.countDown())
|
||||||
|
);
|
||||||
|
|
||||||
|
uiServer.stopAsync(promise);
|
||||||
|
latch.await();
|
||||||
|
assertTrue(uiServer.isStopped());
|
||||||
|
|
||||||
|
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer = UIServer.getInstance(false, null);
|
||||||
|
assertFalse(uiServer.isMultiSession());
|
||||||
|
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test (expected = DL4JException.class)
|
||||||
|
public void testUIStartPortAlreadyBound() throws InterruptedException {
|
||||||
|
CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
//Create HttpServer that binds the same port
|
||||||
|
int port = VertxUIServer.DEFAULT_UI_PORT;
|
||||||
|
Vertx vertx = Vertx.vertx();
|
||||||
|
vertx.createHttpServer()
|
||||||
|
.requestHandler(event -> {})
|
||||||
|
.listen(port, result -> latch.countDown());
|
||||||
|
latch.await();
|
||||||
|
|
||||||
|
try {
|
||||||
|
//DL4JException signals that the port cannot be bound, UI server cannot start
|
||||||
|
UIServer.getInstance();
|
||||||
|
} finally {
|
||||||
|
vertx.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUIStartAsync() throws InterruptedException {
|
||||||
|
CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
Promise<String> promise = Promise.promise();
|
||||||
|
promise.future().compose(
|
||||||
|
success -> Future.future(prom -> latch.countDown()),
|
||||||
|
failure -> Future.future(prom -> latch.countDown())
|
||||||
|
);
|
||||||
|
int port = VertxUIServer.DEFAULT_UI_PORT;
|
||||||
|
VertxUIServer.getInstance(port, false, null, promise);
|
||||||
|
latch.await();
|
||||||
|
if (promise.future().succeeded()) {
|
||||||
|
String deploymentId = promise.future().result();
|
||||||
|
log.debug("UI server deployed, deployment ID = {}", deploymentId);
|
||||||
|
} else {
|
||||||
|
log.debug("UI server failed to deploy.", promise.future().cause());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testUIAutoStopOnThreadExit() throws InterruptedException {
|
||||||
|
AtomicReference<UIServer> uiServer = new AtomicReference<>();
|
||||||
|
Thread thread = new Thread(() -> uiServer.set(UIServer.getInstance()));
|
||||||
|
thread.start();
|
||||||
|
thread.join();
|
||||||
|
Thread.sleep(1_000);
|
||||||
|
assertTrue(uiServer.get().isStopped());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,272 @@
|
||||||
|
package org.deeplearning4j.ui;
|
||||||
|
|
||||||
|
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
|
import io.vertx.core.Future;
|
||||||
|
import io.vertx.core.Promise;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.api.storage.StatsStorage;
|
||||||
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
import org.deeplearning4j.ui.api.UIServer;
|
||||||
|
import org.deeplearning4j.ui.stats.StatsListener;
|
||||||
|
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
|
||||||
|
import org.junit.Ignore;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
import org.nd4j.linalg.function.Function;
|
||||||
|
import org.nd4j.linalg.learning.config.Sgd;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
import java.io.UnsupportedEncodingException;
|
||||||
|
import java.net.HttpURLConnection;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.net.URLEncoder;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
|
||||||
|
import static org.junit.Assert.*;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Ignore
|
||||||
|
public class TestVertxUIManual extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 3600_000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testUI() throws Exception {
|
||||||
|
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
|
||||||
|
assertEquals(9000, uiServer.getPort());
|
||||||
|
|
||||||
|
Thread.sleep(3000_000);
|
||||||
|
uiServer.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testUISequentialSessions() throws Exception {
|
||||||
|
UIServer uiServer = UIServer.getInstance();
|
||||||
|
StatsStorage ss = null;
|
||||||
|
for (int session = 0; session < 3; session++) {
|
||||||
|
|
||||||
|
if (ss != null) {
|
||||||
|
uiServer.detach(ss);
|
||||||
|
}
|
||||||
|
ss = new InMemoryStatsStorage();
|
||||||
|
uiServer.attach(ss);
|
||||||
|
|
||||||
|
int numInputs = 4;
|
||||||
|
int outputNum = 3;
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.activation(Activation.TANH)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.updater(new Sgd(0.03))
|
||||||
|
.l2(1e-4)
|
||||||
|
.list()
|
||||||
|
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3)
|
||||||
|
.build())
|
||||||
|
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3)
|
||||||
|
.build())
|
||||||
|
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nIn(3).nOut(outputNum).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
|
||||||
|
|
||||||
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
for (int i = 0; i < 100; i++) {
|
||||||
|
net.fit(iter);
|
||||||
|
}
|
||||||
|
Thread.sleep(5_000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testUIServerStop() throws Exception {
|
||||||
|
UIServer uiServer = UIServer.getInstance(true, null);
|
||||||
|
assertTrue(uiServer.isMultiSession());
|
||||||
|
assertFalse(uiServer.isStopped());
|
||||||
|
|
||||||
|
long sleepMilliseconds = 30_000;
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer.stop();
|
||||||
|
assertTrue(uiServer.isStopped());
|
||||||
|
|
||||||
|
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer = UIServer.getInstance(false, null);
|
||||||
|
assertFalse(uiServer.isMultiSession());
|
||||||
|
assertFalse(uiServer.isStopped());
|
||||||
|
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer.stop();
|
||||||
|
assertTrue(uiServer.isStopped());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testUIServerStopAsync() throws Exception {
|
||||||
|
UIServer uiServer = UIServer.getInstance(true, null);
|
||||||
|
assertTrue(uiServer.isMultiSession());
|
||||||
|
assertFalse(uiServer.isStopped());
|
||||||
|
|
||||||
|
long sleepMilliseconds = 30_000;
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
|
||||||
|
CountDownLatch latch = new CountDownLatch(1);
|
||||||
|
Promise<Void> promise = Promise.promise();
|
||||||
|
promise.future().compose(
|
||||||
|
success -> Future.future(prom -> latch.countDown()),
|
||||||
|
failure -> Future.future(prom -> latch.countDown())
|
||||||
|
);
|
||||||
|
|
||||||
|
uiServer.stopAsync(promise);
|
||||||
|
latch.await();
|
||||||
|
assertTrue(uiServer.isStopped());
|
||||||
|
|
||||||
|
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer = UIServer.getInstance(false, null);
|
||||||
|
assertFalse(uiServer.isMultiSession());
|
||||||
|
|
||||||
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
||||||
|
Thread.sleep(sleepMilliseconds);
|
||||||
|
uiServer.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testUIAutoAttachDetach() throws Exception {
|
||||||
|
long detachTimeoutMillis = 15_000;
|
||||||
|
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(detachTimeoutMillis);
|
||||||
|
UIServer uIServer = UIServer.getInstance(true, statsProvider);
|
||||||
|
statsProvider.setUIServer(uIServer);
|
||||||
|
InMemoryStatsStorage ss = null;
|
||||||
|
for (int session = 0; session < 3; session++) {
|
||||||
|
int layerSize = session + 4;
|
||||||
|
|
||||||
|
ss = new InMemoryStatsStorage();
|
||||||
|
String sessionId = Integer.toString(session);
|
||||||
|
statsProvider.put(sessionId, ss);
|
||||||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
||||||
|
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
|
||||||
|
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||||
|
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
|
net.init();
|
||||||
|
|
||||||
|
StatsListener statsListener = new StatsListener(ss, 1);
|
||||||
|
statsListener.setSessionID(sessionId);
|
||||||
|
net.setListeners(statsListener, new ScoreIterationListener(1));
|
||||||
|
uIServer.attach(ss);
|
||||||
|
|
||||||
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||||
|
|
||||||
|
for (int i = 0; i < 20; i++) {
|
||||||
|
net.fit(iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertTrue(uIServer.isAttached(ss));
|
||||||
|
uIServer.detach(ss);
|
||||||
|
assertFalse(uIServer.isAttached(ss));
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Visiting /train/:sessionId to auto-attach StatsStorage
|
||||||
|
*/
|
||||||
|
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
|
||||||
|
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
|
||||||
|
conn.connect();
|
||||||
|
|
||||||
|
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||||
|
assertTrue(uIServer.isAttached(ss));
|
||||||
|
}
|
||||||
|
|
||||||
|
Thread.sleep(detachTimeoutMillis + 60_000);
|
||||||
|
assertFalse(uIServer.isAttached(ss));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get URL-encoded URL for training session on given server address
|
||||||
|
* @param serverAddress server address
|
||||||
|
* @param sessionId session ID
|
||||||
|
* @return URL
|
||||||
|
* @throws UnsupportedEncodingException if the used encoding is not supported
|
||||||
|
*/
|
||||||
|
private static String trainingSessionUrl(String serverAddress, String sessionId)
|
||||||
|
throws UnsupportedEncodingException {
|
||||||
|
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
|
||||||
|
* @author Tamas Fenyvesi
|
||||||
|
*/
|
||||||
|
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
|
||||||
|
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
|
||||||
|
UIServer uIServer;
|
||||||
|
long autoDetachTimeoutMillis;
|
||||||
|
|
||||||
|
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
|
||||||
|
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
|
||||||
|
storageForSession.put(sessionId, statsStorage);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setUIServer(UIServer uIServer) {
|
||||||
|
this.uIServer = uIServer;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StatsStorage apply(String sessionId) {
|
||||||
|
StatsStorage statsStorage = storageForSession.get(sessionId);
|
||||||
|
|
||||||
|
if (statsStorage != null) {
|
||||||
|
new Thread(() -> {
|
||||||
|
try {
|
||||||
|
log.info("Waiting to detach StatsStorage (session ID: {})" +
|
||||||
|
" after {} ms ", sessionId, autoDetachTimeoutMillis);
|
||||||
|
Thread.sleep(autoDetachTimeoutMillis);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
} finally {
|
||||||
|
log.info("Auto-detaching StatsStorage (session ID: {}) after {} ms.",
|
||||||
|
sessionId, autoDetachTimeoutMillis);
|
||||||
|
uIServer.detach(statsStorage);
|
||||||
|
log.info(" To re-attach StatsStorage of training session, visit {}}/train/{}",
|
||||||
|
uIServer.getAddress(), sessionId);
|
||||||
|
}
|
||||||
|
}).start();
|
||||||
|
}
|
||||||
|
|
||||||
|
return statsStorage;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.api.storage.StatsStorage;
|
import org.deeplearning4j.api.storage.StatsStorage;
|
||||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||||
|
import org.deeplearning4j.exception.DL4JException;
|
||||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
@ -33,7 +34,6 @@ import org.deeplearning4j.ui.api.UIServer;
|
||||||
import org.deeplearning4j.ui.stats.StatsListener;
|
import org.deeplearning4j.ui.stats.StatsListener;
|
||||||
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
|
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
@ -53,22 +53,22 @@ import static org.junit.Assert.*;
|
||||||
/**
|
/**
|
||||||
* @author Tamas Fenyvesi
|
* @author Tamas Fenyvesi
|
||||||
*/
|
*/
|
||||||
@Ignore
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestVertxUIMultiSession extends BaseDL4JTest {
|
public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
UIServer.stopInstance();
|
UIServer.stopInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUIMultiSession() throws Exception {
|
public void testUIMultiSessionParallelTraining() throws Exception {
|
||||||
|
|
||||||
UIServer uIServer = UIServer.getInstance(true, null);
|
UIServer uIServer = UIServer.getInstance(true, null);
|
||||||
HashMap<Thread, StatsStorage> statStorageForThread = new HashMap<>();
|
HashMap<Thread, StatsStorage> statStorageForThread = new HashMap<>();
|
||||||
HashMap<Thread, String> sessionIdForThread = new HashMap<>();
|
HashMap<Thread, String> sessionIdForThread = new HashMap<>();
|
||||||
|
|
||||||
for (int session = 0; session < 10; session++) {
|
int parallelTrainingCount = 10;
|
||||||
|
for (int session = 0; session < parallelTrainingCount; session++) {
|
||||||
|
|
||||||
StatsStorage ss = new InMemoryStatsStorage();
|
StatsStorage ss = new InMemoryStatsStorage();
|
||||||
|
|
||||||
|
@ -106,8 +106,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
sessionIdForThread.put(training, sessionId);
|
sessionIdForThread.put(training, sessionId);
|
||||||
}
|
}
|
||||||
|
|
||||||
Thread.sleep(10000000);
|
|
||||||
|
|
||||||
for (Thread thread: statStorageForThread.keySet()) {
|
for (Thread thread: statStorageForThread.keySet()) {
|
||||||
StatsStorage ss = statStorageForThread.get(thread);
|
StatsStorage ss = statStorageForThread.get(thread);
|
||||||
String sessionId = sessionIdForThread.get(thread);
|
String sessionId = sessionIdForThread.get(thread);
|
||||||
|
@ -122,7 +120,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||||
assertTrue(uIServer.isAttached(ss));
|
assertTrue(uIServer.isAttached(ss));
|
||||||
} catch (InterruptedException | IOException e) {
|
} catch (IOException e) {
|
||||||
log.error("",e);
|
log.error("",e);
|
||||||
fail(e.getMessage());
|
fail(e.getMessage());
|
||||||
} finally {
|
} finally {
|
||||||
|
@ -130,7 +128,6 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
assertFalse(uIServer.isAttached(ss));
|
assertFalse(uIServer.isAttached(ss));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -183,61 +180,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test (expected = DL4JException.class)
|
||||||
@Ignore
|
|
||||||
public void testUIAutoAttachDetach() throws Exception {
|
|
||||||
|
|
||||||
long autoDetachTimeoutMillis = 30_000;
|
|
||||||
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(autoDetachTimeoutMillis);
|
|
||||||
UIServer uIServer = UIServer.getInstance(true, statsProvider);
|
|
||||||
statsProvider.setUIServer(uIServer);
|
|
||||||
|
|
||||||
for (int session = 0; session < 3; session++) {
|
|
||||||
int layerSize = session + 4;
|
|
||||||
|
|
||||||
InMemoryStatsStorage ss = new InMemoryStatsStorage();
|
|
||||||
String sessionId = Integer.toString(session);
|
|
||||||
statsProvider.put(sessionId, ss);
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
|
||||||
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
|
|
||||||
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
|
||||||
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
|
|
||||||
.build();
|
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
||||||
net.init();
|
|
||||||
|
|
||||||
StatsListener statsListener = new StatsListener(ss, 1);
|
|
||||||
statsListener.setSessionID(sessionId);
|
|
||||||
net.setListeners(statsListener, new ScoreIterationListener(1));
|
|
||||||
uIServer.attach(ss);
|
|
||||||
|
|
||||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
|
||||||
|
|
||||||
for (int i = 0; i < 20; i++) {
|
|
||||||
net.fit(iter);
|
|
||||||
}
|
|
||||||
|
|
||||||
assertTrue(uIServer.isAttached(ss));
|
|
||||||
uIServer.detach(ss);
|
|
||||||
assertFalse(uIServer.isAttached(ss));
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Visiting /train/:sessionId to auto-attach StatsStorage
|
|
||||||
*/
|
|
||||||
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
|
|
||||||
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
|
|
||||||
conn.connect();
|
|
||||||
|
|
||||||
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
|
||||||
assertTrue(uIServer.isAttached(ss));
|
|
||||||
}
|
|
||||||
|
|
||||||
Thread.sleep(1_000_000);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test (expected = RuntimeException.class)
|
|
||||||
public void testUIServerGetInstanceMultipleCalls1() {
|
public void testUIServerGetInstanceMultipleCalls1() {
|
||||||
UIServer uiServer = UIServer.getInstance();
|
UIServer uiServer = UIServer.getInstance();
|
||||||
assertFalse(uiServer.isMultiSession());
|
assertFalse(uiServer.isMultiSession());
|
||||||
|
@ -245,7 +188,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test (expected = RuntimeException.class)
|
@Test (expected = DL4JException.class)
|
||||||
public void testUIServerGetInstanceMultipleCalls2() {
|
public void testUIServerGetInstanceMultipleCalls2() {
|
||||||
UIServer uiServer = UIServer.getInstance(true, null);
|
UIServer uiServer = UIServer.getInstance(true, null);
|
||||||
assertTrue(uiServer.isMultiSession());
|
assertTrue(uiServer.isMultiSession());
|
||||||
|
@ -259,55 +202,8 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
* @return URL
|
* @return URL
|
||||||
* @throws UnsupportedEncodingException if the used encoding is not supported
|
* @throws UnsupportedEncodingException if the used encoding is not supported
|
||||||
*/
|
*/
|
||||||
private static String trainingSessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
|
private static String trainingSessionUrl(String serverAddress, String sessionId)
|
||||||
|
throws UnsupportedEncodingException {
|
||||||
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
return String.format("%s/train/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
|
|
||||||
* @author fenyvesit
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private static class AutoDetachingStatsStorageProvider implements Function<String, StatsStorage> {
|
|
||||||
HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
|
|
||||||
UIServer uIServer;
|
|
||||||
long autoDetachTimeoutMillis;
|
|
||||||
|
|
||||||
public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
|
|
||||||
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void put(String sessionId, InMemoryStatsStorage statsStorage) {
|
|
||||||
storageForSession.put(sessionId, statsStorage);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setUIServer(UIServer uIServer) {
|
|
||||||
this.uIServer = uIServer;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public StatsStorage apply(String sessionId) {
|
|
||||||
StatsStorage statsStorage = storageForSession.get(sessionId);
|
|
||||||
|
|
||||||
if (statsStorage != null) {
|
|
||||||
new Thread(() -> {
|
|
||||||
try {
|
|
||||||
System.out.println("Waiting to detach StatsStorage (session ID: " + sessionId + ")" +
|
|
||||||
" after " + autoDetachTimeoutMillis + " ms ");
|
|
||||||
Thread.sleep(autoDetachTimeoutMillis);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
log.error("",e);
|
|
||||||
} finally {
|
|
||||||
System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " +
|
|
||||||
autoDetachTimeoutMillis + " ms.");
|
|
||||||
uIServer.detach(statsStorage);
|
|
||||||
System.out.println(" To re-attach StatsStorage of training session, visit " +
|
|
||||||
uIServer.getAddress() + "/train/" + sessionId);
|
|
||||||
}
|
|
||||||
}).start();
|
|
||||||
}
|
|
||||||
|
|
||||||
return statsStorage;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue