* Add maven profile + base tests methods for integration tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Switch from system property to environment variable; seems more reliable in intellij Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add nd4j-common-tests module, and common base test; cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Ensure all ND4J tests extend BaseND4JTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test spam reduction, import fix Signed-off-by: AlexDBlack <blacka101@gmail.com> * Add test logging to nd4j-aeron Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix unintended change Signed-off-by: AlexDBlack <blacka101@gmail.com> * Reduce sprint test log spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Significantly speed up TSNE tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * W2V iterator test unit/integration split Signed-off-by: AlexDBlack <blacka101@gmail.com> * More NLP test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * Avoid debug/verbose mode leaking between tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * test tweak Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter extends base DL4J test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Arbiter test speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * nlp-uima test speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix ND4J base test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Few small ND4J test speed improvements Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J tests speedup Signed-off-by: AlexDBlack <blacka101@gmail.com> * More tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * Even more test speedups Signed-off-by: AlexDBlack <blacka101@gmail.com> * More tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * Various test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * More test fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Add ability to specify number of threads for C++ ops in BaseDL4JTest and BaseND4JTest Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-aeron test profile fix for CUDA Signed-off-by: Alex Black <blacka101@gmail.com>
167 lines
8.9 KiB
Java
167 lines
8.9 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.deeplearning4j.arbiter.multilayernetwork;
|
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
import org.deeplearning4j.arbiter.MultiLayerSpace;
|
|
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
|
|
import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace;
|
|
import org.deeplearning4j.arbiter.layers.DenseLayerSpace;
|
|
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
|
|
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
|
|
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
|
|
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
|
|
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
|
|
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
|
|
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
|
|
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
|
|
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
|
|
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
|
|
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
|
|
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
|
|
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
|
|
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
|
|
import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction;
|
|
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator;
|
|
import org.deeplearning4j.arbiter.util.TestDataFactoryProviderMnist;
|
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
|
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
|
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
|
|
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
|
|
import org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition;
|
|
import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;
|
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
import java.io.File;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
// import org.deeplearning4j.arbiter.optimize.ui.ArbiterUIServer;
|
|
// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener;
|
|
|
|
/** Not strictly a unit test. Rather: part example, part debugging on MNIST */
|
|
public class MNISTOptimizationTest extends BaseDL4JTest {
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
|
|
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
|
|
.epochTerminationConditions(new MaxEpochsTerminationCondition(3))
|
|
.iterationTerminationConditions(
|
|
new MaxTimeIterationTerminationCondition(5, TimeUnit.MINUTES),
|
|
new MaxScoreIterationTerminationCondition(4.6) //Random score: -log_e(0.1) ~= 2.3
|
|
).scoreCalculator(new DataSetLossCalculator(new MnistDataSetIterator(64, 2000, false, false, true, 123), true)).modelSaver(new InMemoryModelSaver()).build();
|
|
|
|
//Define: network config (hyperparameter space)
|
|
MultiLayerSpace mls = new MultiLayerSpace.Builder()
|
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
.updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2)))
|
|
.l2(new ContinuousParameterSpace(0.0001, 0.05))
|
|
.addLayer(
|
|
new ConvolutionLayerSpace.Builder().nIn(1)
|
|
.nOut(new IntegerParameterSpace(5, 30))
|
|
.kernelSize(new DiscreteParameterSpace<>(new int[] {3, 3},
|
|
new int[] {4, 4}, new int[] {5, 5}))
|
|
.stride(new DiscreteParameterSpace<>(new int[] {1, 1},
|
|
new int[] {2, 2}))
|
|
.activation(new DiscreteParameterSpace<>(Activation.RELU,
|
|
Activation.SOFTPLUS, Activation.LEAKYRELU))
|
|
.build(),
|
|
new IntegerParameterSpace(1, 2)) //1-2 identical layers
|
|
.addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(new IntegerParameterSpace(2, 10))
|
|
.activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH))
|
|
.build(), new IntegerParameterSpace(0, 1)) //0 to 1 layers
|
|
.addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX)
|
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
|
.earlyStoppingConfiguration(esConf).build();
|
|
Map<String, Object> commands = new HashMap<>();
|
|
commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName());
|
|
|
|
//Define configuration:
|
|
CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands);
|
|
DataProvider dataProvider = new MnistDataSetProvider();
|
|
|
|
|
|
String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterMNISTSmall\\").getAbsolutePath();
|
|
|
|
File f = new File(modelSavePath);
|
|
if (f.exists())
|
|
f.delete();
|
|
f.mkdir();
|
|
if (!f.exists())
|
|
throw new RuntimeException();
|
|
|
|
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
|
|
.candidateGenerator(candidateGenerator)
|
|
.dataProvider(dataProvider)
|
|
.modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(new TestSetLossScoreFunction(true))
|
|
.terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES),
|
|
new MaxCandidatesCondition(100))
|
|
.build();
|
|
|
|
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator());
|
|
|
|
// ArbiterUIServer server = ArbiterUIServer.getInstance();
|
|
// runner.addListeners(new UIOptimizationRunnerStatusListener(server));
|
|
|
|
runner.execute();
|
|
|
|
|
|
System.out.println("----- COMPLETE -----");
|
|
}
|
|
|
|
|
|
private static class MnistDataSetProvider implements DataProvider {
|
|
|
|
@Override
|
|
public DataSetIterator trainData(Map<String, Object> dataParameters) {
|
|
try {
|
|
if (dataParameters == null || dataParameters.isEmpty()) {
|
|
return new MnistDataSetIterator(64, 10000, false, true, true, 123);
|
|
}
|
|
if (dataParameters.containsKey("batchsize")) {
|
|
int b = (Integer) dataParameters.get("batchsize");
|
|
return new MnistDataSetIterator(b, 10000, false, true, true, 123);
|
|
}
|
|
return new MnistDataSetIterator(64, 10000, false, true, true, 123);
|
|
} catch (Exception e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public DataSetIterator testData(Map<String, Object> dataParameters) {
|
|
return trainData(dataParameters);
|
|
}
|
|
|
|
@Override
|
|
public Class<?> getDataType() {
|
|
return DataSetIterator.class;
|
|
}
|
|
|
|
@Override
|
|
public String toString() {
|
|
return "MnistDataSetProvider()";
|
|
}
|
|
}
|
|
}
|