tests for auto-attach and manual attach; improve JavaDoc

Signed-off-by: Tamás Fenyvesi <tamas.fenyvesi@doknet.hu>
master
Tamás Fenyvesi 2020-05-06 14:56:22 +02:00
parent 55be669069
commit cf24728f35
2 changed files with 195 additions and 116 deletions

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.arbiter.optimize; package org.deeplearning4j.arbiter.optimize;
import io.netty.handler.codec.http.HttpResponseStatus;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
@ -65,12 +66,17 @@ import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File; import java.io.File;
import java.util.Collections; import java.io.IOException;
import java.util.HashMap; import java.io.UnsupportedEncodingException;
import java.util.Map; import java.net.HttpURLConnection;
import java.util.Properties; import java.net.URL;
import java.net.URLEncoder;
import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/** /**
* Created by Alex on 19/07/2017. * Created by Alex on 19/07/2017.
*/ */
@ -88,112 +94,15 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance(); UIServer.getInstance();
Thread.sleep(1000000); Thread.sleep(1000_000);
} }
@Test
@Ignore
public void testBasicUiMultiSession() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
MultiLayerSpace mls = new MultiLayerSpace.Builder()
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
.l2(new ContinuousParameterSpace(0.0001, 0.05))
.addLayer(
new ConvolutionLayerSpace.Builder().nIn(1)
.nOut(new IntegerParameterSpace(5, 30))
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
new int[]{4, 4}, new int[]{5, 5}))
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
new int[]{2, 2}))
.activation(new DiscreteParameterSpace<>(Activation.RELU,
Activation.SOFTPLUS, Activation.LEAKYRELU))
.build())
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
Map<String, Object> commands = new HashMap<>();
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
//Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
DataProvider dataProvider = new MnistDataSetProvider();
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
File f = new File(modelSavePath);
if (f.exists())
f.delete();
f.mkdir();
if (!f.exists())
throw new RuntimeException();
OptimizationConfiguration configuration =
new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider)
.modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES),
new MaxCandidatesCondition(100))
.build();
IOptimizationRunner runner =
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
// add 3 different sessions to the same execution
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
for (int i = 0; i < 3; i++) {
StatsStorage ss = new InMemoryStatsStorage();
@NonNull String sessionId = "sid" + i;
statsStorageForSession.put(sessionId, ss);
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
runner.addListeners(sl);
}
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
String serverAddress = uIServer.getAddress();
for (String sessionId : statsStorageForSession.keySet()) {
log.info("Arbiter session can be attached at {}/arbiter/{}", serverAddress, sessionId);
}
runner.execute();
Thread.sleep(1000000);
}
@Test @Test
@Ignore @Ignore
public void testBasicMnist() throws Exception { public void testBasicMnist() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
MultiLayerSpace mls = new MultiLayerSpace.Builder() MultiLayerSpace mls = getMultiLayerSpaceMnist();
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
.l2(new ContinuousParameterSpace(0.0001, 0.05))
.addLayer(
new ConvolutionLayerSpace.Builder().nIn(1)
.nOut(new IntegerParameterSpace(5, 30))
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
new int[]{4, 4}, new int[]{5, 5}))
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
new int[]{2, 2}))
.activation(new DiscreteParameterSpace<>(Activation.RELU,
Activation.SOFTPLUS, Activation.LEAKYRELU))
.build())
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
Map<String, Object> commands = new HashMap<>(); Map<String, Object> commands = new HashMap<>();
// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); // commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
@ -230,7 +139,30 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
}
private static MultiLayerSpace getMultiLayerSpaceMnist() {
return new MultiLayerSpace.Builder()
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
.l2(new ContinuousParameterSpace(0.0001, 0.05))
.addLayer(
new ConvolutionLayerSpace.Builder().nIn(1)
.nOut(new IntegerParameterSpace(5, 30))
.kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3},
new int[]{4, 4}, new int[]{5, 5}))
.stride(new DiscreteParameterSpace<>(new int[]{1, 1},
new int[]{2, 2}))
.activation(new DiscreteParameterSpace<>(Activation.RELU,
Activation.SOFTPLUS, Activation.LEAKYRELU))
.build())
.addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128))
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
.build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.build();
} }
@Test @Test
@ -319,7 +251,7 @@ public class TestBasic extends BaseDL4JTest {
.build(); .build();
//Define configuration: //Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
DataProvider dataProvider = new MnistDataSetProvider(); DataProvider dataProvider = new MnistDataSetProvider();
@ -417,7 +349,7 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
} }
@ -482,7 +414,7 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
} }
@ -519,7 +451,7 @@ public class TestBasic extends BaseDL4JTest {
.build(); .build();
//Define configuration: //Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, Collections.EMPTY_MAP); CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs);
DataProvider dataProvider = new MnistDataSetProvider(); DataProvider dataProvider = new MnistDataSetProvider();
@ -551,13 +483,17 @@ public class TestBasic extends BaseDL4JTest {
UIServer.getInstance().attach(ss); UIServer.getInstance().attach(ss);
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
} }
/**
* Visualize multiple optimization sessions run one after another on single-session mode UI
* @throws InterruptedException if current thread has been interrupted
*/
@Test @Test
@Ignore @Ignore
public void testBasicMnistMultipleSessions() throws Exception { public void testBasicMnistMultipleSessions() throws InterruptedException {
MultiLayerSpace mls = new MultiLayerSpace.Builder() MultiLayerSpace mls = new MultiLayerSpace.Builder()
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
@ -585,8 +521,10 @@ public class TestBasic extends BaseDL4JTest {
//Define configuration: //Define configuration:
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
DataProvider dataProvider = new MnistDataSetProvider();
Class<? extends DataSource> ds = MnistDataSource.class;
Properties dsp = new Properties();
dsp.setProperty("minibatch", "8");
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath();
@ -599,7 +537,7 @@ public class TestBasic extends BaseDL4JTest {
OptimizationConfiguration configuration = OptimizationConfiguration configuration =
new OptimizationConfiguration.Builder() new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider) .candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath)) .modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true)) .scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
@ -621,7 +559,7 @@ public class TestBasic extends BaseDL4JTest {
candidateGenerator = new RandomSearchGenerator(mls, commands); candidateGenerator = new RandomSearchGenerator(mls, commands);
configuration = new OptimizationConfiguration.Builder() configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataProvider(dataProvider) .candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath)) .modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true)) .scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES),
@ -636,7 +574,148 @@ public class TestBasic extends BaseDL4JTest {
runner.execute(); runner.execute();
Thread.sleep(100000); Thread.sleep(1000_000);
}
/**
* Auto-attach multiple optimization sessions to multi-session mode UI
* @throws IOException if could not connect to the server
*/
@Test
public void testUiMultiSessionAutoAttach() throws IOException {
//Define configuration:
MultiLayerSpace mls = getMultiLayerSpaceMnist();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
Class<? extends DataSource> ds = MnistDataSource.class;
Properties dsp = new Properties();
dsp.setProperty("minibatch", "8");
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\")
.getAbsolutePath();
File f = new File(modelSavePath);
if (f.exists())
f.delete();
f.mkdir();
if (!f.exists())
throw new RuntimeException();
OptimizationConfiguration configuration =
new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS),
new MaxCandidatesCondition(1))
.build();
IOptimizationRunner runner =
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
// add 3 different sessions to the same execution
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
for (int i = 0; i < 3; i++) {
StatsStorage ss = new InMemoryStatsStorage();
@NonNull String sessionId = "sid" + i;
statsStorageForSession.put(sessionId, ss);
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
runner.addListeners(sl);
}
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
String serverAddress = uIServer.getAddress();
runner.execute();
for (String sessionId : statsStorageForSession.keySet()) {
/*
* Visiting /arbiter/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();
log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId));
assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode());
assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId)));
}
}
/**
* Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL
* @throws Exception if an error occurred
*/
@Test
@Ignore
public void testUiMultiSessionManualAttach() throws Exception {
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
//Define configuration:
MultiLayerSpace mls = getMultiLayerSpaceMnist();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls);
Class<? extends DataSource> ds = MnistDataSource.class;
Properties dsp = new Properties();
dsp.setProperty("minibatch", "8");
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\")
.getAbsolutePath();
File f = new File(modelSavePath);
if (f.exists())
f.delete();
f.mkdir();
if (!f.exists())
throw new RuntimeException();
OptimizationConfiguration configuration =
new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator).dataSource(ds, dsp)
.modelSaver(new FileModelSaver(modelSavePath))
.scoreFunction(new TestSetLossScoreFunction(true))
.terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES),
new MaxCandidatesCondition(10))
.build();
// parallel execution of multiple optimization sessions
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();
for (int i = 0; i < 3; i++) {
String sessionId = "sid" + i;
IOptimizationRunner runner =
new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
StatsStorage ss = new InMemoryStatsStorage();
statsStorageForSession.put(sessionId, ss);
StatusListener sl = new ArbiterStatusListener(sessionId, ss);
runner.addListeners(sl);
// Asynchronous execution
new Thread(runner::execute).start();
}
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);
String serverAddress = uIServer.getAddress();
for (String sessionId : statsStorageForSession.keySet()) {
log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId));
}
Thread.sleep(1000_000);
}
/**
* Get URL for arbiter session on given server address
* @param serverAddress server address, e.g.: http://localhost:9000
* @param sessionId session ID (will be URL-encoded)
* @return URL
* @throws UnsupportedEncodingException if the character encoding is not supported
*/
private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8"));
} }
private static class MnistDataSetProvider implements DataProvider { private static class MnistDataSetProvider implements DataProvider {

View File

@ -197,9 +197,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest {
} }
/** /**
* Get URL-encoded URL for training session on given server address * Get URL for training session on given server address
* @param serverAddress server address * @param serverAddress server address
* @param sessionId session ID * @param sessionId session ID (will be URL-encoded)
* @return URL * @return URL
* @throws UnsupportedEncodingException if the used encoding is not supported * @throws UnsupportedEncodingException if the used encoding is not supported
*/ */