Re-add UI auto-attach functionality with test; also fixes (#80)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-25 22:52:02 +11:00 committed by GitHub
parent 0e3fcdc24d
commit 4b50b920c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 112 additions and 6 deletions

View File

@ -58,8 +58,20 @@ public class StatsListener extends BaseStatsListener {
* @param listenerFrequency Frequency with which to collect stats information
*/
public StatsListener(StatsStorageRouter router, int listenerFrequency) {
this(router, listenerFrequency, null);
}
/**
* Create a StatsListener with network information collected every n >= 1 time steps
*
* @param router Where/how to store the calculated stats. For example, {@link org.deeplearning4j.ui.storage.InMemoryStatsStorage} or
* {@link org.deeplearning4j.ui.storage.FileStatsStorage}
* @param listenerFrequency Frequency with which to collect stats information
* @param sessionId The Session ID for storing the stats, optional (may be null)
*/
public StatsListener(StatsStorageRouter router, int listenerFrequency, String sessionId) {
this(router, null, new DefaultStatsUpdateConfiguration.Builder().reportingFrequency(listenerFrequency).build(),
null, null);
sessionId, null);
}
public StatsListener(StatsStorageRouter router, StatsInitializationConfiguration initConfig,

View File

@ -73,6 +73,8 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
private static Integer instancePort;
private TrainModule trainModule;
public static VertxUIServer getInstance() {
return getInstance(null, multiSession.get(), null);
}
@ -136,6 +138,17 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
instance.stop();
}
/**
* Auto-attach StatsStorage if an unknown session ID is passed as URL path parameter in multi-session mode
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID
*/
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
if (statsStorageProvider != null) {
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
this.trainModule.setSessionLoader(this.statsStorageLoader);
}
}
@Override
public void start() throws Exception {
//Create REST endpoints
@ -181,7 +194,8 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
uiModules.add(new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress));
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
uiModules.add(trainModule);
uiModules.add(new ConvolutionalListenerModule());
uiModules.add(new TsneModule());
uiModules.add(new SameDiffModule());

View File

@ -26,6 +26,8 @@ import io.vertx.ext.web.RoutingContext;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
@ -98,7 +100,8 @@ public class TrainModule implements UIModule {
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
private final boolean multiSession;
private final Function<String, Boolean> sessionLoader;
@Getter @Setter
private Function<String, Boolean> sessionLoader;
private final Configuration configuration;
@ -172,7 +175,13 @@ public class TrainModule implements UIModule {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> getOverviewDataForSession(path.get(0), rc)));
r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
getOverviewDataForSession(path.get(0), rc);
} else {
sessionNotFound(path.get(0), rc.request().path(), rc);
}
}));
r.add(new Route("/train/:sessionId/model", HttpMethod.GET, (path, rc) -> {
if (knownSessionIDs.containsKey(path.get(0))) {
renderFtl("TrainingModel.html.ftl", rc);
@ -275,11 +284,10 @@ public class TrainModule implements UIModule {
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
if (sessionLoader != null && sessionLoader.apply(sessionId)) {
if (targetPath != null) {
rc.reroute("./" + targetPath);
rc.reroute(targetPath);
} else {
rc.response().end();
}
} else {
rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
.end("Unknown session ID: " + sessionId);

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.ui;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
@ -39,9 +41,15 @@ import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import static org.junit.Assert.*;
/**
@ -251,6 +259,70 @@ public class TestVertxUI {
Thread.sleep(1000000);
}
@Test
public void testAutoAttach() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
"in")
.addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
.setOutputs("L1").build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
StatsStorage ss1 = new InMemoryStatsStorage();
net.setListeners(new StatsListener(ss1, 1, "ss1"));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 5; i++) {
net.fit(iter);
}
StatsStorage ss2 = new InMemoryStatsStorage();
net.setListeners(new StatsListener(ss2, 1, "ss2"));
for (int i = 0; i < 4; i++) {
net.fit(iter);
}
UIServer ui = UIServer.getInstance(true, null);
try {
((VertxUIServer) ui).autoAttachStatsStorageBySessionId(new Function<String, StatsStorage>() {
@Override
public StatsStorage apply(String s) {
if ("ss1".equals(s)) {
return ss1;
} else if ("ss2".equals(s)) {
return ss2;
}
return null;
}
});
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8);
// System.out.println(json1);
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8);
// System.out.println(json2);
assertNotEquals(json1, json2);
Map<String, Object> m1 = JsonMappers.getMapper().readValue(json1, Map.class);
Map<String, Object> m2 = JsonMappers.getMapper().readValue(json2, Map.class);
List<Object> s1 = (List<Object>) m1.get("scores");
List<Object> s2 = (List<Object>) m2.get("scores");
assertEquals(5, s1.size());
assertEquals(4, s2.size());
} finally {
ui.stop();
}
}
@Test
public void testUIAttachDetach() throws Exception {
StatsStorage ss = new InMemoryStatsStorage();