2019-06-06 15:21:15 +03:00

176 lines
7.9 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.earlystopping;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver;
import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator;
import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.function.Supplier;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Early stopping configuration: Specifies the various configuration options for running training with early stopping.<br>
* Users need to specify the following:<br>
* (a) EarlyStoppingModelSaver: How models will be saved (to disk, to memory, etc) (Default: in memory)<br>
* (b) Termination conditions: at least one termination condition must be specified<br>
* (i) Iteration termination conditions: calculated once for each minibatch. For example, maxTime or invalid (NaN/infinite) scores<br>
* (ii) Epoch termination conditions: calculated once per epoch. For example, maxEpochs or no improvement for N epochs<br>
* (c) Score calculator: what score should be calculated at every epoch? (For example: test set loss or test set accuracy)<br>
* (d) How frequently (ever N epochs) should scores be calculated? (Default: every epoch)<br>
* @param <T> Type of model. For example, {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork} or {@link org.deeplearning4j.nn.graph.ComputationGraph}
* @author Alex Black
*/
@Data
@NoArgsConstructor
public class EarlyStoppingConfiguration<T extends Model> implements Serializable {
private EarlyStoppingModelSaver<T> modelSaver;
private List<EpochTerminationCondition> epochTerminationConditions;
private List<IterationTerminationCondition> iterationTerminationConditions;
private boolean saveLastModel;
private int evaluateEveryNEpochs;
private ScoreCalculator<T> scoreCalculator;
private Supplier<ScoreCalculator> scoreCalculatorSupplier;
private EarlyStoppingConfiguration(Builder<T> builder) {
this.modelSaver = builder.modelSaver;
this.epochTerminationConditions = builder.epochTerminationConditions;
this.iterationTerminationConditions = builder.iterationTerminationConditions;
this.saveLastModel = builder.saveLastModel;
this.evaluateEveryNEpochs = builder.evaluateEveryNEpochs;
this.scoreCalculator = builder.scoreCalculator;
this.scoreCalculatorSupplier = builder.scoreCalculatorSupplier;
}
public ScoreCalculator<T> getScoreCalculator(){
if(scoreCalculatorSupplier != null){
return scoreCalculatorSupplier.get();
}
return scoreCalculator;
}
public void validate() {
if(scoreCalculator == null && scoreCalculatorSupplier == null) {
throw new DL4JInvalidConfigException("A score calculator or score calculator supplier must be defined.");
}
if(modelSaver == null) {
throw new DL4JInvalidConfigException("A model saver must be defined");
}
boolean hasTermination = false;
if(iterationTerminationConditions != null && !iterationTerminationConditions.isEmpty()) {
hasTermination = true;
}
else if(epochTerminationConditions != null && !epochTerminationConditions.isEmpty()) {
hasTermination = true;
}
if(!hasTermination) {
throw new DL4JInvalidConfigException("No termination conditions defined.");
}
}
public static class Builder<T extends Model> {
private EarlyStoppingModelSaver<T> modelSaver = new InMemoryModelSaver<>();
private List<EpochTerminationCondition> epochTerminationConditions = new ArrayList<>();
private List<IterationTerminationCondition> iterationTerminationConditions = new ArrayList<>();
private boolean saveLastModel = false;
private int evaluateEveryNEpochs = 1;
private ScoreCalculator<T> scoreCalculator;
private Supplier<ScoreCalculator> scoreCalculatorSupplier;
/** How should models be saved? (Default: in memory)*/
public Builder<T> modelSaver(EarlyStoppingModelSaver<T> modelSaver) {
this.modelSaver = modelSaver;
return this;
}
/** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */
public Builder<T> epochTerminationConditions(EpochTerminationCondition... terminationConditions) {
epochTerminationConditions.clear();
Collections.addAll(epochTerminationConditions, terminationConditions);
return this;
}
/** Termination conditions to be evaluated every N epochs, with N set by evaluateEveryNEpochs option */
public Builder<T> epochTerminationConditions(List<EpochTerminationCondition> terminationConditions) {
this.epochTerminationConditions = terminationConditions;
return this;
}
/** Termination conditions to be evaluated every iteration (minibatch)*/
public Builder<T> iterationTerminationConditions(IterationTerminationCondition... terminationConditions) {
iterationTerminationConditions.clear();
Collections.addAll(iterationTerminationConditions, terminationConditions);
return this;
}
/** Save the last model? If true: save the most recent model at each epoch, in addition to the best
* model (whenever the best model improves). If false: only save the best model. Default: false
* Useful for example if you might want to continue training after a max-time terminatino condition
* occurs.
*/
public Builder<T> saveLastModel(boolean saveLastModel) {
this.saveLastModel = saveLastModel;
return this;
}
/** How frequently should evaluations be conducted (in terms of epochs)? Defaults to every (1) epochs. */
public Builder<T> evaluateEveryNEpochs(int everyNEpochs) {
this.evaluateEveryNEpochs = everyNEpochs;
return this;
}
/** Score calculator. Used to calculate a score (such as loss function on a test set), every N epochs,
* where N is set by {@link #evaluateEveryNEpochs}
*/
public Builder<T> scoreCalculator(ScoreCalculator scoreCalculator) {
this.scoreCalculator = scoreCalculator;
return this;
}
/** Score calculator. Used to calculate a score (such as loss function on a test set), every N epochs,
* where N is set by {@link #evaluateEveryNEpochs}
*/
public Builder<T> scoreCalculator(Supplier<ScoreCalculator> scoreCalculatorSupplier){
this.scoreCalculatorSupplier = scoreCalculatorSupplier;
return this;
}
/** Create the early stopping configuration */
public EarlyStoppingConfiguration<T> build() {
return new EarlyStoppingConfiguration<>(this);
}
}
}