From 4b50b920c7d75895649353182c154c2e19bf3489 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 25 Nov 2019 22:52:02 +1100 Subject: [PATCH] Re-add UI auto-attach functionality with test; also fixes (#80) Signed-off-by: AlexDBlack --- .../ui/stats/StatsListener.java | 14 +++- .../org/deeplearning4j/ui/VertxUIServer.java | 16 ++++- .../ui/module/train/TrainModule.java | 16 +++-- .../org/deeplearning4j/ui/TestVertxUI.java | 72 +++++++++++++++++++ 4 files changed, 112 insertions(+), 6 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/StatsListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/StatsListener.java index a59271bdb..94bafa8c1 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/StatsListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/stats/StatsListener.java @@ -58,8 +58,20 @@ public class StatsListener extends BaseStatsListener { * @param listenerFrequency Frequency with which to collect stats information */ public StatsListener(StatsStorageRouter router, int listenerFrequency) { + this(router, listenerFrequency, null); + } + + /** + * Create a StatsListener with network information collected every n >= 1 time steps + * + * @param router Where/how to store the calculated stats. For example, {@link org.deeplearning4j.ui.storage.InMemoryStatsStorage} or + * {@link org.deeplearning4j.ui.storage.FileStatsStorage} + * @param listenerFrequency Frequency with which to collect stats information + * @param sessionId The Session ID for storing the stats, optional (may be null) + */ + public StatsListener(StatsStorageRouter router, int listenerFrequency, String sessionId) { this(router, null, new DefaultStatsUpdateConfiguration.Builder().reportingFrequency(listenerFrequency).build(), - null, null); + sessionId, null); } public StatsListener(StatsStorageRouter router, StatsInitializationConfiguration initConfig, diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java index 81d6ecb0b..2aec66a77 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java @@ -73,6 +73,8 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { private static Integer instancePort; + private TrainModule trainModule; + public static VertxUIServer getInstance() { return getInstance(null, multiSession.get(), null); } @@ -136,6 +138,17 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { instance.stop(); } + /** + * Auto-attach StatsStorage if an unknown session ID is passed as URL path parameter in multi-session mode + * @param statsStorageProvider function that returns a StatsStorage containing the given session ID + */ + public void autoAttachStatsStorageBySessionId(Function statsStorageProvider) { + if (statsStorageProvider != null) { + this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider); + this.trainModule.setSessionLoader(this.statsStorageLoader); + } + } + @Override public void start() throws Exception { //Create REST endpoints @@ -181,7 +194,8 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/" - uiModules.add(new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress)); + trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress); + uiModules.add(trainModule); uiModules.add(new ConvolutionalListenerModule()); uiModules.add(new TsneModule()); uiModules.add(new SameDiffModule()); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 3c6b1f366..5648de738 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -26,6 +26,8 @@ import io.vertx.ext.web.RoutingContext; import it.unimi.dsi.fastutil.longs.LongArrayList; import lombok.AllArgsConstructor; import lombok.Data; +import lombok.Getter; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; @@ -98,7 +100,8 @@ public class TrainModule implements UIModule { private Map> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID private Map lastUpdateForSession = new ConcurrentHashMap<>(); private final boolean multiSession; - private final Function sessionLoader; + @Getter @Setter + private Function sessionLoader; private final Configuration configuration; @@ -172,7 +175,13 @@ public class TrainModule implements UIModule { sessionNotFound(path.get(0), rc.request().path(), rc); } })); - r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> getOverviewDataForSession(path.get(0), rc))); + r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> { + if (knownSessionIDs.containsKey(path.get(0))) { + getOverviewDataForSession(path.get(0), rc); + } else { + sessionNotFound(path.get(0), rc.request().path(), rc); + } + })); r.add(new Route("/train/:sessionId/model", HttpMethod.GET, (path, rc) -> { if (knownSessionIDs.containsKey(path.get(0))) { renderFtl("TrainingModel.html.ftl", rc); @@ -275,11 +284,10 @@ public class TrainModule implements UIModule { private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { if (sessionLoader != null && sessionLoader.apply(sessionId)) { if (targetPath != null) { - rc.reroute("./" + targetPath); + rc.reroute(targetPath); } else { rc.response().end(); } - } else { rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code()) .end("Unknown session ID: " + sessionId); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java index a60147d0f..fc5c3c4ac 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java @@ -17,6 +17,7 @@ package org.deeplearning4j.ui; +import org.apache.commons.io.IOUtils; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; @@ -27,6 +28,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; +import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; @@ -39,9 +41,15 @@ import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.function.Function; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + import static org.junit.Assert.*; /** @@ -251,6 +259,70 @@ public class TestVertxUI { Thread.sleep(1000000); } + @Test + public void testAutoAttach() throws Exception { + + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + .addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(), + "in") + .addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0") + .setOutputs("L1").build(); + + ComputationGraph net = new ComputationGraph(conf); + net.init(); + + StatsStorage ss1 = new InMemoryStatsStorage(); + + net.setListeners(new StatsListener(ss1, 1, "ss1")); + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + + for (int i = 0; i < 5; i++) { + net.fit(iter); + } + + StatsStorage ss2 = new InMemoryStatsStorage(); + net.setListeners(new StatsListener(ss2, 1, "ss2")); + + for (int i = 0; i < 4; i++) { + net.fit(iter); + } + + UIServer ui = UIServer.getInstance(true, null); + try { + ((VertxUIServer) ui).autoAttachStatsStorageBySessionId(new Function() { + @Override + public StatsStorage apply(String s) { + if ("ss1".equals(s)) { + return ss1; + } else if ("ss2".equals(s)) { + return ss2; + } + return null; + } + }); + + String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8); +// System.out.println(json1); + + String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8); +// System.out.println(json2); + + assertNotEquals(json1, json2); + + Map m1 = JsonMappers.getMapper().readValue(json1, Map.class); + Map m2 = JsonMappers.getMapper().readValue(json2, Map.class); + + List s1 = (List) m1.get("scores"); + List s2 = (List) m2.get("scores"); + assertEquals(5, s1.size()); + assertEquals(4, s2.size()); + } finally { + ui.stop(); + } + } + @Test public void testUIAttachDetach() throws Exception { StatsStorage ss = new InMemoryStatsStorage();