236 lines
9.4 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.optimize.listeners;
import lombok.Builder;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map;
/**
* An iteration listener that provides details on parameters and gradients at each iteration during traning.
* Attempts to provide much of the same information as the UI histogram iteration listener, but in a text-based
* format (for example, when learning on a system accessed via SSH etc).
* i.e., is intended to aid network tuning and debugging<br>
* This iteration listener is set up to calculate mean, min, max, and mean absolute value
* of each type of parameter and gradient in the network at each iteration.<br>
*
* @author Alex Black
* @deprecated StatsListener can be used instead, storing data using FileStatsStorage - UI is not required
*/
public class ParamAndGradientIterationListener extends BaseTrainingListener {
private static final int MAX_WRITE_FAILURE_MESSAGES = 10;
private static final Logger logger = LoggerFactory.getLogger(ParamAndGradientIterationListener.class);
private int iterations;
private long totalIterationCount = 0;
private boolean printMean = true;
private boolean printHeader = true;
private boolean printMinMax = true;
private boolean printMeanAbsValue = true;
private File file;
private Path filePath;
private boolean outputToConsole;
private boolean outputToFile;
private boolean outputToLogger;
private String delimiter = "\t";
private int writeFailureCount = 0;
/** Default constructor for output to console only every iteration, tab delimited */
public ParamAndGradientIterationListener() {
this(1, true, true, true, true, true, false, false, null, "\t");
}
/**Full constructor with all options.
* Note also: ParamAndGradientIterationListener.builder() can be used instead of this constructor.
* @param iterations calculate and report values every 'iterations' iterations
* @param printHeader Whether to output a header row (i.e., names for each column)
* @param printMean Calculate and display the mean of parameters and gradients
* @param printMinMax Calculate and display the min/max of the parameters and gradients
* @param printMeanAbsValue Calculate and display the mean absolute value
* @param outputToConsole If true, display the values to the console (System.out.println())
* @param outputToFile If true, write the values to a file, one per line
* @param outputToLogger If true, log the values
* @param file File to write values to. May be null, not used if outputToFile == false
* @param delimiter delimiter (for example, "\t" or "," etc)
*/
@Builder
public ParamAndGradientIterationListener(int iterations, boolean printHeader, boolean printMean,
boolean printMinMax, boolean printMeanAbsValue, boolean outputToConsole, boolean outputToFile,
boolean outputToLogger, File file, String delimiter) {
this.printHeader = printHeader;
this.printMean = printMean;
this.printMinMax = printMinMax;
this.printMeanAbsValue = printMeanAbsValue;
this.iterations = iterations;
this.file = file;
if (this.file != null) {
this.filePath = file.toPath();
}
this.outputToConsole = outputToConsole;
this.outputToFile = outputToFile;
this.outputToLogger = outputToLogger;
this.delimiter = delimiter;
}
@Override
public void iterationDone(Model model, int iteration, int epoch) {
totalIterationCount++;
if (totalIterationCount == 1 && printHeader) {
Map<String, INDArray> params = model.paramTable();
model.conf().getVariables();
StringBuilder sb = new StringBuilder();
sb.append("n");
sb.append(delimiter);
sb.append("score");
for (String s : params.keySet()) {
//Parameters:
if (printMean)
sb.append(delimiter).append(s).append("_mean");
//Min, max
if (printMinMax) {
sb.append(delimiter).append(s).append("_min").append(delimiter).append(s).append("_max");
}
if (printMeanAbsValue)
sb.append(delimiter).append(s).append("_meanAbsValue");
//Gradients:
if (printMean)
sb.append(delimiter).append(s).append("_meanG");
//Min, max
if (printMinMax) {
sb.append(delimiter).append(s).append("_minG").append(delimiter).append(s).append("_maxG");
}
if (printMeanAbsValue)
sb.append(delimiter).append(s).append("_meanAbsValueG");
}
sb.append("\n");
if (outputToFile) {
try {
Files.write(filePath, sb.toString().getBytes(), StandardOpenOption.CREATE,
StandardOpenOption.TRUNCATE_EXISTING);
} catch (IOException e) {
if (writeFailureCount++ < MAX_WRITE_FAILURE_MESSAGES) {
//Print error message
logger.warn("Error writing to file: {}", e);
}
if (writeFailureCount == MAX_WRITE_FAILURE_MESSAGES) {
logger.warn("Max file write messages displayed. No more failure messages will be printed");
}
}
}
if (outputToLogger)
logger.info(sb.toString());
if (outputToConsole)
System.out.println(sb.toString());
}
if (totalIterationCount % iterations != 0)
return; //No op this iteration
Map<String, INDArray> params = model.paramTable();
Map<String, INDArray> grads = model.gradient().gradientForVariable();
StringBuilder sb = new StringBuilder();
sb.append(totalIterationCount);
sb.append(delimiter);
sb.append(model.score());
//Calculate actual values for parameters and gradients
for (Map.Entry<String, INDArray> entry : params.entrySet()) {
INDArray currParams = entry.getValue();
INDArray currGrad = grads.get(entry.getKey());
//Parameters:
if (printMean) {
sb.append(delimiter);
sb.append(currParams.meanNumber().doubleValue());
}
if (printMinMax) {
sb.append(delimiter);
sb.append(currParams.minNumber().doubleValue());
sb.append(delimiter);
sb.append(currParams.maxNumber().doubleValue());
}
if (printMeanAbsValue) {
sb.append(delimiter);
INDArray abs = Transforms.abs(currParams.dup());
sb.append(abs.meanNumber().doubleValue());
}
//Gradients:
if (printMean) {
sb.append(delimiter);
sb.append(currGrad.meanNumber().doubleValue());
}
if (printMinMax) {
sb.append(delimiter);
sb.append(currGrad.minNumber().doubleValue());
sb.append(delimiter);
sb.append(currGrad.maxNumber().doubleValue());
}
if (printMeanAbsValue) {
sb.append(delimiter);
INDArray abs = Transforms.abs(currGrad.dup());
sb.append(abs.meanNumber().doubleValue());
}
}
sb.append("\n");
String out = sb.toString();
if (outputToLogger)
logger.info(out);
if (outputToConsole)
System.out.print(out);
if (outputToFile) {
try {
Files.write(filePath, out.getBytes(), StandardOpenOption.CREATE, StandardOpenOption.APPEND);
} catch (IOException e) {
if (writeFailureCount++ < MAX_WRITE_FAILURE_MESSAGES) {
//Print error message
logger.warn("Error writing to file: {}", e);
}
if (writeFailureCount == MAX_WRITE_FAILURE_MESSAGES) {
logger.warn("Max file write messages displayed. No more failure messages will be printed");
}
}
}
}
}