2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
|
* ******************************************************************************
|
|
|
|
|
* *
|
|
|
|
|
* *
|
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
|
* * under the License.
|
|
|
|
|
* *
|
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
* *****************************************************************************
|
|
|
|
|
*/
|
2019-06-06 15:21:15 +03:00
|
|
|
|
2019-11-22 23:50:34 +11:00
|
|
|
package org.deeplearning4j.ui;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
2020-04-23 02:26:51 +02:00
|
|
|
import io.vertx.core.Future;
|
|
|
|
|
import io.vertx.core.Promise;
|
|
|
|
|
import io.vertx.core.Vertx;
|
2022-09-20 15:40:53 +02:00
|
|
|
import lombok.extern.slf4j.Slf4j;
|
2019-11-25 22:52:02 +11:00
|
|
|
import org.apache.commons.io.IOUtils;
|
2020-01-04 13:45:07 +11:00
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.deeplearning4j.core.storage.StatsStorage;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
2020-04-23 02:26:51 +02:00
|
|
|
import org.deeplearning4j.exception.DL4JException;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|
|
|
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
|
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
|
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
|
|
|
|
|
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;
|
2019-11-25 22:52:02 +11:00
|
|
|
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
|
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
|
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
|
|
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
|
|
|
|
import org.deeplearning4j.ui.api.UIServer;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.deeplearning4j.ui.model.stats.StatsListener;
|
|
|
|
|
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
|
2022-09-20 15:40:53 +02:00
|
|
|
import org.junit.jupiter.api.AfterAll;
|
2021-03-16 11:57:24 +09:00
|
|
|
import org.junit.jupiter.api.BeforeEach;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2021-03-16 11:57:24 +09:00
|
|
|
import org.junit.jupiter.api.Test;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.activations.Activation;
|
|
|
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.nd4j.common.function.Function;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.learning.config.Sgd;
|
|
|
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
|
|
2019-11-25 22:52:02 +11:00
|
|
|
import java.net.URL;
|
|
|
|
|
import java.nio.charset.StandardCharsets;
|
|
|
|
|
import java.util.List;
|
|
|
|
|
import java.util.Map;
|
2020-04-23 02:26:51 +02:00
|
|
|
import java.util.concurrent.CountDownLatch;
|
2022-09-20 15:40:53 +02:00
|
|
|
import java.util.concurrent.atomic.AtomicReference;
|
2019-11-25 22:52:02 +11:00
|
|
|
|
2021-03-16 11:57:24 +09:00
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
@Slf4j
|
|
|
|
|
//@Ignore
|
2020-01-04 13:45:07 +11:00
|
|
|
public class TestVertxUI extends BaseDL4JTest {
|
2021-03-16 11:57:24 +09:00
|
|
|
|
|
|
|
|
@BeforeEach
|
2019-06-06 15:21:15 +03:00
|
|
|
public void setUp() throws Exception {
|
|
|
|
|
UIServer.stopInstance();
|
|
|
|
|
}
|
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
@AfterAll
|
|
|
|
|
public void shutdownServer() throws InterruptedException {
|
|
|
|
|
UIServer.getInstance().stop();
|
|
|
|
|
}
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
@Test
|
|
|
|
|
public void testUI() throws Exception {
|
2019-11-22 23:50:34 +11:00
|
|
|
VertxUIServer uiServer = (VertxUIServer) UIServer.getInstance();
|
2019-06-06 15:21:15 +03:00
|
|
|
assertEquals(9000, uiServer.getPort());
|
|
|
|
|
uiServer.stop();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testUI_VAE() throws Exception {
|
|
|
|
|
//Variational autoencoder - for unsupervised layerwise pretraining
|
|
|
|
|
|
|
|
|
|
StatsStorage ss = new InMemoryStatsStorage();
|
|
|
|
|
|
|
|
|
|
UIServer uiServer = UIServer.getInstance();
|
|
|
|
|
uiServer.attach(ss);
|
|
|
|
|
|
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
|
|
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
|
|
|
.updater(new Sgd(1e-5))
|
|
|
|
|
.list().layer(0,
|
|
|
|
|
new VariationalAutoencoder.Builder().nIn(4).nOut(3).encoderLayerSizes(10, 11)
|
|
|
|
|
.decoderLayerSizes(12, 13).weightInit(WeightInit.XAVIER)
|
|
|
|
|
.pzxActivationFunction(Activation.IDENTITY)
|
|
|
|
|
.reconstructionDistribution(
|
|
|
|
|
new GaussianReconstructionDistribution())
|
|
|
|
|
.activation(Activation.LEAKYRELU).build())
|
|
|
|
|
.layer(1, new VariationalAutoencoder.Builder().nIn(3).nOut(3).encoderLayerSizes(7)
|
|
|
|
|
.decoderLayerSizes(8).weightInit(WeightInit.XAVIER)
|
|
|
|
|
.pzxActivationFunction(Activation.IDENTITY)
|
|
|
|
|
.reconstructionDistribution(new GaussianReconstructionDistribution())
|
|
|
|
|
.activation(Activation.LEAKYRELU).build())
|
|
|
|
|
.layer(2, new OutputLayer.Builder().nIn(3).nOut(3).build())
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
|
|
|
net.init();
|
|
|
|
|
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
|
|
|
|
|
|
|
|
|
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 50; i++) {
|
|
|
|
|
net.fit(iter);
|
|
|
|
|
Thread.sleep(100);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testUIMultipleSessions() throws Exception {
|
|
|
|
|
|
|
|
|
|
for (int session = 0; session < 3; session++) {
|
|
|
|
|
|
|
|
|
|
StatsStorage ss = new InMemoryStatsStorage();
|
|
|
|
|
|
|
|
|
|
UIServer uiServer = UIServer.getInstance();
|
|
|
|
|
uiServer.attach(ss);
|
|
|
|
|
|
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
|
|
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
|
|
|
|
|
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
|
|
|
|
|
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
|
|
|
|
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
|
|
|
net.init();
|
|
|
|
|
net.setListeners(new StatsListener(ss, 1), new ScoreIterationListener(1));
|
|
|
|
|
|
|
|
|
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 20; i++) {
|
|
|
|
|
net.fit(iter);
|
|
|
|
|
Thread.sleep(100);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2020-04-23 02:26:51 +02:00
|
|
|
public void testUICompGraph() {
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
StatsStorage ss = new InMemoryStatsStorage();
|
|
|
|
|
|
|
|
|
|
UIServer uiServer = UIServer.getInstance();
|
|
|
|
|
uiServer.attach(ss);
|
|
|
|
|
|
|
|
|
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
|
|
|
|
|
.addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
|
|
|
|
|
"in")
|
|
|
|
|
.addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
|
|
|
|
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
|
|
|
|
|
.setOutputs("L1").build();
|
|
|
|
|
|
|
|
|
|
ComputationGraph net = new ComputationGraph(conf);
|
|
|
|
|
net.init();
|
|
|
|
|
|
|
|
|
|
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
|
|
|
|
|
|
|
|
|
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 100; i++) {
|
|
|
|
|
net.fit(iter);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-11-25 22:52:02 +11:00
|
|
|
@Test
|
|
|
|
|
public void testAutoAttach() throws Exception {
|
|
|
|
|
|
|
|
|
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
|
|
|
|
|
.addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
|
|
|
|
|
"in")
|
|
|
|
|
.addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
|
|
|
|
|
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
|
|
|
|
|
.setOutputs("L1").build();
|
|
|
|
|
|
|
|
|
|
ComputationGraph net = new ComputationGraph(conf);
|
|
|
|
|
net.init();
|
|
|
|
|
|
|
|
|
|
StatsStorage ss1 = new InMemoryStatsStorage();
|
|
|
|
|
|
|
|
|
|
net.setListeners(new StatsListener(ss1, 1, "ss1"));
|
|
|
|
|
|
|
|
|
|
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 5; i++) {
|
|
|
|
|
net.fit(iter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
StatsStorage ss2 = new InMemoryStatsStorage();
|
|
|
|
|
net.setListeners(new StatsListener(ss2, 1, "ss2"));
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < 4; i++) {
|
|
|
|
|
net.fit(iter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
UIServer ui = UIServer.getInstance(true, null);
|
|
|
|
|
try {
|
|
|
|
|
((VertxUIServer) ui).autoAttachStatsStorageBySessionId(new Function<String, StatsStorage>() {
|
|
|
|
|
@Override
|
|
|
|
|
public StatsStorage apply(String s) {
|
|
|
|
|
if ("ss1".equals(s)) {
|
|
|
|
|
return ss1;
|
|
|
|
|
} else if ("ss2".equals(s)) {
|
|
|
|
|
return ss2;
|
|
|
|
|
}
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
|
2020-04-23 02:26:51 +02:00
|
|
|
String json1 = IOUtils.toString(new URL("http://localhost:9000/train/ss1/overview/data"),
|
|
|
|
|
StandardCharsets.UTF_8);
|
2019-11-25 22:52:02 +11:00
|
|
|
|
2020-04-23 02:26:51 +02:00
|
|
|
String json2 = IOUtils.toString(new URL("http://localhost:9000/train/ss2/overview/data"),
|
|
|
|
|
StandardCharsets.UTF_8);
|
2019-11-25 22:52:02 +11:00
|
|
|
|
|
|
|
|
assertNotEquals(json1, json2);
|
|
|
|
|
|
|
|
|
|
Map<String, Object> m1 = JsonMappers.getMapper().readValue(json1, Map.class);
|
|
|
|
|
Map<String, Object> m2 = JsonMappers.getMapper().readValue(json2, Map.class);
|
|
|
|
|
|
|
|
|
|
List<Object> s1 = (List<Object>) m1.get("scores");
|
|
|
|
|
List<Object> s2 = (List<Object>) m2.get("scores");
|
|
|
|
|
assertEquals(5, s1.size());
|
|
|
|
|
assertEquals(4, s2.size());
|
|
|
|
|
} finally {
|
|
|
|
|
ui.stop();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
@Test
|
|
|
|
|
public void testUIAttachDetach() throws Exception {
|
|
|
|
|
StatsStorage ss = new InMemoryStatsStorage();
|
|
|
|
|
|
|
|
|
|
UIServer uiServer = UIServer.getInstance();
|
|
|
|
|
uiServer.attach(ss);
|
|
|
|
|
assertFalse(uiServer.getStatsStorageInstances().isEmpty());
|
|
|
|
|
uiServer.detach(ss);
|
|
|
|
|
assertTrue(uiServer.getStatsStorageInstances().isEmpty());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2020-04-23 02:26:51 +02:00
|
|
|
public void testUIServerStop() throws Exception {
|
2019-06-06 15:21:15 +03:00
|
|
|
UIServer uiServer = UIServer.getInstance(true, null);
|
|
|
|
|
assertTrue(uiServer.isMultiSession());
|
2020-04-23 02:26:51 +02:00
|
|
|
assertFalse(uiServer.isStopped());
|
|
|
|
|
|
|
|
|
|
long sleepMilliseconds = 1_000;
|
|
|
|
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
|
|
|
|
Thread.sleep(sleepMilliseconds);
|
2019-06-06 15:21:15 +03:00
|
|
|
uiServer.stop();
|
2020-04-23 02:26:51 +02:00
|
|
|
assertTrue(uiServer.isStopped());
|
|
|
|
|
|
|
|
|
|
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
|
|
|
|
Thread.sleep(sleepMilliseconds);
|
|
|
|
|
uiServer = UIServer.getInstance(false, null);
|
|
|
|
|
assertFalse(uiServer.isMultiSession());
|
|
|
|
|
assertFalse(uiServer.isStopped());
|
|
|
|
|
|
|
|
|
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
|
|
|
|
Thread.sleep(sleepMilliseconds);
|
|
|
|
|
uiServer.stop();
|
|
|
|
|
assertTrue(uiServer.isStopped());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testUIServerStopAsync() throws Exception {
|
|
|
|
|
UIServer uiServer = UIServer.getInstance(true, null);
|
|
|
|
|
assertTrue(uiServer.isMultiSession());
|
|
|
|
|
assertFalse(uiServer.isStopped());
|
|
|
|
|
|
|
|
|
|
long sleepMilliseconds = 1_000;
|
|
|
|
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
|
|
|
|
Thread.sleep(sleepMilliseconds);
|
|
|
|
|
|
|
|
|
|
CountDownLatch latch = new CountDownLatch(1);
|
|
|
|
|
Promise<Void> promise = Promise.promise();
|
|
|
|
|
promise.future().compose(
|
|
|
|
|
success -> Future.future(prom -> latch.countDown()),
|
|
|
|
|
failure -> Future.future(prom -> latch.countDown())
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
uiServer.stopAsync(promise);
|
|
|
|
|
latch.await();
|
|
|
|
|
assertTrue(uiServer.isStopped());
|
|
|
|
|
|
|
|
|
|
log.info("UI server is stopped. Waiting {} ms before starting new UI server.", sleepMilliseconds);
|
|
|
|
|
Thread.sleep(sleepMilliseconds);
|
2019-06-06 15:21:15 +03:00
|
|
|
uiServer = UIServer.getInstance(false, null);
|
|
|
|
|
assertFalse(uiServer.isMultiSession());
|
2020-04-23 02:26:51 +02:00
|
|
|
|
|
|
|
|
log.info("Waiting {} ms before stopping.", sleepMilliseconds);
|
|
|
|
|
Thread.sleep(sleepMilliseconds);
|
|
|
|
|
uiServer.stop();
|
|
|
|
|
}
|
|
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
@Test
|
2020-04-23 02:26:51 +02:00
|
|
|
public void testUIStartPortAlreadyBound() throws InterruptedException {
|
2022-09-20 15:40:53 +02:00
|
|
|
assertThrows(DL4JException.class, () -> {
|
2021-03-16 11:57:24 +09:00
|
|
|
CountDownLatch latch = new CountDownLatch(1);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2021-03-16 11:57:24 +09:00
|
|
|
//Create HttpServer that binds the same port
|
|
|
|
|
int port = VertxUIServer.DEFAULT_UI_PORT;
|
|
|
|
|
Vertx vertx = Vertx.vertx();
|
|
|
|
|
vertx.createHttpServer()
|
2022-09-20 15:40:53 +02:00
|
|
|
.requestHandler(event -> {
|
|
|
|
|
})
|
2021-03-16 11:57:24 +09:00
|
|
|
.listen(port, result -> latch.countDown());
|
|
|
|
|
latch.await();
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
//DL4JException signals that the port cannot be bound, UI server cannot start
|
|
|
|
|
UIServer.getInstance();
|
|
|
|
|
} finally {
|
|
|
|
|
vertx.close();
|
|
|
|
|
}
|
|
|
|
|
});
|
2020-04-23 02:26:51 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
|
|
|
|
public void testUIStartAsync() throws InterruptedException {
|
|
|
|
|
CountDownLatch latch = new CountDownLatch(1);
|
|
|
|
|
Promise<String> promise = Promise.promise();
|
|
|
|
|
promise.future().compose(
|
|
|
|
|
success -> Future.future(prom -> latch.countDown()),
|
|
|
|
|
failure -> Future.future(prom -> latch.countDown())
|
|
|
|
|
);
|
|
|
|
|
int port = VertxUIServer.DEFAULT_UI_PORT;
|
|
|
|
|
VertxUIServer.getInstance(port, false, null, promise);
|
|
|
|
|
latch.await();
|
|
|
|
|
if (promise.future().succeeded()) {
|
|
|
|
|
String deploymentId = promise.future().result();
|
|
|
|
|
log.debug("UI server deployed, deployment ID = {}", deploymentId);
|
|
|
|
|
} else {
|
|
|
|
|
log.debug("UI server failed to deploy.", promise.future().cause());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2020-07-25 13:30:49 +02:00
|
|
|
public void testUIShutdownHook() throws InterruptedException {
|
|
|
|
|
UIServer uIServer = UIServer.getInstance();
|
|
|
|
|
Thread shutdownHook = UIServer.getShutdownHook();
|
|
|
|
|
shutdownHook.start();
|
|
|
|
|
shutdownHook.join();
|
|
|
|
|
/*
|
|
|
|
|
* running the shutdown hook thread before the Runtime is terminated
|
|
|
|
|
* enables us to check if the UI server has been shut down or not
|
|
|
|
|
*/
|
|
|
|
|
assertTrue(uIServer.isStopped());
|
|
|
|
|
log.info("Deeplearning4j UI server stopped in shutdown hook.");
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
}
|