86 lines
4.3 KiB
Java
86 lines
4.3 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.deeplearning4j.parallelism;
|
|
|
|
import org.deeplearning4j.api.storage.StatsStorage;
|
|
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
|
|
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
|
|
import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver;
|
|
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
|
|
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
|
|
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
|
|
import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition;
|
|
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
|
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
import org.deeplearning4j.ui.api.UIServer;
|
|
import org.deeplearning4j.ui.stats.StatsListener;
|
|
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
|
|
import org.junit.Ignore;
|
|
import org.junit.Test;
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
import org.nd4j.linalg.learning.config.Sgd;
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
import static org.junit.Assert.assertEquals;
|
|
|
|
|
|
public class TestParallelEarlyStoppingUI extends BaseDL4JTest {
|
|
|
|
@Test
|
|
@Ignore //To be run manually
|
|
public void testParallelStatsListenerCompatibility() throws Exception {
|
|
UIServer uiServer = UIServer.getInstance();
|
|
|
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
|
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
|
|
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
|
|
.layer(1, new OutputLayer.Builder().nIn(3).nOut(3)
|
|
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
|
|
.build();
|
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
|
|
|
// it's important that the UI can report results from parallel training
|
|
// there's potential for StatsListener to fail if certain properties aren't set in the model
|
|
StatsStorage statsStorage = new InMemoryStatsStorage();
|
|
net.setListeners(new StatsListener(statsStorage));
|
|
uiServer.attach(statsStorage);
|
|
|
|
DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
|
|
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
|
|
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
|
|
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
|
|
.epochTerminationConditions(new MaxEpochsTerminationCondition(500))
|
|
.scoreCalculator(new DataSetLossCalculator(irisIter, true))
|
|
.evaluateEveryNEpochs(2).modelSaver(saver).build();
|
|
|
|
IEarlyStoppingTrainer<MultiLayerNetwork> trainer =
|
|
new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);
|
|
|
|
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
|
|
System.out.println(result);
|
|
|
|
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
|
|
}
|
|
}
|