Re-add UI auto-attach functionality with test; also fixes (#80)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
0e3fcdc24d
commit
4b50b920c7
|
@ -58,8 +58,20 @@ public class StatsListener extends BaseStatsListener {
|
|||
* @param listenerFrequency Frequency with which to collect stats information
|
||||
*/
|
||||
public StatsListener(StatsStorageRouter router, int listenerFrequency) {
|
||||
this(router, listenerFrequency, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a StatsListener with network information collected every n >= 1 time steps
|
||||
*
|
||||
* @param router Where/how to store the calculated stats. For example, {@link org.deeplearning4j.ui.storage.InMemoryStatsStorage} or
|
||||
* {@link org.deeplearning4j.ui.storage.FileStatsStorage}
|
||||
* @param listenerFrequency Frequency with which to collect stats information
|
||||
* @param sessionId The Session ID for storing the stats, optional (may be null)
|
||||
*/
|
||||
public StatsListener(StatsStorageRouter router, int listenerFrequency, String sessionId) {
|
||||
this(router, null, new DefaultStatsUpdateConfiguration.Builder().reportingFrequency(listenerFrequency).build(),
|
||||
null, null);
|
||||
sessionId, null);
|
||||
}
|
||||
|
||||
public StatsListener(StatsStorageRouter router, StatsInitializationConfiguration initConfig,
|
||||
|
|
|
@ -73,6 +73,8 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
|
||||
private static Integer instancePort;
|
||||
|
||||
private TrainModule trainModule;
|
||||
|
||||
public static VertxUIServer getInstance() {
|
||||
return getInstance(null, multiSession.get(), null);
|
||||
}
|
||||
|
@ -136,6 +138,17 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
instance.stop();
|
||||
}
|
||||
|
||||
/**
|
||||
* Auto-attach StatsStorage if an unknown session ID is passed as URL path parameter in multi-session mode
|
||||
* @param statsStorageProvider function that returns a StatsStorage containing the given session ID
|
||||
*/
|
||||
public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> statsStorageProvider) {
|
||||
if (statsStorageProvider != null) {
|
||||
this.statsStorageLoader = new StatsStorageLoader(statsStorageProvider);
|
||||
this.trainModule.setSessionLoader(this.statsStorageLoader);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
//Create REST endpoints
|
||||
|
@ -181,7 +194,8 @@ public class VertxUIServer extends AbstractVerticle implements UIServer {
|
|||
|
||||
|
||||
uiModules.add(new DefaultModule(isMultiSession())); //For: navigation page "/"
|
||||
uiModules.add(new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress));
|
||||
trainModule = new TrainModule(isMultiSession(), statsStorageLoader, this::getAddress);
|
||||
uiModules.add(trainModule);
|
||||
uiModules.add(new ConvolutionalListenerModule());
|
||||
uiModules.add(new TsneModule());
|
||||
uiModules.add(new SameDiffModule());
|
||||
|
|
|
@ -26,6 +26,8 @@ import io.vertx.ext.web.RoutingContext;
|
|||
import it.unimi.dsi.fastutil.longs.LongArrayList;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
|
@ -98,7 +100,8 @@ public class TrainModule implements UIModule {
|
|||
private Map<String, Map<Integer, String>> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID
|
||||
private Map<String, Long> lastUpdateForSession = new ConcurrentHashMap<>();
|
||||
private final boolean multiSession;
|
||||
private final Function<String, Boolean> sessionLoader;
|
||||
@Getter @Setter
|
||||
private Function<String, Boolean> sessionLoader;
|
||||
|
||||
|
||||
private final Configuration configuration;
|
||||
|
@ -172,7 +175,13 @@ public class TrainModule implements UIModule {
|
|||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> getOverviewDataForSession(path.get(0), rc)));
|
||||
r.add(new Route("/train/:sessionId/overview/data", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
getOverviewDataForSession(path.get(0), rc);
|
||||
} else {
|
||||
sessionNotFound(path.get(0), rc.request().path(), rc);
|
||||
}
|
||||
}));
|
||||
r.add(new Route("/train/:sessionId/model", HttpMethod.GET, (path, rc) -> {
|
||||
if (knownSessionIDs.containsKey(path.get(0))) {
|
||||
renderFtl("TrainingModel.html.ftl", rc);
|
||||
|
@ -275,11 +284,10 @@ public class TrainModule implements UIModule {
|
|||
private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) {
|
||||
if (sessionLoader != null && sessionLoader.apply(sessionId)) {
|
||||
if (targetPath != null) {
|
||||
rc.reroute("./" + targetPath);
|
||||
rc.reroute(targetPath);
|
||||
} else {
|
||||
rc.response().end();
|
||||
}
|
||||
|
||||
} else {
|
||||
rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code())
|
||||
.end("Unknown session ID: " + sessionId);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
package org.deeplearning4j.ui;
|
||||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.deeplearning4j.api.storage.StatsStorage;
|
||||
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
||||
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||||
|
@ -27,6 +28,7 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|||
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
||||
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
|
@ -39,9 +41,15 @@ import org.junit.Ignore;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.function.Function;
|
||||
import org.nd4j.linalg.learning.config.Sgd;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
/**
|
||||
|
@ -251,6 +259,70 @@ public class TestVertxUI {
|
|||
Thread.sleep(1000000);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAutoAttach() throws Exception {
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
|
||||
.addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
|
||||
"in")
|
||||
.addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
||||
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
|
||||
.setOutputs("L1").build();
|
||||
|
||||
ComputationGraph net = new ComputationGraph(conf);
|
||||
net.init();
|
||||
|
||||
StatsStorage ss1 = new InMemoryStatsStorage();
|
||||
|
||||
net.setListeners(new StatsListener(ss1, 1, "ss1"));
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
|
||||
for (int i = 0; i < 5; i++) {
|
||||
net.fit(iter);
|
||||
}
|
||||
|
||||
StatsStorage ss2 = new InMemoryStatsStorage();
|
||||
net.setListeners(new StatsListener(ss2, 1, "ss2"));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
net.fit(iter);
|
||||
}
|
||||
|
||||
UIServer ui = UIServer.getInstance(true, null);
|
||||
try {
|
||||
((VertxUIServer) ui).autoAttachStatsStorageBySessionId(new Function<String, StatsStorage>() {
|
||||
@Override
|
||||
public StatsStorage apply(String s) {
|
||||
if ("ss1".equals(s)) {
|
||||
return ss1;
|
||||
} else if ("ss2".equals(s)) {
|
||||
return ss2;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
});
|
||||
|
||||
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"), StandardCharsets.UTF_8);
|
||||
// System.out.println(json1);
|
||||
|
||||
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"), StandardCharsets.UTF_8);
|
||||
// System.out.println(json2);
|
||||
|
||||
assertNotEquals(json1, json2);
|
||||
|
||||
Map<String, Object> m1 = JsonMappers.getMapper().readValue(json1, Map.class);
|
||||
Map<String, Object> m2 = JsonMappers.getMapper().readValue(json2, Map.class);
|
||||
|
||||
List<Object> s1 = (List<Object>) m1.get("scores");
|
||||
List<Object> s2 = (List<Object>) m2.get("scores");
|
||||
assertEquals(5, s1.size());
|
||||
assertEquals(4, s2.size());
|
||||
} finally {
|
||||
ui.stop();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testUIAttachDetach() throws Exception {
|
||||
StatsStorage ss = new InMemoryStatsStorage();
|
||||
|
|
Loading…
Reference in New Issue