tests for auto-attach and manual attach; improve JavaDoc
Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>master
parent
55be669069
commit
cf24728f35
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.arbiter.optimize;
|
package org.deeplearning4j.arbiter.optimize;
|
||||||
|
|
||||||
|
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
@ -65,12 +66,17 @@ import org.nd4j.linalg.function.Function;
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.util.Collections;
|
import java.io.IOException;
|
||||||
import java.util.HashMap;
|
import java.io.UnsupportedEncodingException;
|
||||||
import java.util.Map;
|
import java.net.HttpURLConnection;
|
||||||
import java.util.Properties;
|
import java.net.URL;
|
||||||
|
import java.net.URLEncoder;
|
||||||
|
import java.util.*;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Created by Alex on 19/07/2017.
|
* Created by Alex on 19/07/2017.
|
||||||
*/
|
*/
|
||||||
|
@ -88,112 +94,15 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
UIServer.getInstance();
|
UIServer.getInstance();
|
||||||
|
|
||||||
Thread.sleep(1000000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
@Ignore
|
|
||||||
public void testBasicUiMultiSession() throws Exception {
|
|
||||||
|
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
|
||||||
|
|
||||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
|
||||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
|
||||||
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
|
||||||
.addLayer(
|
|
||||||
new ConvolutionLayerSpace.Builder().nIn(1)
|
|
||||||
.nOut(new IntegerParameterSpace(5, 30))
|
|
||||||
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
|
||||||
new int[]{4, 4}, new int[]{5, 5}))
|
|
||||||
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
|
||||||
new int[]{2, 2}))
|
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
|
||||||
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
|
||||||
.build())
|
|
||||||
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
|
||||||
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
|
||||||
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
|
||||||
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
|
||||||
.build();
|
|
||||||
Map<String, Object> commands = new HashMap<>();
|
|
||||||
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
|
||||||
|
|
||||||
//Define configuration:
|
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
|
||||||
DataProvider dataProvider = new MnistDataSetProvider();
|
|
||||||
|
|
||||||
|
|
||||||
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
|
|
||||||
|
|
||||||
File f = new File(modelSavePath);
|
|
||||||
if (f.exists())
|
|
||||||
f.delete();
|
|
||||||
f.mkdir();
|
|
||||||
if (!f.exists())
|
|
||||||
throw new RuntimeException();
|
|
||||||
|
|
||||||
OptimizationConfiguration configuration =
|
|
||||||
new OptimizationConfiguration.Builder()
|
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
|
||||||
.modelSaver(new FileModelSaver(modelSavePath))
|
|
||||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
|
||||||
.terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES),
|
|
||||||
new MaxCandidatesCondition(100))
|
|
||||||
.build();
|
|
||||||
|
|
||||||
IOptimizationRunner runner =
|
|
||||||
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
|
||||||
|
|
||||||
// add 3 different sessions to the same execution
|
|
||||||
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
|
||||||
for (int i = 0; i < 3; i++) {
|
|
||||||
StatsStorage ss = new InMemoryStatsStorage();
|
|
||||||
@NonNull String sessionId = "sid" + i;
|
|
||||||
statsStorageForSession.put(sessionId, ss);
|
|
||||||
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
|
||||||
runner.addListeners(sl);
|
|
||||||
}
|
|
||||||
|
|
||||||
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
|
||||||
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
|
||||||
String serverAddress = uIServer.getAddress();
|
|
||||||
for (String sessionId : statsStorageForSession.keySet()) {
|
|
||||||
log.info("Arbiter session can be attached at {}/arbiter/{}", serverAddress, sessionId);
|
|
||||||
}
|
|
||||||
|
|
||||||
runner.execute();
|
|
||||||
|
|
||||||
Thread.sleep(1000000);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testBasicMnist() throws Exception {
|
public void testBasicMnist() throws Exception {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
|
|
||||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
|
||||||
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
|
||||||
.addLayer(
|
|
||||||
new ConvolutionLayerSpace.Builder().nIn(1)
|
|
||||||
.nOut(new IntegerParameterSpace(5, 30))
|
|
||||||
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
|
||||||
new int[]{4, 4}, new int[]{5, 5}))
|
|
||||||
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
|
||||||
new int[]{2, 2}))
|
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
|
||||||
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
|
||||||
.build())
|
|
||||||
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
|
||||||
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
|
||||||
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
|
||||||
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
|
||||||
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
|
||||||
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
|
||||||
.build();
|
|
||||||
Map<String, Object> commands = new HashMap<>();
|
Map<String, Object> commands = new HashMap<>();
|
||||||
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
||||||
|
|
||||||
|
@ -230,7 +139,30 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static MultiLayerSpace getMultiLayerSpaceMnist() {
|
||||||
|
return new MultiLayerSpace.Builder()
|
||||||
|
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||||
|
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
||||||
|
.addLayer(
|
||||||
|
new ConvolutionLayerSpace.Builder().nIn(1)
|
||||||
|
.nOut(new IntegerParameterSpace(5, 30))
|
||||||
|
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
|
||||||
|
new int[]{4, 4}, new int[]{5, 5}))
|
||||||
|
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
|
||||||
|
new int[]{2, 2}))
|
||||||
|
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
||||||
|
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
||||||
|
.build())
|
||||||
|
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
|
||||||
|
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
||||||
|
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
|
||||||
|
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
||||||
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
||||||
|
.setInputType(InputType.convolutionalFlat(28, 28, 1))
|
||||||
|
.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -319,7 +251,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Define configuration:
|
//Define configuration:
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
|
||||||
DataProvider dataProvider = new MnistDataSetProvider();
|
DataProvider dataProvider = new MnistDataSetProvider();
|
||||||
|
|
||||||
|
|
||||||
|
@ -417,7 +349,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -482,7 +414,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -519,7 +451,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
//Define configuration:
|
//Define configuration:
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
|
||||||
DataProvider dataProvider = new MnistDataSetProvider();
|
DataProvider dataProvider = new MnistDataSetProvider();
|
||||||
|
|
||||||
|
|
||||||
|
@ -551,13 +483,17 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
UIServer.getInstance().attach(ss);
|
UIServer.getInstance().attach(ss);
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Visualize multiple optimization sessions run one after another on single-session mode UI
|
||||||
|
* @throws InterruptedException if current thread has been interrupted
|
||||||
|
*/
|
||||||
@Test
|
@Test
|
||||||
@Ignore
|
@Ignore
|
||||||
public void testBasicMnistMultipleSessions() throws Exception {
|
public void testBasicMnistMultipleSessions() throws InterruptedException {
|
||||||
|
|
||||||
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
||||||
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
||||||
|
@ -585,8 +521,10 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
//Define configuration:
|
//Define configuration:
|
||||||
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
||||||
DataProvider dataProvider = new MnistDataSetProvider();
|
|
||||||
|
|
||||||
|
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||||
|
Properties dsp = new Properties();
|
||||||
|
dsp.setProperty("minibatch", "8");
|
||||||
|
|
||||||
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
|
||||||
|
|
||||||
|
@ -599,7 +537,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
OptimizationConfiguration configuration =
|
OptimizationConfiguration configuration =
|
||||||
new OptimizationConfiguration.Builder()
|
new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
.modelSaver(new FileModelSaver(modelSavePath))
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
||||||
|
@ -621,7 +559,7 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
candidateGenerator = new RandomSearchGenerator(mls, commands);
|
candidateGenerator = new RandomSearchGenerator(mls, commands);
|
||||||
configuration = new OptimizationConfiguration.Builder()
|
configuration = new OptimizationConfiguration.Builder()
|
||||||
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
.modelSaver(new FileModelSaver(modelSavePath))
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
.scoreFunction(new TestSetLossScoreFunction(true))
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
|
||||||
|
@ -636,7 +574,148 @@ public class TestBasic extends BaseDL4JTest {
|
||||||
|
|
||||||
runner.execute();
|
runner.execute();
|
||||||
|
|
||||||
Thread.sleep(100000);
|
Thread.sleep(1000_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Auto-attach multiple optimization sessions to multi-session mode UI
|
||||||
|
* @throws IOException if could not connect to the server
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testUiMultiSessionAutoAttach() throws IOException {
|
||||||
|
|
||||||
|
//Define configuration:
|
||||||
|
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||||
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
|
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||||
|
Properties dsp = new Properties();
|
||||||
|
dsp.setProperty("minibatch", "8");
|
||||||
|
|
||||||
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\")
|
||||||
|
.getAbsolutePath();
|
||||||
|
|
||||||
|
File f = new File(modelSavePath);
|
||||||
|
if (f.exists())
|
||||||
|
f.delete();
|
||||||
|
f.mkdir();
|
||||||
|
if (!f.exists())
|
||||||
|
throw new RuntimeException();
|
||||||
|
|
||||||
|
OptimizationConfiguration configuration =
|
||||||
|
new OptimizationConfiguration.Builder()
|
||||||
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
|
.terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS),
|
||||||
|
new MaxCandidatesCondition(1))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
IOptimizationRunner runner =
|
||||||
|
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
||||||
|
|
||||||
|
// add 3 different sessions to the same execution
|
||||||
|
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
StatsStorage ss = new InMemoryStatsStorage();
|
||||||
|
@NonNull String sessionId = "sid" + i;
|
||||||
|
statsStorageForSession.put(sessionId, ss);
|
||||||
|
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
||||||
|
runner.addListeners(sl);
|
||||||
|
}
|
||||||
|
|
||||||
|
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
||||||
|
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
||||||
|
String serverAddress = uIServer.getAddress();
|
||||||
|
|
||||||
|
runner.execute();
|
||||||
|
|
||||||
|
for (String sessionId : statsStorageForSession.keySet()) {
|
||||||
|
/*
|
||||||
|
* Visiting /arbiter/:sessionId to auto-attach StatsStorage
|
||||||
|
*/
|
||||||
|
String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId);
|
||||||
|
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
|
||||||
|
conn.connect();
|
||||||
|
|
||||||
|
log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId));
|
||||||
|
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
|
||||||
|
assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL
|
||||||
|
* @throws Exception if an error occurred
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
@Ignore
|
||||||
|
public void testUiMultiSessionManualAttach() throws Exception {
|
||||||
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
|
|
||||||
|
//Define configuration:
|
||||||
|
MultiLayerSpace mls = getMultiLayerSpaceMnist();
|
||||||
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
|
||||||
|
|
||||||
|
Class<? extends DataSource> ds = MnistDataSource.class;
|
||||||
|
Properties dsp = new Properties();
|
||||||
|
dsp.setProperty("minibatch", "8");
|
||||||
|
|
||||||
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\")
|
||||||
|
.getAbsolutePath();
|
||||||
|
|
||||||
|
File f = new File(modelSavePath);
|
||||||
|
if (f.exists())
|
||||||
|
f.delete();
|
||||||
|
f.mkdir();
|
||||||
|
if (!f.exists())
|
||||||
|
throw new RuntimeException();
|
||||||
|
|
||||||
|
OptimizationConfiguration configuration =
|
||||||
|
new OptimizationConfiguration.Builder()
|
||||||
|
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
|
||||||
|
.modelSaver(new FileModelSaver(modelSavePath))
|
||||||
|
.scoreFunction(new TestSetLossScoreFunction(true))
|
||||||
|
.terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES),
|
||||||
|
new MaxCandidatesCondition(10))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
|
||||||
|
// parallel execution of multiple optimization sessions
|
||||||
|
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
String sessionId = "sid" + i;
|
||||||
|
IOptimizationRunner runner =
|
||||||
|
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
||||||
|
StatsStorage ss = new InMemoryStatsStorage();
|
||||||
|
statsStorageForSession.put(sessionId, ss);
|
||||||
|
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
|
||||||
|
runner.addListeners(sl);
|
||||||
|
// Asynchronous execution
|
||||||
|
new Thread(runner::execute).start();
|
||||||
|
}
|
||||||
|
|
||||||
|
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
|
||||||
|
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
|
||||||
|
String serverAddress = uIServer.getAddress();
|
||||||
|
|
||||||
|
for (String sessionId : statsStorageForSession.keySet()) {
|
||||||
|
log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId));
|
||||||
|
}
|
||||||
|
|
||||||
|
Thread.sleep(1000_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get URL for arbiter session on given server address
|
||||||
|
* @param serverAddress server address, e.g.: http://localhost:9000
|
||||||
|
* @param sessionId session ID (will be URL-encoded)
|
||||||
|
* @return URL
|
||||||
|
* @throws UnsupportedEncodingException if the character encoding is not supported
|
||||||
|
*/
|
||||||
|
private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
|
||||||
|
return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class MnistDataSetProvider implements DataProvider {
|
private static class MnistDataSetProvider implements DataProvider {
|
||||||
|
|
|
@ -197,9 +197,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get URL-encoded URL for training session on given server address
|
* Get URL for training session on given server address
|
||||||
* @param serverAddress server address
|
* @param serverAddress server address
|
||||||
* @param sessionId session ID
|
* @param sessionId session ID (will be URL-encoded)
|
||||||
* @return URL
|
* @return URL
|
||||||
* @throws UnsupportedEncodingException if the used encoding is not supported
|
* @throws UnsupportedEncodingException if the used encoding is not supported
|
||||||
*/
|
*/
|
||||||
|
|
Loading…
Reference in New Issue