SameDiff: Listener changes and training api update (#99)
* example api Signed-off-by: Ryan Nett <rnett@skymind.io> * Lambda based evaluation Signed-off-by: Ryan Nett <rnett@skymind.io> * lambda test Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * partial fixes, use get-variable listener framework, example EvaluationListener Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc fix and newInstance implementations Signed-off-by: Ryan Nett <rnett@skymind.io> * fit and evaluate methods with validation data (for fit) and listeners Signed-off-by: Ryan Nett <rnett@skymind.io> * output method overloads + listener args Signed-off-by: Ryan Nett <rnett@skymind.io> * history and evaluation helpers Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * FitConfig and added getters and setters Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, javadoc, added activations to history, added latest activation listener Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, start of tests Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes and updates Signed-off-by: Ryan Nett <rnett@skymind.io> * newInstance fixes, tests Signed-off-by: Ryan Nett <rnett@skymind.io> * test fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs, getters with SDVariable overrides, CustomEvaluation fix Signed-off-by: Ryan Nett <rnett@skymind.io> * more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests Signed-off-by: Ryan Nett <rnett@skymind.io> * merge fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * fix Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, most old fit/evaluate/output methods use the builders Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * numerous fixes/cleanup Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * Polish round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Formatting + round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 4 Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
6ed03217b4
commit
11bddb3825
|
@ -21,12 +21,13 @@ import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
|
|||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
/**
|
||||
* Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph
|
||||
* on a test set. Supports all regression metrics: {@link RegressionEvaluation.Metric}
|
||||
* on a test set. Supports all regression metrics: {@link Metric}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
|
@ -35,13 +36,13 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|||
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For JSON
|
||||
public class RegressionScoreFunction extends BaseNetScoreFunction {
|
||||
|
||||
protected RegressionEvaluation.Metric metric;
|
||||
protected Metric metric;
|
||||
|
||||
public RegressionScoreFunction(@NonNull org.deeplearning4j.eval.RegressionEvaluation.Metric metric) {
|
||||
this(metric.toNd4j());
|
||||
}
|
||||
|
||||
public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) {
|
||||
public RegressionScoreFunction(@NonNull Metric metric) {
|
||||
this.metric = metric;
|
||||
}
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ import org.junit.Test;
|
|||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.ROCBinary;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -107,7 +107,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
|||
min = false;
|
||||
break;
|
||||
case 3:
|
||||
sc = new RegressionScoreCalculator(RegressionEvaluation.Metric.MSE, irisIter);
|
||||
sc = new RegressionScoreCalculator(Metric.MSE, irisIter);
|
||||
min = true;
|
||||
break;
|
||||
case 4:
|
||||
|
@ -561,8 +561,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testRegressionScoreFunctionSimple() throws Exception {
|
||||
|
||||
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
|
||||
RegressionEvaluation.Metric.MAE}) {
|
||||
for(Metric metric : new Metric[]{Metric.MSE,
|
||||
Metric.MAE}) {
|
||||
log.info("Metric: " + metric);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
@ -604,8 +604,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testAEScoreFunctionSimple() throws Exception {
|
||||
|
||||
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
|
||||
RegressionEvaluation.Metric.MAE}) {
|
||||
for(Metric metric : new Metric[]{Metric.MSE,
|
||||
Metric.MAE}) {
|
||||
log.info("Metric: " + metric);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
@ -647,8 +647,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testVAEScoreFunctionSimple() throws Exception {
|
||||
|
||||
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
|
||||
RegressionEvaluation.Metric.MAE}) {
|
||||
for(Metric metric : new Metric[]{Metric.MSE,
|
||||
Metric.MAE}) {
|
||||
log.info("Metric: " + metric);
|
||||
|
||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
|
|
@ -43,7 +43,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -289,8 +289,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testRegressionScoreFunctionSimple() throws Exception {
|
||||
|
||||
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
|
||||
RegressionEvaluation.Metric.MAE}) {
|
||||
for(Metric metric : new Metric[]{Metric.MSE,
|
||||
Metric.MAE}) {
|
||||
log.info("Metric: " + metric);
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
@ -335,8 +335,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
|||
public void testAEScoreFunctionSimple() throws Exception {
|
||||
DataType dt = Nd4j.defaultFloatingPointType();
|
||||
|
||||
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
|
||||
RegressionEvaluation.Metric.MAE}) {
|
||||
for(Metric metric : new Metric[]{Metric.MSE,
|
||||
Metric.MAE}) {
|
||||
log.info("Metric: " + metric);
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
@ -380,8 +380,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
|
|||
@Test
|
||||
public void testVAEScoreFunctionSimple() throws Exception {
|
||||
|
||||
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
|
||||
RegressionEvaluation.Metric.MAE}) {
|
||||
for(Metric metric : new Metric[]{Metric.MSE,
|
||||
Metric.MAE}) {
|
||||
log.info("Metric: " + metric);
|
||||
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
|||
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
@ -30,16 +31,16 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|||
/**
|
||||
* Score function for a MultiLayerNetwork or ComputationGraph with a single
|
||||
* {@link org.deeplearning4j.nn.conf.layers.AutoEncoder} layer.
|
||||
* Calculates the specified {@link RegressionEvaluation.Metric} on the layer's reconstructions.
|
||||
* Calculates the specified {@link Metric} on the layer's reconstructions.
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class AutoencoderScoreCalculator extends BaseScoreCalculator<Model> {
|
||||
|
||||
protected final RegressionEvaluation.Metric metric;
|
||||
protected final Metric metric;
|
||||
protected RegressionEvaluation evaluation;
|
||||
|
||||
public AutoencoderScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){
|
||||
public AutoencoderScoreCalculator(Metric metric, DataSetIterator iterator){
|
||||
super(iterator);
|
||||
this.metric = metric;
|
||||
}
|
||||
|
|
|
@ -19,19 +19,20 @@ package org.deeplearning4j.earlystopping.scorecalc;
|
|||
import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator;
|
||||
import org.deeplearning4j.nn.api.Model;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
||||
/**
|
||||
* Calculate the regression score of the network (MultiLayerNetwork or ComputationGraph) on a test set, using the
|
||||
* specified regression metric - {@link RegressionEvaluation.Metric}
|
||||
* specified regression metric - {@link Metric}
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public class RegressionScoreCalculator extends BaseIEvaluationScoreCalculator<Model, RegressionEvaluation> {
|
||||
|
||||
protected final RegressionEvaluation.Metric metric;
|
||||
protected final Metric metric;
|
||||
|
||||
public RegressionScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){
|
||||
public RegressionScoreCalculator(Metric metric, DataSetIterator iterator){
|
||||
super(iterator);
|
||||
this.metric = metric;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
|
|||
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
|
@ -35,7 +36,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|||
*/
|
||||
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<Model> {
|
||||
|
||||
protected final RegressionEvaluation.Metric metric;
|
||||
protected final Metric metric;
|
||||
protected RegressionEvaluation evaluation;
|
||||
|
||||
/**
|
||||
|
@ -44,7 +45,7 @@ public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<Model> {
|
|||
* @param metric
|
||||
* @param iterator
|
||||
*/
|
||||
public VAEReconErrorScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator) {
|
||||
public VAEReconErrorScoreCalculator(Metric metric, DataSetIterator iterator) {
|
||||
super(iterator);
|
||||
this.metric = metric;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,22 @@ public class At {
|
|||
private int iteration;
|
||||
private int trainingThreadNum;
|
||||
private long javaThreadNum;
|
||||
private Operation operation;
|
||||
|
||||
/**
|
||||
* @return A new instance with everything set to 0, and operation set to INFERENCE
|
||||
*/
|
||||
public static At defaultAt(){
|
||||
return new At(0, 0, 0, 0, Operation.INFERENCE);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param op Operation
|
||||
* @return A new instance with everything set to 0, except for the specified operation
|
||||
*/
|
||||
public static At defaultAt(@NonNull Operation op){
|
||||
return new At(0, 0, 0, 0, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The current training epoch
|
||||
|
@ -48,4 +64,26 @@ public class At {
|
|||
public long javaThreadNum(){
|
||||
return javaThreadNum;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The current operation
|
||||
*/
|
||||
public Operation operation(){
|
||||
return operation;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return A copy of the current At instance
|
||||
*/
|
||||
public At copy(){
|
||||
return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param operation Operation to set in the new instance
|
||||
* @return A copy of the current instance, but with the specified operation
|
||||
*/
|
||||
public At copy(Operation operation){
|
||||
return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* A base listener class that will preform the provided evaluations, and provide the results in epochEnd and validationDone
|
||||
*
|
||||
* Instead of overriding requiredVariables, epochStart, epochEnd, validationDone, and/or opExecution,
|
||||
* override otherRequiredVariables, epochStartEvaluations, epochEndEvaluations, validationDoneEvaluations, and/or opExecutionEvaluations
|
||||
*
|
||||
* <strong>If you want to use Evaluations in your listener, extend this class</strong>
|
||||
*/
|
||||
public abstract class BaseEvaluationListener extends BaseListener {
|
||||
|
||||
private Map<String, List<IEvaluation>> trainingEvaluations = new HashMap<>();
|
||||
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Return the requested evaluations. New instances of these evaluations will be made each time they are used
|
||||
*/
|
||||
public abstract ListenerEvaluations evaluations();
|
||||
|
||||
@Override
|
||||
public final ListenerVariables requiredVariables(SameDiff sd) {
|
||||
return evaluations().requiredVariables().merge(otherRequiredVariables(sd));
|
||||
}
|
||||
|
||||
/**
|
||||
* Return any requested variables that are not part of the evaluations
|
||||
*/
|
||||
public ListenerVariables otherRequiredVariables(SameDiff sd){
|
||||
return ListenerVariables.empty();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public final void epochStart(SameDiff sd, At at) {
|
||||
trainingEvaluations = new HashMap<>();
|
||||
for(Map.Entry<String, List<IEvaluation>> entry : evaluations().trainEvaluations().entrySet()){
|
||||
|
||||
List<IEvaluation> evals = new ArrayList<>();
|
||||
for(IEvaluation ie : entry.getValue())
|
||||
evals.add(ie.newInstance());
|
||||
|
||||
trainingEvaluations.put(entry.getKey(), evals);
|
||||
}
|
||||
validationEvaluations = new HashMap<>();
|
||||
for(Map.Entry<String, List<IEvaluation>> entry : evaluations().validationEvaluations().entrySet()){
|
||||
|
||||
List<IEvaluation> evals = new ArrayList<>();
|
||||
for(IEvaluation ie : entry.getValue())
|
||||
evals.add(ie.newInstance());
|
||||
|
||||
validationEvaluations.put(entry.getKey(), evals);
|
||||
}
|
||||
|
||||
epochStartEvaluations(sd, at);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link Listener#epochStart(SameDiff, At)}
|
||||
*/
|
||||
public void epochStartEvaluations(SameDiff sd, At at){
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public final ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
|
||||
return epochEndEvaluations(sd, at, lossCurve, epochTimeMillis, new EvaluationRecord(trainingEvaluations));
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link Listener#epochEnd(SameDiff, At, LossCurve, long)}, also provided the requested evaluations
|
||||
*/
|
||||
public ListenerResponse epochEndEvaluations(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis, EvaluationRecord evaluations) {
|
||||
//No op
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
|
||||
return validationDoneEvaluations(sd, at, validationTimeMillis, new EvaluationRecord(validationEvaluations));
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link Listener#validationDone(SameDiff, At, long)}, also provided the requested evaluations
|
||||
*/
|
||||
public ListenerResponse validationDoneEvaluations(SameDiff sd, At at, long validationTimeMillis, EvaluationRecord evaluations) {
|
||||
//No op
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
|
||||
INDArray activation) {
|
||||
if(at.operation() == Operation.TRAINING) {
|
||||
if (trainingEvaluations.containsKey(varName)) {
|
||||
INDArray labels = batch.getLabels(evaluations().trainEvaluationLabels().get(varName));
|
||||
INDArray mask = batch.getLabelsMaskArray(evaluations().trainEvaluationLabels().get(varName));
|
||||
|
||||
for (IEvaluation e : trainingEvaluations.get(varName))
|
||||
e.eval(labels, activation, mask);
|
||||
}
|
||||
} else if(at.operation() == Operation.TRAINING_VALIDATION) {
|
||||
if (validationEvaluations.containsKey(varName)) {
|
||||
INDArray labels = batch.getLabels(evaluations().validationEvaluationLabels().get(varName));
|
||||
INDArray mask = batch.getLabelsMaskArray(evaluations().validationEvaluationLabels().get(varName));
|
||||
|
||||
for (IEvaluation e : validationEvaluations.get(varName))
|
||||
e.eval(labels, activation, mask);
|
||||
}
|
||||
}
|
||||
|
||||
activationAvailableEvaluations(sd, at, batch, op, varName, activation);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)}
|
||||
*/
|
||||
public void activationAvailableEvaluations(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
|
||||
INDArray activation){
|
||||
//No op
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
|
@ -11,18 +11,32 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|||
* A base/abstract {@link Listener} with all methods implemented as no-op.
|
||||
* Extend this for custom listeners to selectively override only the required methods
|
||||
*
|
||||
* <strong>If you want to use evaluations in your listener, use {@link BaseEvaluationListener}</strong>
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
public abstract class BaseListener implements Listener {
|
||||
|
||||
|
||||
@Override
|
||||
public ListenerVariables requiredVariables(SameDiff sd){
|
||||
return ListenerVariables.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void epochStart(SameDiff sd, At at) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void epochEnd(SameDiff sd, At at) {
|
||||
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
|
||||
//No op
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -36,12 +50,28 @@ public abstract class BaseListener implements Listener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) {
|
||||
public void operationStart(SameDiff sd, Operation op) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
|
||||
public void operationEnd(SameDiff sd, Operation op) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
//No op
|
||||
}
|
||||
|
||||
@Override
|
||||
public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
|
||||
INDArray activation) {
|
||||
//No op
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.autodiff.samediff.internal.Variable;
|
||||
|
@ -11,10 +11,29 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|||
* A {@link SameDiff} listener interface that is called during every iteration of training or inference
|
||||
*
|
||||
* @author Alex Black
|
||||
* @see BaseListener BaseListener, for extending
|
||||
* @see BaseListener BaseListener, for extending only the required methods (all others are no-op)
|
||||
* @see BaseEvaluationListener BaseEvaluationListener, for extending if you want to use evaluations
|
||||
*/
|
||||
public interface Listener {
|
||||
|
||||
|
||||
/**
|
||||
* Required variables for this listener.
|
||||
* <p>
|
||||
* Used to ensure these variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}.
|
||||
* Otherwise, if the variables weren't required by a loss variable, they would not be calculated.
|
||||
* <p>
|
||||
* Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)}
|
||||
* called for them, regardless of whether they would normally be calculated or not.
|
||||
*/
|
||||
ListenerVariables requiredVariables(SameDiff sd);
|
||||
|
||||
/**
|
||||
* Returns whether this listener is active during the given operation. If this returns false for the given operation,
|
||||
* those listener methods will not be called.
|
||||
*/
|
||||
boolean isActive(Operation operation);
|
||||
|
||||
/**
|
||||
* Called at the start of every epoch, when fitting from an iterator
|
||||
*
|
||||
|
@ -26,10 +45,23 @@ public interface Listener {
|
|||
/**
|
||||
* Called at the end of every epoch, when fitting from an iterator
|
||||
*
|
||||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param lossCurve The losses so far
|
||||
* @param epochTimeMillis How long this epoch took
|
||||
* @return ListenerResponse.STOP to stop training, CONTINUE or null to continue
|
||||
*/
|
||||
void epochEnd(SameDiff sd, At at);
|
||||
ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis);
|
||||
|
||||
/**
|
||||
* Called after the end of every epoch, once validation evaluation is done, when training
|
||||
*
|
||||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param validationTimeMillis How long validation took for this epoch
|
||||
* @return ListenerResponse.STOP to stop training, CONTINUE or null to continue
|
||||
*/
|
||||
ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis);
|
||||
|
||||
/**
|
||||
* Called at the start of every iteration (minibatch), before any operations have been executed
|
||||
|
@ -45,31 +77,70 @@ public interface Listener {
|
|||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param dataSet The current dataset (minibatch) used for training
|
||||
* @param loss The loss value for the current minibatch
|
||||
* @param loss The loss value for the current minibatch. Will be null except for during training
|
||||
*/
|
||||
void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss);
|
||||
|
||||
/**
|
||||
* Called just before each operation is executed (native code called, etc) - after all inputs etc have been set
|
||||
* Called at the start of an operation, e.g. training or validation
|
||||
*
|
||||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param op Operation that has just been executed
|
||||
* @param sd The SameDiff instance
|
||||
* @param op The operation being started
|
||||
*/
|
||||
void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op);
|
||||
void operationStart(SameDiff sd, Operation op);
|
||||
|
||||
/**
|
||||
* Called at the end of each operation execution
|
||||
* Called at the end of an operation, e.g. training or validation
|
||||
*
|
||||
* @param sd The SameDiff instance
|
||||
* @param op The operation being started
|
||||
*/
|
||||
void operationEnd(SameDiff sd, Operation op);
|
||||
|
||||
/**
|
||||
* Called just before each operation is executed (native code called, etc) - after all inputs etc have been set
|
||||
*
|
||||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param op Operation that has just been executed
|
||||
*/
|
||||
void preOpExecution(SameDiff sd, At at, SameDiffOp op);
|
||||
|
||||
/**
|
||||
* Called at the end of each operation execution<br>
|
||||
* <p>
|
||||
* Note: Outputs will most likely be freed later, use detach() if you need to save it.
|
||||
*
|
||||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param batch The batch's input data. May be null if not called with a batch
|
||||
* @param op Operation that has just been executed
|
||||
* @param outputs The output arrays for the just-executed operation
|
||||
*/
|
||||
void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs);
|
||||
void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs);
|
||||
|
||||
/**
|
||||
* Called just before each parameter is to be updated - i.e., just before each parameter is modified
|
||||
* Called when any activation becomes available.
|
||||
* <p>
|
||||
* The activation will most likely be freed later, use detach() if you need to save it.<br>
|
||||
* <br>
|
||||
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
|
||||
* It is guaranteed to be called for variables from requiredVariables().<br>
|
||||
* <br>
|
||||
* Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} -
|
||||
* both contain the same information/arrays
|
||||
*
|
||||
* @param sd The SameDiff instance
|
||||
* @param at Current iteration/epoch etc
|
||||
* @param batch The batch's input data. May be null if not called with a batch
|
||||
* @param op Operation that has just been executed
|
||||
* @param varName The name of the variable
|
||||
* @param activation The variable's activation
|
||||
*/
|
||||
void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation);
|
||||
|
||||
/**
|
||||
* Called just before each parameter is to be updated - i.e., just before each parameter is modified.
|
||||
*
|
||||
* @param sd SameDiff instance
|
||||
* @param at The current iteration/epoch etc
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
|
||||
/**
|
||||
* A class to allow Listeners to define what evaluations they need to run during training<br>
|
||||
* <p>
|
||||
* Usage example - does classification ({@link org.nd4j.evaluation.classification.Evaluation}) evaluation on
|
||||
* the training set (as training proceeds) and also Evaluation/ROCMultiClass evaluation on the test/validation set.
|
||||
* Assumes that the output predictions are called "softmax" and the first DataSet/MultiDataSet labels are those corresponding
|
||||
* to the "softmax" node
|
||||
* <pre>{@code
|
||||
* ListenerEvaluations.builder()
|
||||
* //trainEvaluations: on the training data (in-line, as training proceeds through the epoch)
|
||||
* .trainEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
|
||||
* //validationEvaluation: on the test/validation data, at the end of each epoch
|
||||
* .validationEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
|
||||
* .build();
|
||||
* }</pre>
|
||||
*/
|
||||
@Getter
|
||||
public class ListenerEvaluations {
|
||||
private Map<String, List<IEvaluation>> trainEvaluations;
|
||||
private Map<String, Integer> trainEvaluationLabels;
|
||||
|
||||
private Map<String, List<IEvaluation>> validationEvaluations;
|
||||
private Map<String, Integer> validationEvaluationLabels;
|
||||
|
||||
public ListenerEvaluations(Map<String, List<IEvaluation>> trainEvaluations,
|
||||
Map<String, Integer> trainEvaluationLabels, Map<String, List<IEvaluation>> validationEvaluations,
|
||||
Map<String, Integer> validationEvaluationLabels) {
|
||||
this.trainEvaluations = trainEvaluations;
|
||||
this.trainEvaluationLabels = trainEvaluationLabels;
|
||||
this.validationEvaluations = validationEvaluations;
|
||||
this.validationEvaluationLabels = validationEvaluationLabels;
|
||||
|
||||
Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()),
|
||||
"Must specify a label index for each train evaluation. Expected: %s, got: %s",
|
||||
trainEvaluations.keySet(), trainEvaluationLabels.keySet());
|
||||
|
||||
Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()),
|
||||
"Must specify a label index for each validation evaluation. Expected: %s, got: %s",
|
||||
validationEvaluations.keySet(), validationEvaluationLabels.keySet());
|
||||
}
|
||||
|
||||
private ListenerEvaluations() {
|
||||
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the requested training evaluations
|
||||
*/
|
||||
public Map<String, List<IEvaluation>> trainEvaluations() {
|
||||
return trainEvaluations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the label indices for the requested training evaluations
|
||||
*/
|
||||
public Map<String, Integer> trainEvaluationLabels() {
|
||||
return trainEvaluationLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the requested validation evaluations
|
||||
*/
|
||||
public Map<String, List<IEvaluation>> validationEvaluations() {
|
||||
return validationEvaluations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the label indices for the requested validation evaluations
|
||||
*/
|
||||
public Map<String, Integer> validationEvaluationLabels() {
|
||||
return validationEvaluationLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the required variables for these evaluations
|
||||
*/
|
||||
public ListenerVariables requiredVariables() {
|
||||
return new ListenerVariables(trainEvaluations.keySet(), validationEvaluations.keySet(),
|
||||
new HashSet<String>(), new HashSet<String>());
|
||||
}
|
||||
|
||||
/**
|
||||
* @return true if there are no requested evaluations
|
||||
*/
|
||||
public boolean isEmpty() {
|
||||
return trainEvaluations.isEmpty() && validationEvaluations.isEmpty();
|
||||
}
|
||||
|
||||
@NoArgsConstructor
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder {
|
||||
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
|
||||
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
|
||||
|
||||
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
|
||||
private Map<String, Integer> validationEvaluationLabels = new HashMap<>();
|
||||
|
||||
private void addEvaluations(boolean validation, @NonNull Map<String, List<IEvaluation>> evaluationMap, @NonNull Map<String, Integer> labelMap,
|
||||
@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
if (evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex) {
|
||||
String s;
|
||||
|
||||
if (validation) {
|
||||
s = "This ListenerEvaluations.Builder already has validation evaluations for ";
|
||||
} else {
|
||||
s = "This ListenerEvaluations.Builder already has train evaluations for ";
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(s + "variable " +
|
||||
variableName + " with label index " + labelIndex + ". You can't add " +
|
||||
" evaluations with a different label index. Got label index " + labelIndex);
|
||||
}
|
||||
|
||||
if (evaluationMap.containsKey(variableName)) {
|
||||
evaluationMap.get(variableName).addAll(Arrays.asList(evaluations));
|
||||
} else {
|
||||
evaluationMap.put(variableName, Arrays.asList(evaluations));
|
||||
labelMap.put(variableName, labelIndex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested training evaluations for a parm/variable
|
||||
*
|
||||
* @param variableName The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName,
|
||||
labelIndex, evaluations);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested training evaluations for a parm/variable
|
||||
*
|
||||
* @param variable The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
return trainEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested validation evaluations for a parm/variable
|
||||
*
|
||||
* @param variableName The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName,
|
||||
labelIndex, evaluations);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested validation evaluations for a parm/variable
|
||||
*
|
||||
* @param variable The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
return validationEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested evaluations for a parm/variable, for either training or validation
|
||||
*
|
||||
* @param validation Whether to add these evaluations as validation or training
|
||||
* @param variableName The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
|
||||
if (validation) {
|
||||
return validationEvaluation(variableName, labelIndex, evaluations);
|
||||
} else {
|
||||
return trainEvaluation(variableName, labelIndex, evaluations);
|
||||
}
|
||||
}
|
||||
|
||||
public ListenerEvaluations build() {
|
||||
return new ListenerEvaluations(trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
/**
|
||||
* An enum representing feedback given by listeners during the training loop.<br>
|
||||
* CONTINUE: Continue training for more epochs, unless the specified (maximum) number of training epochs have already been completed.<br>
|
||||
* STOP: Terminate training at the current point, irrespective of how many total epochs were specified when calling fit.<br>
|
||||
*/
|
||||
public enum ListenerResponse {
|
||||
CONTINUE, STOP;
|
||||
}
|
|
@ -0,0 +1,236 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* Specifies a Listener's required variables for each operation.
|
||||
* Used to ensure those variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}.
|
||||
* Otherwise, if the variables weren't required by a loss variable, they would not be calculated.
|
||||
* <p>
|
||||
* Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} called for them.
|
||||
*/
|
||||
@RequiredArgsConstructor
|
||||
@Getter
|
||||
public class ListenerVariables {
|
||||
|
||||
public static ListenerVariables empty() {
|
||||
return ListenerVariables.builder().build();
|
||||
}
|
||||
|
||||
@NonNull
|
||||
private Set<String> trainingVariables;
|
||||
@NonNull
|
||||
private Set<String> validationVariables;
|
||||
@NonNull
|
||||
private Set<String> evaluationVariables;
|
||||
@NonNull
|
||||
private Set<String> inferenceVariables;
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get required training variables
|
||||
*/
|
||||
public Set<String> trainingVariables() {
|
||||
return trainingVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get required validation variables
|
||||
*/
|
||||
public Set<String> validationVariables() {
|
||||
return validationVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get required evaluation variables
|
||||
*/
|
||||
public Set<String> evaluationVariables() {
|
||||
return evaluationVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get required inference variables
|
||||
*/
|
||||
public Set<String> inferenceVariables() {
|
||||
return inferenceVariables;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get required variables for specified op
|
||||
*/
|
||||
public Set<String> requiredVariables(Operation op) {
|
||||
switch (op) {
|
||||
case TRAINING:
|
||||
return trainingVariables;
|
||||
case TRAINING_VALIDATION:
|
||||
return validationVariables;
|
||||
case INFERENCE:
|
||||
return inferenceVariables;
|
||||
case EVALUATION:
|
||||
return evaluationVariables;
|
||||
}
|
||||
throw new IllegalArgumentException("Unknown operation " + op);
|
||||
}
|
||||
|
||||
private ListenerVariables() {
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new ListenerVariables that contains the variables of this ListenerVariables and of other
|
||||
*/
|
||||
public ListenerVariables merge(ListenerVariables other) {
|
||||
return new ListenerVariables(
|
||||
Sets.newHashSet(Sets.union(trainingVariables, other.trainingVariables)),
|
||||
Sets.newHashSet(Sets.union(validationVariables, other.validationVariables)),
|
||||
Sets.newHashSet(Sets.union(evaluationVariables, other.evaluationVariables)),
|
||||
Sets.newHashSet(Sets.union(inferenceVariables, other.inferenceVariables)));
|
||||
}
|
||||
|
||||
@NoArgsConstructor
|
||||
@Getter
|
||||
@Setter
|
||||
public static class Builder {
|
||||
@NonNull
|
||||
private Set<String> trainingVariables = new HashSet<>();
|
||||
@NonNull
|
||||
private Set<String> validationVariables = new HashSet<>();
|
||||
@NonNull
|
||||
private Set<String> evaluationVariables = new HashSet<>();
|
||||
@NonNull
|
||||
private Set<String> inferenceVariables = new HashSet<>();
|
||||
|
||||
/**
|
||||
* Add required variables for the specified op
|
||||
*
|
||||
* @param op The op to require the variable for
|
||||
*/
|
||||
public Builder requireVariables(@NonNull Operation op, @NonNull String... variables) {
|
||||
switch (op) {
|
||||
case TRAINING:
|
||||
trainingVariables.addAll(Arrays.asList(variables));
|
||||
break;
|
||||
case TRAINING_VALIDATION:
|
||||
validationVariables.addAll(Arrays.asList(variables));
|
||||
break;
|
||||
case INFERENCE:
|
||||
inferenceVariables.addAll(Arrays.asList(variables));
|
||||
break;
|
||||
case EVALUATION:
|
||||
evaluationVariables.addAll(Arrays.asList(variables));
|
||||
break;
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for the specified op
|
||||
*
|
||||
* @param op The op to require the variable for
|
||||
*/
|
||||
public Builder requireVariables(@NonNull Operation op, @NonNull SDVariable... variables) {
|
||||
String[] names = new String[variables.length];
|
||||
|
||||
for (int i = 0; i < variables.length; i++)
|
||||
names[i] = variables[i].getVarName();
|
||||
|
||||
return requireVariables(op, names);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for training
|
||||
*/
|
||||
public Builder trainingVariables(@NonNull String... variables) {
|
||||
return requireVariables(Operation.TRAINING, variables);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for training
|
||||
*/
|
||||
public Builder trainingVariables(@NonNull SDVariable... variables) {
|
||||
return requireVariables(Operation.TRAINING, variables);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for validation
|
||||
*/
|
||||
public Builder validationVariables(@NonNull String... variables) {
|
||||
return requireVariables(Operation.TRAINING_VALIDATION, variables);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for validation
|
||||
*/
|
||||
public Builder validationVariables(@NonNull SDVariable... variables) {
|
||||
return requireVariables(Operation.TRAINING_VALIDATION, variables);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for inference
|
||||
*/
|
||||
public Builder inferenceVariables(@NonNull String... variables) {
|
||||
return requireVariables(Operation.INFERENCE, variables);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for inference
|
||||
*/
|
||||
public Builder inferenceVariables(@NonNull SDVariable... variables) {
|
||||
return requireVariables(Operation.INFERENCE, variables);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for evaluation
|
||||
*/
|
||||
public Builder evaluationVariables(@NonNull String... variables) {
|
||||
return requireVariables(Operation.EVALUATION, variables);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required variables for evaluation
|
||||
*/
|
||||
public Builder evaluationVariables(@NonNull SDVariable... variables) {
|
||||
return requireVariables(Operation.EVALUATION, variables);
|
||||
}
|
||||
|
||||
public ListenerVariables build() {
|
||||
return new ListenerVariables(trainingVariables, validationVariables, evaluationVariables, inferenceVariables);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,5 +1,8 @@
|
|||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -7,7 +10,7 @@ import org.nd4j.base.Preconditions;
|
|||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Loss class - represents the loss (score) for the network. Provides a breakdown of all the loss components
|
||||
* Loss class - represents the loss (score) for the network, for one iteration. Provides a breakdown of all the loss components
|
||||
*
|
||||
* @author Alex Black
|
||||
*/
|
||||
|
@ -70,4 +73,96 @@ public class Loss {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
public Loss copy() {
|
||||
return new Loss(lossNames, losses);
|
||||
}
|
||||
|
||||
public static Loss sum(List<Loss> losses) {
|
||||
|
||||
if (losses.size() == 0)
|
||||
return new Loss(Collections.<String>emptyList(), new double[0]);
|
||||
|
||||
double[] lossValues = new double[losses.get(0).losses.length];
|
||||
List<String> lossNames = new ArrayList<>(losses.get(0).lossNames);
|
||||
|
||||
for (int i = 0; i < losses.size(); i++) {
|
||||
Loss l = losses.get(i);
|
||||
Preconditions.checkState(l.losses.length == lossValues.length,
|
||||
"Loss %s has %s losses, the others before it had %s.", i, l.losses.length, lossValues.length);
|
||||
|
||||
Preconditions.checkState(l.lossNames.equals(lossNames),
|
||||
"Loss %s has different loss names from the others before it. Expected %s, got %s.",
|
||||
i, lossNames, l.lossNames);
|
||||
|
||||
for (int j = 0; j < lossValues.length; j++)
|
||||
lossValues[j] += l.losses[j];
|
||||
|
||||
}
|
||||
|
||||
return new Loss(lossNames, lossValues);
|
||||
}
|
||||
|
||||
public static Loss average(List<Loss> losses) {
|
||||
Loss sum = sum(losses);
|
||||
|
||||
for (int i = 0; i < sum.losses.length; i++) {
|
||||
sum.losses[i] /= losses.size();
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
public static Loss add(Loss a, Loss b) {
|
||||
Preconditions.checkState(a.lossNames.equals(b.lossNames),
|
||||
"Loss names differ. First loss has names %s, second has names %s.",
|
||||
a.lossNames, b.lossNames);
|
||||
|
||||
double[] lossValues = new double[a.losses.length];
|
||||
for (int i = 0; i < lossValues.length; i++)
|
||||
lossValues[i] = a.losses[i] + b.losses[i];
|
||||
|
||||
return new Loss(a.lossNames, lossValues);
|
||||
}
|
||||
|
||||
public static Loss sub(Loss a, Loss b) {
|
||||
Preconditions.checkState(a.lossNames.equals(b.lossNames),
|
||||
"Loss names differ. First loss has names %s, second has names %s.",
|
||||
a.lossNames, b.lossNames);
|
||||
|
||||
double[] lossValues = new double[a.losses.length];
|
||||
for (int i = 0; i < lossValues.length; i++)
|
||||
lossValues[i] = a.losses[i] - b.losses[i];
|
||||
|
||||
return new Loss(a.lossNames, lossValues);
|
||||
}
|
||||
|
||||
public static Loss div(Loss a, Number b) {
|
||||
double[] lossValues = new double[a.losses.length];
|
||||
for (int i = 0; i < lossValues.length; i++)
|
||||
lossValues[i] = a.losses[i] / b.doubleValue();
|
||||
|
||||
return new Loss(a.lossNames, lossValues);
|
||||
}
|
||||
|
||||
public Loss add(Loss other) {
|
||||
return add(this, other);
|
||||
}
|
||||
|
||||
public Loss sub(Loss other) {
|
||||
return sub(this, other);
|
||||
}
|
||||
|
||||
public Loss plus(Loss other) {
|
||||
return add(this, other);
|
||||
}
|
||||
|
||||
public Loss minus(Loss other) {
|
||||
return sub(this, other);
|
||||
}
|
||||
|
||||
public Loss div(Number other) {
|
||||
return div(this, other);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
/**
|
||||
* An enum representing the operation being done on a SameDiff graph.<br>
|
||||
* <p>
|
||||
* TRAINING: {@link SameDiff#fit()} methods training step (everything except validation)<br>
|
||||
* TRAINING_VALIDATION: the validation step during {@link SameDiff#fit()} methods - i.e., test/validation set evaluation,<br>
|
||||
* INFERENCE: {@link SameDiff#output()}, {@link SameDiff#batchOutput()} and {@link SameDiff#exec(Map, String...)} ()} methods,
|
||||
* including the single batch and placeholder ones. Also {@link SDVariable#eval()}<br>
|
||||
* EVALUATION: {@link SameDiff#evaluate()} methods<br>
|
||||
*/
|
||||
public enum Operation {
|
||||
/**
|
||||
* The training operation: {@link SameDiff#fit()} methods training step (everything except validation).
|
||||
*/
|
||||
TRAINING,
|
||||
/**
|
||||
* The training validation operation: the validation step during {@link SameDiff#fit()} methods.
|
||||
*/
|
||||
TRAINING_VALIDATION,
|
||||
/**
|
||||
* Inference operations: {@link SameDiff#output()}, {@link SameDiff#batchOutput()} and {@link SameDiff#exec(Map, String...)} ()} methods,
|
||||
* as well as {@link SameDiff#execBackwards(Map, Operation, String...)} methods.
|
||||
*/
|
||||
INFERENCE,
|
||||
/**
|
||||
* Evaluation operations: {@link SameDiff#evaluate()} methods.
|
||||
*/
|
||||
EVALUATION;
|
||||
|
||||
public boolean isTrainingPhase() {
|
||||
return this == TRAINING || this == TRAINING_VALIDATION;
|
||||
}
|
||||
|
||||
public boolean isValidation() {
|
||||
return this == TRAINING_VALIDATION;
|
||||
}
|
||||
|
||||
}
|
|
@ -7,13 +7,15 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.apache.commons.io.IOUtils;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.ListenerResponse;
|
||||
import org.nd4j.autodiff.listeners.Loss;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
import java.io.*;
|
||||
import java.nio.charset.Charset;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.*;
|
||||
|
@ -148,12 +150,18 @@ public class CheckpointListener extends BaseListener implements Serializable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void epochEnd(SameDiff sameDiff, At at) {
|
||||
public ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long epochTimeMillis) {
|
||||
if(saveEveryNEpochs != null && (at.epoch()+1) % saveEveryNEpochs == 0){
|
||||
//Save:
|
||||
saveCheckpoint(sameDiff, at);
|
||||
}
|
||||
//General saving conditions: don't need to check here - will check in iterationDone
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return operation == Operation.TRAINING;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -3,6 +3,7 @@ package org.nd4j.autodiff.listeners.debugging;
|
|||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -70,7 +71,12 @@ public class ExecDebuggingListener extends BaseListener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) {
|
||||
public boolean isActive(Operation operation) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||
if(lastIter != at.iteration()){
|
||||
lastIter = at.iteration();
|
||||
stepThisIter = 0;
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners.impl;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseEvaluationListener;
|
||||
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.listeners.ListenerEvaluations;
|
||||
import org.nd4j.autodiff.listeners.ListenerResponse;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* HistoryListener is mainly used internally to collect information such as the loss curve and evaluations,
|
||||
* which will be reported later in a {@link History} instance
|
||||
*/
|
||||
public class HistoryListener extends BaseEvaluationListener {
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
private ListenerEvaluations evaluations;
|
||||
|
||||
private List<EvaluationRecord> trainingHistory = new ArrayList<>();
|
||||
private List<EvaluationRecord> validationHistory = new ArrayList<>();
|
||||
private LossCurve loss = null;
|
||||
|
||||
private long startTime;
|
||||
private long endTime;
|
||||
|
||||
private List<Long> validationTimes = new ArrayList<>();
|
||||
private long validationStartTime;
|
||||
|
||||
|
||||
public HistoryListener(TrainingConfig tc) {
|
||||
this.evaluations = new ListenerEvaluations(tc.getTrainEvaluations(), tc.getTrainEvaluationLabels(),
|
||||
tc.getValidationEvaluations(), tc.getValidationEvaluationLabels());
|
||||
}
|
||||
|
||||
public HistoryListener(ListenerEvaluations evaluations) {
|
||||
this.evaluations = evaluations;
|
||||
}
|
||||
|
||||
public HistoryListener newInstance() {
|
||||
return new HistoryListener(evaluations);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerEvaluations evaluations() {
|
||||
return evaluations;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return operation.isTrainingPhase();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse epochEndEvaluations(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis, EvaluationRecord evaluations) {
|
||||
trainingHistory.add(evaluations);
|
||||
loss = lossCurve;
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ListenerResponse validationDoneEvaluations(SameDiff sd, At at, long validationTimeMillis, EvaluationRecord evaluations) {
|
||||
validationHistory.add(evaluations);
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationStart(SameDiff sd, Operation op) {
|
||||
if (op == Operation.TRAINING) {
|
||||
startTime = System.currentTimeMillis();
|
||||
} else if (op == Operation.TRAINING_VALIDATION) {
|
||||
validationStartTime = System.currentTimeMillis();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationEnd(SameDiff sd, Operation op) {
|
||||
if (op == Operation.TRAINING) {
|
||||
endTime = System.currentTimeMillis();
|
||||
} else if (op == Operation.TRAINING_VALIDATION) {
|
||||
validationTimes.add(System.currentTimeMillis() - validationStartTime);
|
||||
}
|
||||
}
|
||||
|
||||
public History getReport() {
|
||||
return new History(trainingHistory, validationHistory, loss, endTime - startTime, validationTimes);
|
||||
}
|
||||
|
||||
}
|
|
@ -3,7 +3,10 @@ package org.nd4j.autodiff.listeners.impl;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.ListenerResponse;
|
||||
import org.nd4j.autodiff.listeners.Loss;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
@ -32,7 +35,6 @@ public class ScoreListener extends BaseListener {
|
|||
private final boolean reportEpochs;
|
||||
private final boolean reportIterPerformance;
|
||||
|
||||
private long epochStart;
|
||||
private long epochExampleCount;
|
||||
private int epochBatchCount;
|
||||
private long etlTotalTimeEpoch;
|
||||
|
@ -72,10 +74,14 @@ public class ScoreListener extends BaseListener {
|
|||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return operation == Operation.TRAINING;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void epochStart(SameDiff sd, At at) {
|
||||
if (reportEpochs) {
|
||||
epochStart = System.currentTimeMillis();
|
||||
epochExampleCount = 0;
|
||||
epochBatchCount = 0;
|
||||
etlTotalTimeEpoch = 0;
|
||||
|
@ -85,17 +91,18 @@ public class ScoreListener extends BaseListener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void epochEnd(SameDiff sd, At at) {
|
||||
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
|
||||
if (reportEpochs) {
|
||||
long epochDuration = System.currentTimeMillis() - epochStart;
|
||||
double batchesPerSec = epochBatchCount / (epochDuration / 1000.0);
|
||||
double examplesPerSec = epochExampleCount / (epochDuration / 1000.0);
|
||||
double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochDuration;
|
||||
double batchesPerSec = epochBatchCount / (epochTimeMillis / 1000.0);
|
||||
double examplesPerSec = epochExampleCount / (epochTimeMillis / 1000.0);
|
||||
double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochTimeMillis;
|
||||
String etl = formatDurationMs(etlTotalTimeEpoch) + " ETL time" + (etlTotalTimeEpoch > 0 ? "(" + format2dp(pcEtl) + " %)" : "");
|
||||
log.info("Epoch {} complete on iteration {} - {} batches ({} examples) in {} - {} batches/sec, {} examples/sec, {}",
|
||||
at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochDuration),
|
||||
at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochTimeMillis),
|
||||
format2dp(batchesPerSec), format2dp(examplesPerSec), etl);
|
||||
}
|
||||
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -4,7 +4,10 @@ import com.google.flatbuffers.Table;
|
|||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.ListenerResponse;
|
||||
import org.nd4j.autodiff.listeners.Loss;
|
||||
import org.nd4j.autodiff.listeners.records.LossCurve;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
|
@ -279,13 +282,18 @@ public class UIListener extends BaseListener {
|
|||
writer.writeFinishStaticMarker();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return operation == Operation.TRAINING;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void epochStart(SameDiff sd, At at) {
|
||||
epochTrainEval = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void epochEnd(SameDiff sd, At at) {
|
||||
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
|
||||
|
||||
//If any training evaluation, report it here:
|
||||
if(epochTrainEval != null){
|
||||
|
@ -315,6 +323,7 @@ public class UIListener extends BaseListener {
|
|||
}
|
||||
|
||||
epochTrainEval = null;
|
||||
return ListenerResponse.CONTINUE;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -401,13 +410,13 @@ public class UIListener extends BaseListener {
|
|||
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
|
||||
|
||||
//Do training set evaluation, if required
|
||||
//Note we'll do it in opExecution not iterationDone because we can't be sure arrays will be stil be around in the future
|
||||
//i.e., we'll eventually add workspaces and clear activation arrays once they have been consumed
|
||||
if(training && trainEvalMetrics != null && trainEvalMetrics.size() > 0){
|
||||
if(at.operation() == Operation.TRAINING && trainEvalMetrics != null && trainEvalMetrics.size() > 0){
|
||||
long time = System.currentTimeMillis();
|
||||
|
||||
//First: check if this op is relevant at all to evaluation...
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners.records;
|
||||
|
||||
import com.google.common.base.Predicates;
|
||||
import com.google.common.collect.Collections2;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
|
||||
/**
|
||||
* A helper class to hold evaluations and provide methods to easily query them
|
||||
*/
|
||||
@Getter
|
||||
public class EvaluationRecord {
|
||||
|
||||
private ImmutableMap<String, List<IEvaluation>> evaluations;
|
||||
private Map<Class<? extends IEvaluation>, IEvaluation> classEvaluations = new HashMap<>();
|
||||
private boolean isEmpty = true;
|
||||
|
||||
public EvaluationRecord(Map<String, List<IEvaluation>> evaluations) {
|
||||
this.evaluations = ImmutableMap.copyOf(evaluations);
|
||||
|
||||
for (List<IEvaluation> le : evaluations.values()) {
|
||||
for (IEvaluation e : le) {
|
||||
isEmpty = false;
|
||||
if (classEvaluations.containsKey(e.getClass()))
|
||||
classEvaluations.remove(e.getClass());
|
||||
else
|
||||
classEvaluations.put(e.getClass(), e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private EvaluationRecord() {
|
||||
|
||||
}
|
||||
|
||||
public boolean isEmpty() {
|
||||
return isEmpty;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all evaluations
|
||||
*/
|
||||
public ImmutableMap<String, List<IEvaluation>> evaluations() {
|
||||
return evaluations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get evaluations for a given param/variable
|
||||
*
|
||||
* @param param The target param/variable
|
||||
*/
|
||||
public List<IEvaluation> evaluations(String param) {
|
||||
Preconditions.checkArgument(evaluations.containsKey(param),
|
||||
"No evaluations for %s.", param);
|
||||
|
||||
return evaluations.get(param);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get evaluations for a given param/variable
|
||||
*
|
||||
* @param param The target param/variable
|
||||
*/
|
||||
public List<IEvaluation> evaluations(SDVariable param) {
|
||||
return evaluations(param.getVarName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the evaluation for param at the specified index
|
||||
*/
|
||||
public IEvaluation evaluation(String param, int index) {
|
||||
return evaluations(param).get(index);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the evaluation for param at the specified index
|
||||
*/
|
||||
public IEvaluation evaluation(SDVariable param, int index) {
|
||||
return evaluation(param.getVarName(), index);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the evaluation for a given param/variable
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations for the param
|
||||
*
|
||||
* @param param The target param/variable
|
||||
*/
|
||||
public <T extends IEvaluation> T evaluation(String param) {
|
||||
Preconditions.checkArgument(evaluations.containsKey(param),
|
||||
"No evaluations for %s.", param);
|
||||
Preconditions.checkArgument(evaluations.get(param).size() == 1,
|
||||
"Multiple evaluations for %s. Use evaluations().", param);
|
||||
|
||||
return (T) evaluations.get(param).get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the evaluation for a given param/variable
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations for the param
|
||||
*
|
||||
* @param param The target param/variable
|
||||
*/
|
||||
public <T extends IEvaluation> T evaluation(SDVariable param) {
|
||||
return evaluation(param.getVarName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the evaluation of a given type
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations of that type
|
||||
*
|
||||
* @param evalClass The type of evaluation to look for
|
||||
*/
|
||||
public <T extends IEvaluation<T>> T evaluation(Class<T> evalClass) {
|
||||
Preconditions.checkArgument(classEvaluations.containsKey(evalClass),
|
||||
"Can't get evaluation for %s. Either no evaluations with that class are present, or more than one are.", evalClass);
|
||||
|
||||
return (T) classEvaluations.get(evalClass);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the evaluation of a given type, for a given param/variable
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations of that type for the given param
|
||||
*
|
||||
* @param param The target param/variable
|
||||
* @param evalClass The type of evaluation to look for
|
||||
*/
|
||||
public <T extends IEvaluation<T>> T evaluation(String param, Class<T> evalClass) {
|
||||
Collection<IEvaluation> evals = Collections2.filter(evaluations(param), Predicates.instanceOf(evalClass));
|
||||
|
||||
Preconditions.checkArgument(evals.size() == 1, "Multiple or no evaluations of type %s for param %s.", evalClass, param);
|
||||
|
||||
return (T) evals.iterator().next();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the evaluation of a given type, for a given param/variable
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations of that type for the given param
|
||||
*
|
||||
* @param param The target param/variable
|
||||
* @param evalClass The type of evaluation to look for
|
||||
*/
|
||||
public <T extends IEvaluation<T>> T evaluation(SDVariable param, Class<T> evalClass) {
|
||||
return evaluation(param.getVarName(), evalClass);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the metric's value for the evaluation of the metric's type
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations of that type
|
||||
*
|
||||
* @param metric The metric to calculate
|
||||
*/
|
||||
public double getValue(IMetric metric) {
|
||||
return evaluation(metric.getEvaluationClass()).getValue(metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the metric's value for the evaluation of the metric's type, for a given param/variable
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations of that type for the given param
|
||||
*
|
||||
* @param param The target param/variable
|
||||
* @param metric The metric to calculate
|
||||
*/
|
||||
public double getValue(String param, IMetric metric) {
|
||||
return evaluation(param, metric.getEvaluationClass()).getValue(metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the metric's value for the evaluation of the metric's type, for a given param/variable
|
||||
* <p>
|
||||
* Will throw an exception if there are more than one or no evaluations of that type for the given param
|
||||
*
|
||||
* @param param The target param/variable
|
||||
* @param metric The metric to calculate
|
||||
*/
|
||||
public double getValue(SDVariable param, IMetric metric) {
|
||||
return getValue(param.getVarName(), metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the metric's value for the evaluation for a given param/variable at the given index
|
||||
* <p>
|
||||
* Will throw an exception if the target evaluation doesn't support the given metric
|
||||
*
|
||||
* @param param The target param/variable
|
||||
* @param index The index of the target evaluation on the param
|
||||
* @param metric The metric to calculate
|
||||
*/
|
||||
public double getValue(String param, int index, IMetric metric) {
|
||||
return evaluation(param, index).getValue(metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the metric's value for the evaluation for a given param/variable at the given index
|
||||
* <p>
|
||||
* Will throw an exception if the target evaluation doesn't support the given metric
|
||||
*
|
||||
* @param param The target param/variable
|
||||
* @param index The index of the target evaluation on the param
|
||||
* @param metric The metric to calculate
|
||||
*/
|
||||
public double getValue(SDVariable param, int index, IMetric metric) {
|
||||
return getValue(param.getVarName(), index, metric);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,356 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners.records;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
|
||||
/**
|
||||
* An object containing training history for a SameDiff.fit call, such as {@link SameDiff#fit()}, {@link SameDiff#fit(DataSetIterator, int, Listener...)}, etc.<br>
|
||||
* Contains information including:<br>
|
||||
* - Evaluations performed (training set and test set)<br>
|
||||
* - Loss curve (score values at each iteration)<br>
|
||||
* - Training times, and validation times<br>
|
||||
* - Number of epochs performed<br>
|
||||
*/
|
||||
@Getter
|
||||
public class History {
|
||||
|
||||
private List<EvaluationRecord> trainingHistory;
|
||||
private List<EvaluationRecord> validationHistory;
|
||||
|
||||
private LossCurve lossCurve;
|
||||
|
||||
private long trainingTimeMillis;
|
||||
private List<Long> validationTimesMillis;
|
||||
|
||||
public History(List<EvaluationRecord> training, List<EvaluationRecord> validation, LossCurve loss,
|
||||
long trainingTimeMillis, List<Long> validationTimesMillis){
|
||||
trainingHistory = ImmutableList.copyOf(training);
|
||||
validationHistory = ImmutableList.copyOf(validation);
|
||||
this.lossCurve = loss;
|
||||
this.trainingTimeMillis = trainingTimeMillis;
|
||||
this.validationTimesMillis = ImmutableList.copyOf(validationTimesMillis);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the training evaluations
|
||||
*/
|
||||
public List<EvaluationRecord> trainingEval(){
|
||||
return trainingHistory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the validation evaluations
|
||||
*/
|
||||
public List<EvaluationRecord> validationEval(){
|
||||
return validationHistory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the loss curve
|
||||
*/
|
||||
public LossCurve lossCurve(){
|
||||
return lossCurve;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the total training time, in milliseconds
|
||||
*/
|
||||
public long trainingTimeMillis(){
|
||||
return trainingTimeMillis;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the total validation time, in milliseconds
|
||||
*/
|
||||
public List<Long> validationTimesMillis(){
|
||||
return validationTimesMillis;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of epochs trained for
|
||||
*/
|
||||
public int trainingEpochs(){
|
||||
return trainingHistory.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of epochs validation was ran on
|
||||
*/
|
||||
public int validationEpochs(){
|
||||
return validationHistory.size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter for a given metric
|
||||
*
|
||||
* Only works if there is only one evaluation with the given metric for param
|
||||
*/
|
||||
public List<Double> trainingEval(String param, IMetric metric){
|
||||
List<Double> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : trainingHistory)
|
||||
data.add(er.getValue(param, metric));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter for a given metric
|
||||
*
|
||||
* Only works if there is only one evaluation with the given metric for param
|
||||
*/
|
||||
public List<Double> trainingEval(SDVariable param, IMetric metric){
|
||||
return trainingEval(param.getVarName(), metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter at a given index, for a given metric
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<Double> trainingEval(String param, int index, IMetric metric){
|
||||
List<Double> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : trainingHistory)
|
||||
data.add(er.getValue(param, index, metric));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter at a given index, for a given metric
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<Double> trainingEval(SDVariable param, int index, IMetric metric){
|
||||
return trainingEval(param.getVarName(), index, metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation for a given metric
|
||||
*
|
||||
* Only works if there is only one evaluation with the given metric
|
||||
*/
|
||||
public List<Double> trainingEval(IMetric metric){
|
||||
List<Double> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : trainingHistory)
|
||||
data.add(er.getValue(metric));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter
|
||||
*
|
||||
* Only works if there is only one evaluation for param.
|
||||
*/
|
||||
public List<IEvaluation> trainingEval(String param){
|
||||
List<IEvaluation> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : trainingHistory)
|
||||
data.add(er.evaluation(param));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter
|
||||
*
|
||||
* Only works if there is only one evaluation for param.
|
||||
*/
|
||||
public List<IEvaluation> trainingEval(SDVariable param){
|
||||
return trainingEval(param.getVarName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter at a given index
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<IEvaluation> trainingEval(String param, int index){
|
||||
List<IEvaluation> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : trainingHistory)
|
||||
data.add(er.evaluation(param, index));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a training evaluation on a given parameter at a given index
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<IEvaluation> trainingEval(SDVariable param, int index){
|
||||
return trainingEval(param.getVarName(), index);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter for a given metric
|
||||
*
|
||||
* Only works if there is only one evaluation with the given metric for param
|
||||
*/
|
||||
public List<Double> validationEval(String param, IMetric metric){
|
||||
List<Double> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : validationHistory)
|
||||
data.add(er.getValue(param, metric));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter for a given metric
|
||||
*
|
||||
* Only works if there is only one evaluation with the given metric for param
|
||||
*/
|
||||
public List<Double> validationEval(SDVariable param, IMetric metric){
|
||||
return validationEval(param.getVarName(), metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter at a given index, for a given metric
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<Double> validationEval(String param, int index, IMetric metric){
|
||||
List<Double> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : validationHistory)
|
||||
data.add(er.getValue(param, index, metric));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter at a given index, for a given metric
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<Double> validationEval(SDVariable param, int index, IMetric metric){
|
||||
return validationEval(param.getVarName(), index, metric);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation for a given metric
|
||||
*
|
||||
* Only works if there is only one evaluation with the given metric
|
||||
*/
|
||||
public List<Double> validationEval(IMetric metric){
|
||||
List<Double> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : validationHistory)
|
||||
data.add(er.getValue(metric));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter
|
||||
*
|
||||
* Only works if there is only one evaluation for param.
|
||||
*/
|
||||
public List<IEvaluation> validationEval(String param){
|
||||
List<IEvaluation> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : validationHistory)
|
||||
data.add(er.evaluation(param));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter
|
||||
*
|
||||
* Only works if there is only one evaluation for param.
|
||||
*/
|
||||
public List<IEvaluation> validationEval(SDVariable param){
|
||||
return validationEval(param.getVarName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter at a given index
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<IEvaluation> validationEval(String param, int index){
|
||||
List<IEvaluation> data = new ArrayList<>();
|
||||
for(EvaluationRecord er : validationHistory)
|
||||
data.add(er.evaluation(param, index));
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the results of a validation evaluation on a given parameter at a given index
|
||||
*
|
||||
* Note that it returns all recorded evaluations.
|
||||
* Index determines the evaluation used not the epoch's results to return.
|
||||
*/
|
||||
public List<IEvaluation> validationEval(SDVariable param, int index){
|
||||
return validationEval(param.getVarName(), index);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the training evaluations ran during the last epoch
|
||||
*/
|
||||
public EvaluationRecord finalTrainingEvaluations(){
|
||||
return trainingHistory.get(trainingHistory.size() - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the validation evaluations ran during the last epoch
|
||||
*/
|
||||
public EvaluationRecord finalValidationEvaluations(){
|
||||
return validationHistory.get(validationHistory.size() - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the evaluation record for a given epoch.
|
||||
* @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end.
|
||||
*/
|
||||
public EvaluationRecord trainingEvaluations(int epoch){
|
||||
if(epoch >= 0){
|
||||
return trainingHistory.get(epoch);
|
||||
} else {
|
||||
return trainingHistory.get(trainingHistory.size() - epoch);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the evaluation record for a given epoch.
|
||||
* @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end.
|
||||
*/
|
||||
public EvaluationRecord validationEvaluations(int epoch){
|
||||
if(epoch >= 0){
|
||||
return trainingHistory.get(epoch);
|
||||
} else {
|
||||
return validationHistory.get(validationHistory.size() - epoch);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,211 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.listeners.records;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.listeners.Loss;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
public class LossCurve {
|
||||
@Getter
|
||||
private List<String> lossNames;
|
||||
@Getter
|
||||
private INDArray lossValues;
|
||||
|
||||
public LossCurve(List<Loss> losses){
|
||||
lossNames = ImmutableList.copyOf(losses.get(0).getLossNames());
|
||||
int numLossValues = losses.get(0).lossValues().length;
|
||||
lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length);
|
||||
|
||||
for(int i = 0 ; i < losses.size() ; i++){
|
||||
Loss l = losses.get(i);
|
||||
Preconditions.checkArgument(l.getLossNames().equals(lossNames),
|
||||
"Loss names for loss %s differ from others. Expected %s, got %s",
|
||||
i, lossNames, l.getLossNames());
|
||||
|
||||
Preconditions.checkArgument(l.getLosses().length == numLossValues,
|
||||
"Number of loss values for loss %s differ from others. Expected %s, got %s",
|
||||
i, numLossValues, l.getLosses().length);
|
||||
|
||||
lossValues = lossValues.putRow(i, Nd4j.createFromArray(l.getLosses()).castTo(DataType.FLOAT));
|
||||
}
|
||||
}
|
||||
|
||||
public LossCurve(double[] lossValues, List<String> lossNames){
|
||||
this.lossValues = Nd4j.createFromArray(new double[][]{ lossValues}).castTo(DataType.FLOAT);
|
||||
this.lossNames = lossNames;
|
||||
}
|
||||
|
||||
protected LossCurve(INDArray lossValues, List<String> lossNames){
|
||||
Preconditions.checkArgument(lossValues.rank() == 2, "lossValues must have a rank of 2, got %s", lossValues.rank());
|
||||
Preconditions.checkArgument(lossValues.dataType() == DataType.FLOAT, "lossValues must be type FLOAT, got %s", lossValues.dataType());
|
||||
this.lossValues = lossValues;
|
||||
this.lossNames = lossNames;
|
||||
}
|
||||
|
||||
public List<Loss> losses(){
|
||||
List<Loss> losses = new ArrayList<>();
|
||||
for(int i = 0 ; i < lossValues.size(0) ; i++){
|
||||
losses.add(new Loss(lossNames, lossValues.getRow(i).toDoubleVector()));
|
||||
}
|
||||
return losses;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the mean loss for a given epoch
|
||||
*
|
||||
* If epoch is negative, counts backwards from the end.
|
||||
* E.g. losses(-1) gets the last epoch.
|
||||
*
|
||||
* @param epoch The epoch to get. If negative, returns results for the epoch that many epochs from the end
|
||||
*/
|
||||
public Loss meanLoss(int epoch){
|
||||
if(epoch >= 0){
|
||||
return new Loss(lossNames, lossValues.getRow(epoch).toDoubleVector());
|
||||
} else {
|
||||
return new Loss(lossNames, lossValues.getRow(lossValues.rows() + epoch).toDoubleVector());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the mean loss for the last epoch.
|
||||
*/
|
||||
public Loss lastMeanLoss(){
|
||||
return meanLoss(-1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return all mean loss values for a given variable
|
||||
*/
|
||||
public float[] meanLoss(@NonNull String lossName){
|
||||
|
||||
int idx = lossNames.indexOf(lossName);
|
||||
|
||||
Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames);
|
||||
|
||||
float[] loss = new float[(int) lossValues.size(0)];
|
||||
for(int i = 0 ; i < lossValues.size(0) ; i++){
|
||||
loss[i] = lossValues.getFloat(i, idx);
|
||||
}
|
||||
return loss;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return all mean loss values for a given variable
|
||||
*/
|
||||
public float[] meanLoss(@NonNull SDVariable loss){
|
||||
return meanLoss(loss.getVarName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the mean loss value for a given variable on a given epoch.
|
||||
*
|
||||
* See {@link #meanLoss(int)}
|
||||
*/
|
||||
public float meanLoss(@NonNull String lossName, int epoch){
|
||||
|
||||
int idx = lossNames.indexOf(lossName);
|
||||
|
||||
Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames);
|
||||
|
||||
if(epoch >= 0) {
|
||||
return lossValues.getFloat(epoch, idx);
|
||||
} else {
|
||||
return lossValues.getFloat(lossValues.rows() + epoch, idx);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the mean loss value for a given variable on a given epoch.
|
||||
*
|
||||
* See {@link #meanLoss(int)}
|
||||
*/
|
||||
public float meanLoss(@NonNull SDVariable loss, int epoch){
|
||||
return meanLoss(loss.getVarName(), epoch);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the mean loss value for a given variable on the last epoch.
|
||||
*/
|
||||
public float lastMeanLoss(@NonNull String lossName){
|
||||
|
||||
int idx = lossNames.indexOf(lossName);
|
||||
|
||||
Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames);
|
||||
|
||||
return lossValues.getFloat(lossValues.rows() - 1, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the mean loss value for a given variable on the last epoch.
|
||||
*/
|
||||
public float lastMeanLoss(@NonNull SDVariable loss){
|
||||
return lastMeanLoss(loss.getVarName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the loss delta between the last epoch and the one before it.
|
||||
* Equivalent to meanLoss(-1) - meanLoss(-2).
|
||||
* A positive delta means the loss is increasing, and a negative delta means it is decreasing.
|
||||
*/
|
||||
public Loss lastMeanDelta(){
|
||||
return lastMeanLoss().sub(meanLoss(-2));
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the loss delta between the last epoch and the one before it, for a given variable.
|
||||
* Equivalent to meanLoss(-1) - meanLoss(-2).
|
||||
* A positive delta means the loss is increasing, and a negative delta means it is decreasing.
|
||||
*/
|
||||
public double lastMeanDelta(String lossName){
|
||||
return lastMeanDelta().getLoss(lossName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the loss delta between the last epoch and the one before it, for a given variable.
|
||||
* Equivalent to meanLoss(-1) - meanLoss(-2).
|
||||
* A positive delta means the loss is increasing, and a negative delta means it is decreasing.
|
||||
*/
|
||||
public double lastMeanDelta(SDVariable loss){
|
||||
return lastMeanDelta(loss.getVarName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new LossCurve with the given losses added on as the most recent epoch
|
||||
*/
|
||||
public LossCurve addLossAndCopy(Loss loss){
|
||||
return addLossAndCopy(loss.getLosses(), loss.lossNames());
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new LossCurve with the given losses added on as the most recent epoch
|
||||
*/
|
||||
public LossCurve addLossAndCopy(double[] values, List<String> lossNames){
|
||||
return new LossCurve(
|
||||
Nd4j.concat(0, lossValues,
|
||||
Nd4j.createFromArray(new double[][]{values}).castTo(DataType.FLOAT)),
|
||||
lossNames);
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -18,7 +18,9 @@ package org.nd4j.autodiff.samediff;
|
|||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.listeners.ListenerEvaluations;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.linalg.learning.regularization.L1Regularization;
|
||||
import org.nd4j.linalg.learning.regularization.L2Regularization;
|
||||
|
@ -62,6 +64,12 @@ public class TrainingConfig {
|
|||
private int iterationCount;
|
||||
private int epochCount;
|
||||
|
||||
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
|
||||
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
|
||||
|
||||
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
|
||||
private Map<String, Integer> validationEvaluationLabels = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Create a training configuration suitable for training a single input, single output network.<br>
|
||||
* See also the {@link Builder} for creating a TrainingConfig
|
||||
|
@ -106,6 +114,17 @@ public class TrainingConfig {
|
|||
this.lossVariables = lossVariables;
|
||||
}
|
||||
|
||||
protected TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping,
|
||||
List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables,
|
||||
Map<String, List<IEvaluation>> trainEvaluations, Map<String, Integer> trainEvaluationLabels,
|
||||
Map<String, List<IEvaluation>> validationEvaluations, Map<String, Integer> validationEvaluationLabels){
|
||||
this(updater, regularization, minimize, dataSetFeatureMapping, dataSetLabelMapping, dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables);
|
||||
this.trainEvaluations = trainEvaluations;
|
||||
this.trainEvaluationLabels = trainEvaluationLabels;
|
||||
this.validationEvaluations = validationEvaluations;
|
||||
this.validationEvaluationLabels = validationEvaluationLabels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the iteration count by 1
|
||||
*/
|
||||
|
@ -146,6 +165,12 @@ public class TrainingConfig {
|
|||
private boolean skipValidation = false;
|
||||
private boolean markLabelsUnused = false;
|
||||
|
||||
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
|
||||
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
|
||||
|
||||
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
|
||||
private Map<String, Integer> validationEvaluationLabels = new HashMap<>();
|
||||
|
||||
/**
|
||||
* Set the updater (such as {@link org.nd4j.linalg.learning.config.Adam}, {@link org.nd4j.linalg.learning.config.Nesterovs}
|
||||
* etc. This is also how the learning rate (or learning rate schedule) is set.
|
||||
|
@ -327,7 +352,7 @@ public class TrainingConfig {
|
|||
* Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the
|
||||
* DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2"
|
||||
* and the MultiDataSet features masks should be mapped with {@code MultiDataSet.getFeatureMaskArray(0)->"mask1"}
|
||||
* and {@code MultiDataSet.getFeatureMaskArray(1)->"mask2"}, then this should be set to {@code "mask2", "mask2"}.
|
||||
* and {@code MultiDataSet.getFeatureMaskArray(1)->"mask2"}, then this should be set to {@code "mask1", "mask2"}.
|
||||
*
|
||||
* @param dataSetFeatureMaskMapping Name of the variables/placeholders that the feature arrays should be mapped to
|
||||
*/
|
||||
|
@ -347,7 +372,7 @@ public class TrainingConfig {
|
|||
* Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the
|
||||
* DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2"
|
||||
* and the MultiDataSet label masks should be mapped with {@code MultiDataSet.getLabelMaskArray(0)->"mask1"}
|
||||
* and {@code MultiDataSet.getLabelMaskArray(1)->"mask2"}, then this should be set to {@code "mask2", "mask2"}.
|
||||
* and {@code MultiDataSet.getLabelMaskArray(1)->"mask2"}, then this should be set to {@code "mask1", "mask2"}.
|
||||
*
|
||||
* @param dataSetLabelMaskMapping Name of the variables/placeholders that the feature arrays should be mapped to
|
||||
*/
|
||||
|
@ -366,6 +391,104 @@ public class TrainingConfig {
|
|||
return this;
|
||||
}
|
||||
|
||||
private void addEvaluations(boolean validation, @NonNull Map<String, List<IEvaluation>> evaluationMap, @NonNull Map<String, Integer> labelMap,
|
||||
@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
if(evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex){
|
||||
String s;
|
||||
|
||||
if(validation){
|
||||
s = "This ListenerEvaluations.Builder already has validation evaluations for ";
|
||||
} else {
|
||||
s = "This ListenerEvaluations.Builder already has train evaluations for ";
|
||||
}
|
||||
|
||||
throw new IllegalArgumentException(s + "variable " +
|
||||
variableName + " with label index " + labelIndex + ". You can't add " +
|
||||
" evaluations with a different label index. Got label index " + labelIndex);
|
||||
}
|
||||
|
||||
if(evaluationMap.containsKey(variableName)){
|
||||
evaluationMap.get(variableName).addAll(Arrays.asList(evaluations));
|
||||
} else {
|
||||
evaluationMap.put(variableName, Arrays.asList(evaluations));
|
||||
labelMap.put(variableName, labelIndex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested History training evaluations for a parm/variable.
|
||||
*
|
||||
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
|
||||
*
|
||||
* @param variableName The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName,
|
||||
labelIndex, evaluations);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested History training evaluations for a parm/variable.
|
||||
*
|
||||
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
|
||||
*
|
||||
* @param variable The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
return trainEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested History validation evaluations for a parm/variable.
|
||||
*
|
||||
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
|
||||
*
|
||||
* @param variableName The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName,
|
||||
labelIndex, evaluations);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested History validation evaluations for a parm/variable.
|
||||
*
|
||||
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
|
||||
*
|
||||
* @param variable The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
return validationEvaluation(variable.getVarName(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add requested evaluations for a parm/variable, for either training or validation.
|
||||
*
|
||||
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
|
||||
*
|
||||
* @param validation Whether to add these evaluations as validation or training
|
||||
* @param variableName The variable to evaluate
|
||||
* @param labelIndex The index of the label to evaluate against
|
||||
* @param evaluations The evaluations to run
|
||||
*/
|
||||
public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
if(validation){
|
||||
return validationEvaluation(variableName, labelIndex, evaluations);
|
||||
} else{
|
||||
return trainEvaluation(variableName, labelIndex, evaluations);
|
||||
}
|
||||
}
|
||||
|
||||
public TrainingConfig build(){
|
||||
if(!skipValidation) {
|
||||
Preconditions.checkState(updater != null, "Updater (optimizer) must not be null. Use updater(IUpdater) to set an updater");
|
||||
|
@ -374,10 +497,20 @@ public class TrainingConfig {
|
|||
Preconditions.checkState(markLabelsUnused || dataSetLabelMapping != null, "No DataSet label mapping has been provided. A " +
|
||||
"mapping between DataSet array positions and variables/placeholders must be provided - use dataSetLabelMapping(...) to set this," +
|
||||
" or use markLabelsUnused() to mark labels as unused (for example, for unsupervised learning)");
|
||||
|
||||
|
||||
Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()),
|
||||
"Must specify a label index for each train evaluation. Expected: %s, got: %s",
|
||||
trainEvaluations.keySet(), trainEvaluationLabels.keySet());
|
||||
|
||||
Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()),
|
||||
"Must specify a label index for each validation evaluation. Expected: %s, got: %s",
|
||||
validationEvaluations.keySet(), validationEvaluationLabels.keySet());
|
||||
}
|
||||
|
||||
return new TrainingConfig(updater, regularization, minimize, dataSetFeatureMapping, dataSetLabelMapping,
|
||||
dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables);
|
||||
dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables,
|
||||
trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.samediff.config;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
/**
|
||||
* Configuration for a single batch {@link SameDiff} inference operation.
|
||||
*
|
||||
* Used in {@link SameDiff#batchOutput()}.
|
||||
*/
|
||||
@Getter
|
||||
@Setter
|
||||
public class BatchOutputConfig {
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private SameDiff sd;
|
||||
|
||||
@NonNull
|
||||
private List<String> outputs = new ArrayList<>();
|
||||
|
||||
private Map<String, INDArray> placeholders = new HashMap<>();
|
||||
|
||||
@NonNull
|
||||
private List<Listener> listeners = new ArrayList<>();
|
||||
|
||||
public BatchOutputConfig(@NonNull SameDiff sd){
|
||||
this.sd = sd;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required outputs
|
||||
*/
|
||||
public BatchOutputConfig output(@NonNull String... outputs){
|
||||
this.outputs.addAll(Arrays.asList(outputs));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required outputs
|
||||
*/
|
||||
public BatchOutputConfig output(@NonNull SDVariable... outputs){
|
||||
String[] outNames = new String[outputs.length];
|
||||
for(int i = 0 ; i < outputs.length ; i++){
|
||||
outNames[i] = outputs[i].getVarName();
|
||||
}
|
||||
|
||||
return output(outNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add all variables as required outputs
|
||||
*/
|
||||
public BatchOutputConfig outputAll(){
|
||||
return output(sd.variables().toArray(new SDVariable[0]));
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a placeholder value for a specified variable
|
||||
*/
|
||||
public BatchOutputConfig input(@NonNull String variable, @NonNull INDArray placeholder){
|
||||
Preconditions.checkState(!placeholders.containsKey(variable),
|
||||
"Placeholder for variable %s already specified", variable);
|
||||
|
||||
Preconditions.checkNotNull(sd.getVariable(variable),
|
||||
"Variable %s does not exist in this SameDiff graph", variable);
|
||||
|
||||
placeholders.put(variable, placeholder);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #input(String, INDArray)}
|
||||
*/
|
||||
public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){
|
||||
return input(variable.getVarName(), placeholder);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls {@link #input(String, INDArray)} on each entry in the map.
|
||||
*/
|
||||
public BatchOutputConfig inputs(Map<String, INDArray> placeholders){
|
||||
|
||||
if(placeholders == null) {
|
||||
this.placeholders = null;
|
||||
return this;
|
||||
}
|
||||
|
||||
for(Map.Entry<String, INDArray> e : placeholders.entrySet()){
|
||||
input(e.getKey(), e.getValue());
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add listeners for this operation
|
||||
*/
|
||||
public BatchOutputConfig listeners(@NonNull Listener... listeners){
|
||||
this.listeners.addAll(Arrays.asList(listeners));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Do inference and return the results
|
||||
*/
|
||||
public Map<String, INDArray> exec(){
|
||||
return sd.output(placeholders, listeners, outputs.toArray(new String[0]));
|
||||
}
|
||||
|
||||
/**
|
||||
* Do inference and return the results for the single output
|
||||
*
|
||||
* Only works if exactly one output is specified
|
||||
*/
|
||||
public INDArray execSingle(){
|
||||
Preconditions.checkState(outputs.size() == 1,
|
||||
"Can only use execSingle() when exactly one output is specified, there were %s", outputs.size());
|
||||
return exec().get(outputs.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,202 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.samediff.config;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
/**
|
||||
* Configuration for a {@link SameDiff} evaluation operation.
|
||||
*
|
||||
* Used in {@link SameDiff#evaluate()}.
|
||||
*/
|
||||
@Getter
|
||||
@Setter
|
||||
public class EvaluationConfig {
|
||||
|
||||
@NonNull
|
||||
private Map<String, List<IEvaluation>> evaluations = new HashMap<>();
|
||||
|
||||
@NonNull
|
||||
private Map<String, Integer> labelIndices = new HashMap<>();
|
||||
|
||||
private MultiDataSetIterator data;
|
||||
|
||||
@NonNull
|
||||
private List<Listener> listeners = new ArrayList<>();
|
||||
|
||||
private boolean singleInput = false;
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private SameDiff sd;
|
||||
|
||||
public EvaluationConfig(@NonNull SameDiff sd){
|
||||
this.sd = sd;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add evaluations to be preformed on a specified variable, and set that variable's label index.
|
||||
*
|
||||
* Setting a label index is required if using a MultiDataSetIterator.
|
||||
*
|
||||
* @param param The param to evaluate
|
||||
* @param labelIndex The label index of that parameter
|
||||
* @param evaluations The evaluations to preform
|
||||
*/
|
||||
public EvaluationConfig evaluate(@NonNull String param, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
return evaluate(param, evaluations).labelIndex(param, labelIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #evaluate(String, int, IEvaluation[])}
|
||||
*/
|
||||
public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
|
||||
return evaluate(variable.getVarName(), labelIndex, evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add evaluations to be preformed on a specified variable, without setting a label index.
|
||||
*
|
||||
* Setting a label index (which is not done here) is required if using a MultiDataSetIterator.
|
||||
*
|
||||
* @param param The param to evaluate
|
||||
* @param evaluations The evaluations to preform
|
||||
*/
|
||||
public EvaluationConfig evaluate(@NonNull String param, @NonNull IEvaluation... evaluations){
|
||||
if(this.evaluations.get(param) == null){
|
||||
this.evaluations.put(param, new ArrayList<IEvaluation>());
|
||||
}
|
||||
|
||||
this.evaluations.get(param).addAll(Arrays.asList(evaluations));
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* See {@link #evaluate(String, IEvaluation[])}
|
||||
*/
|
||||
public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){
|
||||
return evaluate(variable.getVarName(), evaluations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the label index for a parameter
|
||||
*/
|
||||
public EvaluationConfig labelIndex(@NonNull String param, int labelIndex){
|
||||
if(this.labelIndices.get(param) != null){
|
||||
int existingIndex = this.labelIndices.get(param);
|
||||
Preconditions.checkArgument(existingIndex == labelIndex,
|
||||
"Different label index already specified for param %s. Already specified: %s, given: %s",
|
||||
param, existingIndex, labelIndex);
|
||||
}
|
||||
|
||||
labelIndices.put(param, labelIndex);
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* See {@link #labelIndex(String, int)}
|
||||
*/
|
||||
public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){
|
||||
return labelIndex(variable.getVarName(), labelIndex);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add listeners for this operation
|
||||
*/
|
||||
public EvaluationConfig listeners(@NonNull Listener... listeners){
|
||||
this.listeners.addAll(Arrays.asList(listeners));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the data to evaluate on.
|
||||
*
|
||||
* Setting a label index for each variable to evaluate is required
|
||||
*/
|
||||
public EvaluationConfig data(@NonNull MultiDataSetIterator data){
|
||||
this.data = data;
|
||||
singleInput = false;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the data to evaluate on.
|
||||
*
|
||||
* Setting a label index for each variable to evaluate is NOT required (since there is only one input)
|
||||
*/
|
||||
public EvaluationConfig data(@NonNull DataSetIterator data){
|
||||
this.data = new MultiDataSetIteratorAdapter(data);
|
||||
singleInput = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
private void validateConfig(){
|
||||
Preconditions.checkNotNull(data, "Must specify data. It may not be null.");
|
||||
|
||||
if(!singleInput){
|
||||
for(String param : this.evaluations.keySet()){
|
||||
Preconditions.checkState(labelIndices.containsKey(param),
|
||||
"Using multiple input dataset iterator without specifying a label index for %s", param);
|
||||
}
|
||||
}
|
||||
|
||||
for(String param : this.evaluations.keySet()){
|
||||
Preconditions.checkState(sd.variableMap().containsKey(param),
|
||||
"Parameter %s not present in this SameDiff graph", param);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the evaluation.
|
||||
*
|
||||
* Note that the evaluations in the returned {@link EvaluationRecord} are the evaluations set using {@link #evaluate(String, int, IEvaluation[])},
|
||||
* it does not matter which you use to access results.
|
||||
*
|
||||
* @return The specified listeners, in an {@link EvaluationRecord} for easy access.
|
||||
*/
|
||||
public EvaluationRecord exec(){
|
||||
validateConfig();
|
||||
|
||||
if(singleInput){
|
||||
for(String param : this.evaluations.keySet()){
|
||||
labelIndices.put(param, 0);
|
||||
}
|
||||
}
|
||||
|
||||
sd.evaluate(data, evaluations, labelIndices, listeners.toArray(new Listener[0]));
|
||||
return new EvaluationRecord(evaluations);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,176 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.samediff.config;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
/**
|
||||
* Configuration for a {@link SameDiff} training operation.
|
||||
* <p>
|
||||
* Used in {@link SameDiff#fit()}.
|
||||
*/
|
||||
@Getter
|
||||
@Setter
|
||||
public class FitConfig {
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private SameDiff sd;
|
||||
|
||||
private MultiDataSetIterator trainingData;
|
||||
|
||||
private MultiDataSetIterator validationData = null;
|
||||
|
||||
private int epochs = -1;
|
||||
|
||||
private int validationFrequency = 1;
|
||||
|
||||
@NonNull
|
||||
private List<Listener> listeners = new ArrayList<>();
|
||||
|
||||
public FitConfig(@NonNull SameDiff sd) {
|
||||
this.sd = sd;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the number of epochs to train for
|
||||
*/
|
||||
public FitConfig epochs(int epochs) {
|
||||
this.epochs = epochs;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the training data
|
||||
*/
|
||||
public FitConfig train(@NonNull MultiDataSetIterator trainingData) {
|
||||
this.trainingData = trainingData;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the training data
|
||||
*/
|
||||
public FitConfig train(@NonNull DataSetIterator trainingData) {
|
||||
return train(new MultiDataSetIteratorAdapter(trainingData));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the training data and number of epochs
|
||||
*/
|
||||
public FitConfig train(@NonNull MultiDataSetIterator trainingData, int epochs) {
|
||||
return train(trainingData).epochs(epochs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the training data and number of epochs
|
||||
*/
|
||||
public FitConfig train(@NonNull DataSetIterator trainingData, int epochs) {
|
||||
return train(trainingData).epochs(epochs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the validation data
|
||||
*/
|
||||
public FitConfig validate(MultiDataSetIterator validationData) {
|
||||
this.validationData = validationData;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the validation data
|
||||
*/
|
||||
public FitConfig validate(DataSetIterator validationData) {
|
||||
if (validationData == null) {
|
||||
return validate((MultiDataSetIterator) null);
|
||||
} else {
|
||||
return validate(new MultiDataSetIteratorAdapter(validationData));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the validation frequency. Validation will be preformed once every so many epochs.
|
||||
* <p>
|
||||
* Specifically, validation will be preformed when i % validationFrequency == 0
|
||||
*/
|
||||
public FitConfig validationFrequency(int validationFrequency) {
|
||||
this.validationFrequency = validationFrequency;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the validation data and frequency
|
||||
* <p>
|
||||
* Specifically, validation will be preformed when i % validationFrequency == 0
|
||||
*/
|
||||
public FitConfig validate(MultiDataSetIterator validationData, int validationFrequency) {
|
||||
return validate(validationData).validationFrequency(validationFrequency);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the validation data and frequency
|
||||
* <p>
|
||||
* Specifically, validation will be preformed when i % validationFrequency == 0
|
||||
*/
|
||||
public FitConfig validate(DataSetIterator validationData, int validationFrequency) {
|
||||
return validate(validationData).validationFrequency(validationFrequency);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add listeners for this operation
|
||||
*/
|
||||
public FitConfig listeners(@NonNull Listener... listeners) {
|
||||
this.listeners.addAll(Arrays.asList(listeners));
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
private void validateConfig() {
|
||||
Preconditions.checkNotNull(trainingData, "Training data must not be null");
|
||||
Preconditions.checkState(epochs > 0, "Epochs must be > 0, got %s", epochs);
|
||||
|
||||
if (validationData != null)
|
||||
Preconditions.checkState(validationFrequency > 0, "Validation Frequency must be > 0 if validation data is given, got %s", validationFrequency);
|
||||
}
|
||||
|
||||
/**
|
||||
* Do the training.
|
||||
*
|
||||
* @return a {@link History} object containing the history information for this training operation
|
||||
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
|
||||
*/
|
||||
public History exec() {
|
||||
validateConfig();
|
||||
|
||||
return sd.fit(trainingData, epochs, validationData, validationFrequency, listeners.toArray(new Listener[0]));
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.samediff.config;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.util.TrainingUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||||
|
||||
/**
|
||||
* Configuration for a {@link SameDiff} inference operation.
|
||||
*
|
||||
* Used in {@link SameDiff#output()}.
|
||||
*/
|
||||
@Getter
|
||||
@Setter
|
||||
public class OutputConfig {
|
||||
|
||||
@Setter(AccessLevel.NONE)
|
||||
private SameDiff sd;
|
||||
|
||||
@NonNull
|
||||
private List<String> outputs = new ArrayList<>();
|
||||
|
||||
@NonNull
|
||||
private List<Listener> listeners = new ArrayList<>();
|
||||
|
||||
private MultiDataSetIterator data;
|
||||
|
||||
public OutputConfig(@NonNull SameDiff sd) {
|
||||
this.sd = sd;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required outputs
|
||||
*/
|
||||
public OutputConfig output(@NonNull String... outputs) {
|
||||
this.outputs.addAll(Arrays.asList(outputs));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add required outputs
|
||||
*/
|
||||
public OutputConfig output(@NonNull SDVariable... outputs) {
|
||||
String[] outNames = new String[outputs.length];
|
||||
for (int i = 0; i < outputs.length; i++) {
|
||||
outNames[i] = outputs[i].getVarName();
|
||||
}
|
||||
|
||||
return output(outNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the data to use as input.
|
||||
*/
|
||||
public OutputConfig data(@NonNull MultiDataSetIterator data) {
|
||||
this.data = data;
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the data to use as input.
|
||||
*/
|
||||
public OutputConfig data(@NonNull DataSetIterator data) {
|
||||
this.data = new MultiDataSetIteratorAdapter(data);
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add listeners for this operation
|
||||
*/
|
||||
public OutputConfig listeners(@NonNull Listener... listeners) {
|
||||
this.listeners.addAll(Arrays.asList(listeners));
|
||||
return this;
|
||||
}
|
||||
|
||||
private void validateConfig() {
|
||||
Preconditions.checkNotNull(data, "Must specify data. It may not be null.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Do inference and return the results.
|
||||
*
|
||||
* Uses concatenation on the outputs of {@link #execBatches()} which may cause issues with some inputs. RNNs with
|
||||
* variable time series length and CNNs with variable image sizes will most likely have issues.
|
||||
*/
|
||||
public Map<String, INDArray> exec() {
|
||||
return sd.output(data, listeners, outputs.toArray(new String[0]));
|
||||
}
|
||||
|
||||
/**
|
||||
* Do inference and return the results in batches.
|
||||
*/
|
||||
public List<Map<String, INDArray>> execBatches() {
|
||||
return sd.outputBatches(data, listeners, outputs.toArray(new String[0]));
|
||||
}
|
||||
|
||||
/**
|
||||
* Do inference and return the results for the single output variable specified.
|
||||
*
|
||||
* Only works if exactly one output is specified.
|
||||
*
|
||||
* Uses concatenation on the outputs of {@link #execBatches()} which may cause issues with some inputs. RNNs with
|
||||
* variable time series length and CNNs with variable image sizes will most likely have issues.
|
||||
*/
|
||||
public INDArray execSingle() {
|
||||
Preconditions.checkState(outputs.size() == 1,
|
||||
"Can only use execSingle() when exactly one output is specified, there were %s", outputs.size());
|
||||
|
||||
return sd.output(data, listeners, outputs.toArray(new String[0])).get(outputs.get(0));
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Do inference and return the results (in batches) for the single output variable specified.
|
||||
*
|
||||
* Only works if exactly one output is specified.
|
||||
*/
|
||||
public List<INDArray> execSingleBatches() {
|
||||
Preconditions.checkState(outputs.size() == 1,
|
||||
"Can only use execSingleBatches() when exactly one output is specified, there were %s", outputs.size());
|
||||
|
||||
return TrainingUtils
|
||||
.getSingleOutput(sd.outputBatches(data, listeners, outputs.toArray(new String[0])), outputs.get(0));
|
||||
}
|
||||
}
|
|
@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.Listener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.VariableType;
|
||||
|
@ -31,6 +32,7 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
|
||||
|
||||
import java.util.*;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
/**
|
||||
|
@ -133,16 +135,42 @@ public abstract class AbstractSession<T, O> {
|
|||
return newVarId(variable, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #output(List, Map, MultiDataSet, Collection, List, At)}.
|
||||
*
|
||||
* @param training Uses Operation.TRAINING if true, otherwise Operation.INFERENCE
|
||||
*/
|
||||
@Deprecated
|
||||
public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues,
|
||||
MultiDataSet batch, Collection<String> requiredActivations, boolean training, At at){
|
||||
if(at == null){
|
||||
if(training)
|
||||
at = At.defaultAt(Operation.TRAINING);
|
||||
else
|
||||
at = At.defaultAt(Operation.INFERENCE);
|
||||
}
|
||||
return output(variables, placeholderValues, batch, requiredActivations, Collections.<Listener>emptyList(), at);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the output of the session - i.e., perform inference/forward pass
|
||||
*
|
||||
* @param variables Name of the variables we want the arrays/activations for
|
||||
* @param placeholderValues The placeholder values (if any).
|
||||
* @param batch The batch data, used to call Listener.opExecution
|
||||
* @param requiredActivations Additional activations that are required. Won't be outputed, but opExecution will be called. May be null.
|
||||
* @return The specified variable values, optionally in the specified workspace
|
||||
*/
|
||||
public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues, List<Listener> listeners, boolean training, At at) {
|
||||
public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues,
|
||||
MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at) {
|
||||
|
||||
Preconditions.checkState(!variables.isEmpty(), "Variables to perform forward pass for must not be empty");
|
||||
|
||||
if(requiredActivations == null)
|
||||
requiredActivations = Collections.emptyList();
|
||||
|
||||
if(at == null)
|
||||
at = At.defaultAt();
|
||||
|
||||
//Step 0: validation - that variables exist, placeholders have arrays, etc
|
||||
for (String s : variables) {
|
||||
|
@ -164,7 +192,9 @@ public abstract class AbstractSession<T, O> {
|
|||
//Step 1: determine subgraph structure we actually need to execute
|
||||
//Basic plan: work backwards from the variables we want, based on the graph structure, to work out what
|
||||
// we actually need to execute
|
||||
initSubgraph(variables);
|
||||
List<String> allRequired = new ArrayList<>(requiredActivations);
|
||||
allRequired.addAll(variables);
|
||||
initSubgraph(allRequired);
|
||||
|
||||
//Step 1a: Check that we have required placeholders
|
||||
List<String> phNames = sameDiff.inputs();
|
||||
|
@ -198,7 +228,7 @@ public abstract class AbstractSession<T, O> {
|
|||
// Some Keras layers (like GRU) do different things depending on whether the model is training.
|
||||
// We provide this value directly.
|
||||
if(s.endsWith("keras_learning_phase")){
|
||||
placeholderValues.put(s, (T) Nd4j.scalar(training));
|
||||
placeholderValues.put(s, (T) Nd4j.scalar(at.operation().isTrainingPhase()));
|
||||
} else {
|
||||
throw new IllegalStateException(
|
||||
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
|
||||
|
@ -302,7 +332,7 @@ public abstract class AbstractSession<T, O> {
|
|||
//Execute op
|
||||
FrameIter frameIter = varToExec.toFrameIter();
|
||||
O parameterizedOp = getAndParameterizeOp(opName, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, placeholderValues);
|
||||
T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, training, at);
|
||||
T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, at, batch);
|
||||
|
||||
|
||||
//Post execution: work out what is now available for exec
|
||||
|
@ -831,7 +861,7 @@ public abstract class AbstractSession<T, O> {
|
|||
* @return The outputs of the op
|
||||
*/
|
||||
public abstract T[] getOutputs(O op, FrameIter outputFrameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs,
|
||||
List<Listener> listeners, boolean training, At at);
|
||||
List<Listener> listeners, At at, MultiDataSet batch);
|
||||
|
||||
/**
|
||||
* This method is used to record that the specified input is required for calculating the specified output.
|
||||
|
|
|
@ -30,6 +30,7 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* Infer datatypes for all variables.
|
||||
|
@ -80,7 +81,7 @@ public class DataTypesSession extends AbstractSession<DataType, DataTypesSession
|
|||
|
||||
@Override
|
||||
public DataType[] getOutputs(DataTypeCalc op, FrameIter outputFrameIter, Set<VarId> inputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, boolean training, At at) {
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
|
||||
List<DataType> outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes());
|
||||
|
||||
if(dynamicUpdate) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
package org.nd4j.autodiff.samediff.internal;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import lombok.NonNull;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
|
@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
|||
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.INDArrayIndex;
|
||||
|
@ -106,19 +108,34 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
|
|||
|
||||
@Override
|
||||
public INDArray[] getOutputs(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, boolean training, At at) {
|
||||
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
|
||||
if(listeners != null && listeners.size() > 0){
|
||||
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
|
||||
for(Listener l : listeners){
|
||||
l.preOpExecution(sameDiff, at, training, sdOp);
|
||||
if(l.isActive(at.operation()))
|
||||
l.preOpExecution(sameDiff, at, sdOp);
|
||||
}
|
||||
}
|
||||
|
||||
INDArray[] out = getOutputsHelper(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
|
||||
if(listeners != null && listeners.size() > 0){
|
||||
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
|
||||
|
||||
ImmutableMap.Builder<String, INDArray> namedOutsBuilder = ImmutableMap.builder();
|
||||
|
||||
for(int i = 0 ; i < out.length ; i++)
|
||||
namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]);
|
||||
|
||||
Map<String, INDArray> namedOuts = namedOutsBuilder.build();
|
||||
|
||||
for(Listener l : listeners){
|
||||
l.opExecution(sameDiff, at, training, sdOp, out);
|
||||
if(l.isActive(at.operation())) {
|
||||
l.opExecution(sameDiff, at, batch, sdOp, out);
|
||||
|
||||
for(String varName : namedOuts.keySet()){
|
||||
l.activationAvailable(sameDiff, at, batch, sdOp, varName, namedOuts.get(varName));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
|
|
|
@ -5,12 +5,14 @@ import lombok.NoArgsConstructor;
|
|||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.List;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
/**
|
||||
* A listener used for debugging and testing purposes - specifically for gradient checking activations internally in
|
||||
|
@ -29,7 +31,12 @@ public class ActivationGradientCheckListener extends BaseListener {
|
|||
private double eps;
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
|
||||
public boolean isActive(Operation operation) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener");
|
||||
Preconditions.checkState(eps != 0.0, "Epsilon has not been set");
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package org.nd4j.autodiff.validation.listeners;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.base.Preconditions;
|
||||
|
@ -14,6 +14,7 @@ import org.nd4j.linalg.api.ops.Op;
|
|||
import java.security.MessageDigest;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
public class NonInplaceValidationListener extends BaseListener {
|
||||
@Getter
|
||||
|
@ -30,7 +31,7 @@ public class NonInplaceValidationListener extends BaseListener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) {
|
||||
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
|
||||
if(op.getOp().isInPlace()){
|
||||
//Don't check inplace op
|
||||
return;
|
||||
|
@ -57,7 +58,7 @@ public class NonInplaceValidationListener extends BaseListener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
if(op.getOp().isInPlace()){
|
||||
//Don't check inplace op
|
||||
return;
|
||||
|
@ -124,4 +125,8 @@ public class NonInplaceValidationListener extends BaseListener {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -99,4 +99,14 @@ public interface IEvaluation<T extends IEvaluation> extends Serializable {
|
|||
* @return
|
||||
*/
|
||||
String toYaml();
|
||||
|
||||
/**
|
||||
* Get the value of a given metric for this evaluation.
|
||||
*/
|
||||
double getValue(IMetric metric);
|
||||
|
||||
/**
|
||||
* Get a new instance of this evaluation, with the same configuration but no data.
|
||||
*/
|
||||
T newInstance();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.evaluation;
|
||||
|
||||
/**
|
||||
* A metric used to get a double value from an {@link IEvaluation}.
|
||||
*
|
||||
* Examples: {@link org.nd4j.evaluation.classification.Evaluation.Metric#ACCURACY}, {@link org.nd4j.evaluation.classification.ROC.Metric#AUPRC}.
|
||||
*/
|
||||
public interface IMetric {
|
||||
|
||||
/**
|
||||
* The {@link IEvaluation} class this metric is for
|
||||
*/
|
||||
public Class<? extends IEvaluation> getEvaluationClass();
|
||||
|
||||
/**
|
||||
* Whether this metric should be minimized (aka whether lower values are better).
|
||||
*/
|
||||
public boolean minimize();
|
||||
}
|
|
@ -22,7 +22,11 @@ import org.nd4j.base.Preconditions;
|
|||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.EvaluationAveraging;
|
||||
import org.nd4j.evaluation.EvaluationUtils;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.evaluation.meta.Prediction;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.evaluation.serde.ConfusionMatrixDeserializer;
|
||||
import org.nd4j.evaluation.serde.ConfusionMatrixSerializer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -83,7 +87,18 @@ import java.util.*;
|
|||
@JsonIgnoreProperties({"confusionMatrixMetaData"})
|
||||
public class Evaluation extends BaseEvaluation<Evaluation> {
|
||||
|
||||
public enum Metric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC}
|
||||
public enum Metric implements IMetric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC;
|
||||
|
||||
@Override
|
||||
public Class<? extends IEvaluation> getEvaluationClass() {
|
||||
return Evaluation.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean minimize() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
//What to output from the precision/recall function when we encounter an edge case
|
||||
protected static final double DEFAULT_EDGE_VALUE = 0.0;
|
||||
|
@ -122,6 +137,17 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
|
|||
@Getter @Setter
|
||||
protected int maxWarningClassesToPrint = 16;
|
||||
|
||||
protected Evaluation(int axis, Integer binaryPositiveClass, int topN, List<String> labelsList,
|
||||
Double binaryDecisionThreshold, INDArray costArray, int maxWarningClassesToPrint){
|
||||
this.axis = axis;
|
||||
this.binaryPositiveClass = binaryPositiveClass;
|
||||
this.topN = topN;
|
||||
this.labelsList = labelsList;
|
||||
this.binaryDecisionThreshold = binaryDecisionThreshold;
|
||||
this.costArray = costArray;
|
||||
this.maxWarningClassesToPrint = maxWarningClassesToPrint;
|
||||
}
|
||||
|
||||
// Empty constructor
|
||||
public Evaluation() {
|
||||
this.topN = 1;
|
||||
|
@ -190,6 +216,7 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
|
|||
if (labels != null) {
|
||||
createConfusion(labels.size());
|
||||
}
|
||||
|
||||
this.topN = topN;
|
||||
if(labels != null && labels.size() == 2){
|
||||
this.binaryPositiveClass = 1;
|
||||
|
@ -1869,4 +1896,17 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
|
|||
public static Evaluation fromYaml(String yaml) {
|
||||
return fromYaml(yaml, Evaluation.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric){
|
||||
if(metric instanceof Metric){
|
||||
return scoreForMetric((Metric) metric);
|
||||
} else
|
||||
throw new IllegalStateException("Can't get value for non-evaluation Metric " + metric);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Evaluation newInstance() {
|
||||
return new Evaluation(axis, binaryPositiveClass, topN, labelsList, binaryDecisionThreshold, costArray, maxWarningClassesToPrint);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,9 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.EvaluationUtils;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
|
||||
|
@ -58,7 +61,18 @@ import java.util.List;
|
|||
@Data
|
||||
public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
||||
|
||||
public enum Metric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC, FAR}
|
||||
public enum Metric implements IMetric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC, FAR;
|
||||
|
||||
@Override
|
||||
public Class<? extends IEvaluation> getEvaluationClass() {
|
||||
return EvaluationBinary.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean minimize() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public static final int DEFAULT_PRECISION = 4;
|
||||
public static final double DEFAULT_EDGE_VALUE = 0.0;
|
||||
|
@ -80,6 +94,13 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
|||
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
|
||||
private INDArray decisionThreshold;
|
||||
|
||||
protected EvaluationBinary(int axis, ROCBinary rocBinary, List<String> labels, INDArray decisionThreshold){
|
||||
this.axis = axis;
|
||||
this.rocBinary = rocBinary;
|
||||
this.labels = labels;
|
||||
this.decisionThreshold = decisionThreshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an EvaulationBinary instance with an optional decision threshold array.
|
||||
*
|
||||
|
@ -452,10 +473,25 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
|||
}
|
||||
|
||||
/**
|
||||
* Calculate the G-measure for the given output
|
||||
* Macro average of the Matthews correlation coefficient (MCC) (see {@link #matthewsCorrelation(int)}) for all labels.
|
||||
*
|
||||
* @return The macro average of the MCC for all labels.
|
||||
*/
|
||||
public double averageMatthewsCorrelation() {
|
||||
double ret = 0.0;
|
||||
for (int i = 0; i < numLabels(); i++) {
|
||||
ret += matthewsCorrelation(i);
|
||||
}
|
||||
|
||||
ret /= (double) numLabels();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate the macro average G-measure for the given output
|
||||
*
|
||||
* @param output The specified output
|
||||
* @return The G-measure for the specified output
|
||||
* @return The macro average of the G-measure for the specified output
|
||||
*/
|
||||
public double gMeasure(int output) {
|
||||
double precision = precision(output);
|
||||
|
@ -463,6 +499,21 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
|||
return EvaluationUtils.gMeasure(precision, recall);
|
||||
}
|
||||
|
||||
/**
|
||||
* Average G-measure (see {@link #gMeasure(int)}) for all labels.
|
||||
*
|
||||
* @return The G-measure for all labels.
|
||||
*/
|
||||
public double averageGMeasure() {
|
||||
double ret = 0.0;
|
||||
for (int i = 0; i < numLabels(); i++) {
|
||||
ret += gMeasure(i);
|
||||
}
|
||||
|
||||
ret /= (double) numLabels();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the false positive rate for a given label
|
||||
*
|
||||
|
@ -679,5 +730,37 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
|
|||
return fromYaml(yaml, EvaluationBinary.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric){
|
||||
if(metric instanceof Metric){
|
||||
switch ((Metric) metric){
|
||||
case ACCURACY:
|
||||
return averageAccuracy();
|
||||
case F1:
|
||||
return averageF1();
|
||||
case PRECISION:
|
||||
return averagePrecision();
|
||||
case RECALL:
|
||||
return averageRecall();
|
||||
case GMEASURE:
|
||||
return averageGMeasure();
|
||||
case MCC:
|
||||
return averageMatthewsCorrelation();
|
||||
case FAR:
|
||||
return averageFalseAlarmRate();
|
||||
default:
|
||||
throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + metric);
|
||||
}
|
||||
} else
|
||||
throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + metric);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationBinary newInstance() {
|
||||
if(rocBinary != null) {
|
||||
return new EvaluationBinary(axis, rocBinary.newInstance(), labels, decisionThreshold);
|
||||
} else {
|
||||
return new EvaluationBinary(axis, null, labels, decisionThreshold);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ import lombok.Getter;
|
|||
import lombok.val;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||
import org.nd4j.evaluation.curves.Histogram;
|
||||
import org.nd4j.evaluation.curves.ReliabilityDiagram;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -105,6 +107,13 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
|
|||
@JsonDeserialize(using = NDArrayDeSerializer.class)
|
||||
private INDArray probHistogramByLabelClass; //Histogram - for each label class separately
|
||||
|
||||
protected EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins) {
|
||||
this.axis = axis;
|
||||
this.reliabilityDiagNumBins = reliabilityDiagNumBins;
|
||||
this.histogramNumBins = histogramNumBins;
|
||||
this.excludeEmptyBins = excludeEmptyBins;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an EvaluationCalibration instance with the default number of bins
|
||||
*/
|
||||
|
@ -476,4 +485,14 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
|
|||
public static EvaluationCalibration fromJson(String json){
|
||||
return fromJson(json, EvaluationCalibration.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric){
|
||||
throw new IllegalStateException("Can't get value for non-calibration Metric " + metric);
|
||||
}
|
||||
|
||||
@Override
|
||||
public EvaluationCalibration newInstance() {
|
||||
return new EvaluationCalibration(axis, reliabilityDiagNumBins, histogramNumBins, excludeEmptyBins);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,9 @@ import lombok.*;
|
|||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
||||
import org.nd4j.evaluation.curves.RocCurve;
|
||||
import org.nd4j.evaluation.serde.ROCSerializer;
|
||||
|
@ -77,11 +80,24 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
|
|||
@JsonSerialize(using = ROCSerializer.class)
|
||||
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
|
||||
public class ROC extends BaseEvaluation<ROC> {
|
||||
|
||||
/**
|
||||
* AUROC: Area under ROC curve<br>
|
||||
* AUPRC: Area under Precision-Recall Curve
|
||||
*/
|
||||
public enum Metric {AUROC, AUPRC}
|
||||
public enum Metric implements IMetric {
|
||||
AUROC, AUPRC;
|
||||
|
||||
@Override
|
||||
public Class<? extends IEvaluation> getEvaluationClass() {
|
||||
return ROC.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean minimize() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static final int DEFAULT_EXACT_ALLOC_BLOCK_SIZE = 2048;
|
||||
private final Map<Double, CountsForThreshold> counts = new LinkedHashMap<>();
|
||||
|
@ -100,6 +116,13 @@ public class ROC extends BaseEvaluation<ROC> {
|
|||
private int exactAllocBlockSize;
|
||||
protected int axis = 1;
|
||||
|
||||
|
||||
|
||||
public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize, int axis) {
|
||||
this(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize);
|
||||
this.axis = axis;
|
||||
}
|
||||
|
||||
public ROC() {
|
||||
//Default to exact
|
||||
this(0);
|
||||
|
@ -811,4 +834,17 @@ public class ROC extends BaseEvaluation<ROC> {
|
|||
throw new IllegalStateException("Unknown metric: " + metric);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric){
|
||||
if(metric instanceof Metric){
|
||||
return scoreForMetric((Metric) metric);
|
||||
} else
|
||||
throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ROC newInstance() {
|
||||
return new ROC(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize, axis);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,9 @@ import lombok.EqualsAndHashCode;
|
|||
import lombok.val;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass.Metric;
|
||||
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
||||
import org.nd4j.evaluation.curves.RocCurve;
|
||||
import org.nd4j.evaluation.serde.ROCArraySerializer;
|
||||
|
@ -53,7 +56,18 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> {
|
|||
* AUROC: Area under ROC curve<br>
|
||||
* AUPRC: Area under Precision-Recall Curve
|
||||
*/
|
||||
public enum Metric {AUROC, AUPRC}
|
||||
public enum Metric implements IMetric {AUROC, AUPRC;
|
||||
|
||||
@Override
|
||||
public Class<? extends IEvaluation> getEvaluationClass() {
|
||||
return ROCBinary.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean minimize() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@JsonSerialize(using = ROCArraySerializer.class)
|
||||
private ROC[] underlying;
|
||||
|
@ -65,6 +79,13 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> {
|
|||
@EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality
|
||||
protected int axis = 1;
|
||||
|
||||
protected ROCBinary(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels) {
|
||||
this.thresholdSteps = thresholdSteps;
|
||||
this.rocRemoveRedundantPts = rocRemoveRedundantPts;
|
||||
this.axis = axis;
|
||||
this.labels = labels;
|
||||
}
|
||||
|
||||
public ROCBinary() {
|
||||
this(0);
|
||||
}
|
||||
|
@ -410,4 +431,22 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> {
|
|||
throw new IllegalStateException("Unknown metric: " + metric);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric){
|
||||
if(metric instanceof Metric){
|
||||
if(metric == Metric.AUPRC)
|
||||
return calculateAverageAUCPR();
|
||||
else if(metric == Metric.AUROC)
|
||||
return calculateAverageAuc();
|
||||
else
|
||||
throw new IllegalStateException("Can't get value for non-binary ROC Metric " + metric);
|
||||
} else
|
||||
throw new IllegalStateException("Can't get value for non-binary ROC Metric " + metric);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ROCBinary newInstance() {
|
||||
return new ROCBinary(axis, thresholdSteps, rocRemoveRedundantPts, labels);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,9 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.evaluation.classification.ROC.Metric;
|
||||
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
|
||||
import org.nd4j.evaluation.curves.RocCurve;
|
||||
import org.nd4j.evaluation.serde.ROCArraySerializer;
|
||||
|
@ -49,7 +52,18 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
|
|||
* AUROC: Area under ROC curve<br>
|
||||
* AUPRC: Area under Precision-Recall Curve
|
||||
*/
|
||||
public enum Metric {AUROC, AUPRC}
|
||||
public enum Metric implements IMetric {AUROC, AUPRC;
|
||||
|
||||
@Override
|
||||
public Class<? extends IEvaluation> getEvaluationClass() {
|
||||
return ROCMultiClass.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean minimize() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private int thresholdSteps;
|
||||
private boolean rocRemoveRedundantPts;
|
||||
|
@ -60,6 +74,13 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
|
|||
@EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality
|
||||
protected int axis = 1;
|
||||
|
||||
protected ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels) {
|
||||
this.thresholdSteps = thresholdSteps;
|
||||
this.rocRemoveRedundantPts = rocRemoveRedundantPts;
|
||||
this.axis = axis;
|
||||
this.labels = labels;
|
||||
}
|
||||
|
||||
public ROCMultiClass() {
|
||||
//Default to exact
|
||||
this(0);
|
||||
|
@ -362,4 +383,22 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
|
|||
throw new IllegalStateException("Unknown metric: " + metric);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric){
|
||||
if(metric instanceof Metric){
|
||||
if(metric == Metric.AUPRC)
|
||||
return calculateAverageAUCPR();
|
||||
else if(metric == Metric.AUROC)
|
||||
return calculateAverageAUC();
|
||||
else
|
||||
throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
|
||||
} else
|
||||
throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ROCMultiClass newInstance() {
|
||||
return new ROCMultiClass(axis, thresholdSteps, rocRemoveRedundantPts, labels);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,182 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.evaluation.custom;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
import lombok.Getter;
|
||||
import lombok.NonNull;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* A evaluation using lambdas to calculate the score.
|
||||
*
|
||||
* Uses 3 lambdas:<br>
|
||||
* EvaluationLambda: takes in the labels, predictions, mask, and metadata and returns a value of type T<br>
|
||||
* MergeLambda: takes in two lists of Ts, returns one. Used in merging for distributed training<br>
|
||||
* ResultLambda (in Metric): takes a list of Ts, returns a double value<br>
|
||||
* <br>
|
||||
* The EvaluationLambda will be called on each batch, and the results will be stored in a list.
|
||||
* MergeLambda merges two of those lists for distributed training (think Spark or Map-Reduce).
|
||||
* ResultLambda gets a score from that list.
|
||||
*
|
||||
*/
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class CustomEvaluation<T> extends BaseEvaluation<CustomEvaluation> {
|
||||
|
||||
/**
|
||||
* The metric used to get a score for the CustomEvaluation. Uses a ResultLambda
|
||||
*/
|
||||
@AllArgsConstructor
|
||||
@RequiredArgsConstructor
|
||||
public static class Metric<T> implements IMetric{
|
||||
|
||||
@Getter
|
||||
@NonNull private ResultLambda<T> getResult;
|
||||
|
||||
private boolean minimize = false;
|
||||
|
||||
@Override
|
||||
public Class<? extends IEvaluation> getEvaluationClass() {
|
||||
return CustomEvaluation.class;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean minimize() {
|
||||
return minimize;
|
||||
}
|
||||
|
||||
/**
|
||||
* A metric that takes the average of a list of doubles
|
||||
*/
|
||||
public static Metric<Double> doubleAverage(boolean minimize){
|
||||
return new Metric<>(new ResultLambda<Double>() {
|
||||
@Override
|
||||
public double toResult(List<Double> data) {
|
||||
int count = 0;
|
||||
double sum = 0;
|
||||
for (Double d : data) {
|
||||
count++;
|
||||
sum += d;
|
||||
}
|
||||
return sum / count;
|
||||
}
|
||||
}, minimize);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A metric that takes the max of a list of doubles
|
||||
*/
|
||||
public static Metric<Double> doubleMax(boolean minimize){
|
||||
return new Metric<>(new ResultLambda<Double>() {
|
||||
@Override
|
||||
public double toResult(List<Double> data) {
|
||||
double max = 0;
|
||||
for (Double d : data) {
|
||||
if(d > max)
|
||||
max = d;
|
||||
}
|
||||
return max;
|
||||
}
|
||||
}, minimize);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A metric that takes the min of a list of doubles
|
||||
*/
|
||||
public static Metric<Double> doubleMin(boolean minimize){
|
||||
return new Metric<>(new ResultLambda<Double>() {
|
||||
@Override
|
||||
public double toResult(List<Double> data) {
|
||||
double max = 0;
|
||||
for (Double d : data) {
|
||||
if(d < max)
|
||||
max = d;
|
||||
}
|
||||
return max;
|
||||
}
|
||||
}, minimize);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A MergeLambda that merges by concatenating the two lists
|
||||
*/
|
||||
public static <R> MergeLambda<R> mergeConcatenate(){
|
||||
return new MergeLambda<R>() {
|
||||
@Override
|
||||
public List<R> merge(List<R> a, List<R> b) {
|
||||
List<R> res = Lists.newArrayList(a);
|
||||
res.addAll(b);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@NonNull private EvaluationLambda<T> evaluationLambda;
|
||||
@NonNull private MergeLambda<T> mergeLambda;
|
||||
|
||||
private List<T> evaluations = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray,
|
||||
List<? extends Serializable> recordMetaData) {
|
||||
evaluations.add(evaluationLambda.eval(labels, networkPredictions, maskArray, recordMetaData));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void merge(CustomEvaluation other) {
|
||||
evaluations = mergeLambda.merge(evaluations, other.evaluations);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void reset() {
|
||||
evaluations = new ArrayList<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String stats() {
|
||||
return "";
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric) {
|
||||
if(metric instanceof Metric){
|
||||
return ((Metric<T>) metric).getGetResult().toResult(evaluations);
|
||||
} else
|
||||
throw new IllegalStateException("Can't get value for non-regression Metric " + metric);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CustomEvaluation<T> newInstance() {
|
||||
return new CustomEvaluation<T>(evaluationLambda, mergeLambda);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.evaluation.custom;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* A lambda used to get an evaluation result for a batch
|
||||
* See {@link CustomEvaluation}
|
||||
*/
|
||||
public interface EvaluationLambda<T> {
|
||||
public T eval(INDArray labels, INDArray networkPredictions, INDArray maskArray,
|
||||
List<? extends Serializable> recordMetaData);
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.evaluation.custom;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A lambda used to merge two lists of evaluation results
|
||||
* See {@link CustomEvaluation}
|
||||
*/
|
||||
public interface MergeLambda<T> {
|
||||
public List<T> merge(List<T> a, List<T> b);
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.evaluation.custom;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A lambda used to get a score from a list of evaluation results
|
||||
* See {@link CustomEvaluation}
|
||||
*/
|
||||
public interface ResultLambda<T> {
|
||||
public double toResult(List<T> data);
|
||||
}
|
|
@ -20,6 +20,8 @@ import lombok.Data;
|
|||
import lombok.EqualsAndHashCode;
|
||||
import lombok.val;
|
||||
import org.nd4j.evaluation.BaseEvaluation;
|
||||
import org.nd4j.evaluation.IEvaluation;
|
||||
import org.nd4j.evaluation.IMetric;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
|
||||
|
@ -54,12 +56,18 @@ import java.util.List;
|
|||
@EqualsAndHashCode(callSuper = true)
|
||||
public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
|
||||
|
||||
public enum Metric { MSE, MAE, RMSE, RSE, PC, R2;
|
||||
public enum Metric implements IMetric { MSE, MAE, RMSE, RSE, PC, R2;
|
||||
|
||||
@Override
|
||||
public Class<? extends IEvaluation> getEvaluationClass() {
|
||||
return RegressionEvaluation.class;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return True if the metric should be minimized, or false if the metric should be maximized.
|
||||
* For example, MSE of 0 is best, but R^2 of 1.0 is best
|
||||
*/
|
||||
@Override
|
||||
public boolean minimize(){
|
||||
if(this == R2 || this == PC){
|
||||
return false;
|
||||
|
@ -106,6 +114,12 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
|
|||
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
|
||||
private INDArray sumLabels;
|
||||
|
||||
protected RegressionEvaluation(int axis, List<String> columnNames, long precision){
|
||||
this.axis = axis;
|
||||
this.columnNames = columnNames;
|
||||
this.precision = precision;
|
||||
}
|
||||
|
||||
public RegressionEvaluation() {
|
||||
this(null, DEFAULT_PRECISION);
|
||||
}
|
||||
|
@ -568,6 +582,14 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
|
|||
return ret / (double) numColumns();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getValue(IMetric metric){
|
||||
if(metric instanceof Metric){
|
||||
return scoreForMetric((Metric) metric);
|
||||
} else
|
||||
throw new IllegalStateException("Can't get value for non-regression Metric " + metric);
|
||||
}
|
||||
|
||||
public double scoreForMetric(Metric metric){
|
||||
switch (metric){
|
||||
case MSE:
|
||||
|
@ -590,4 +612,9 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
|
|||
public static RegressionEvaluation fromJson(String json){
|
||||
return fromJson(json, RegressionEvaluation.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RegressionEvaluation newInstance() {
|
||||
return new RegressionEvaluation(axis, columnNames, precision);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -72,7 +72,8 @@ public class TestSessions extends BaseNd4jTest {
|
|||
m.put("x", x);
|
||||
m.put("y", y);
|
||||
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList("out"), m, null, true, null);
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList("out"), m, null,
|
||||
Collections.<String>emptyList(), true, null);
|
||||
|
||||
assertEquals(1, outMap.size());
|
||||
assertEquals(outExp, outMap.get("out"));
|
||||
|
@ -109,7 +110,8 @@ public class TestSessions extends BaseNd4jTest {
|
|||
m.put("y", y);
|
||||
|
||||
System.out.println("----------------------------------");
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList("d"), m, null, false, null);
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList("d"), m, null,
|
||||
Collections.<String>emptyList(), false, null);
|
||||
|
||||
assertEquals(1, outMap.size());
|
||||
assertEquals(dExp, outMap.get("d"));
|
||||
|
@ -143,7 +145,8 @@ public class TestSessions extends BaseNd4jTest {
|
|||
InferenceSession is = new InferenceSession(sd);
|
||||
// String outName = merge.getVarName();
|
||||
String outName = outVar.getVarName();
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList(outName), m, null, false, null);
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList(outName), m, null,
|
||||
Collections.<String>emptyList(), false, null);
|
||||
|
||||
assertEquals(1, outMap.size());
|
||||
INDArray out = outMap.get(outName);
|
||||
|
@ -178,7 +181,8 @@ public class TestSessions extends BaseNd4jTest {
|
|||
String n = merge.getVarName();
|
||||
|
||||
System.out.println("----------------------------------");
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList(n), m, null, false, null);
|
||||
Map<String,INDArray> outMap = is.output(Collections.singletonList(n), m, null, Collections.<String>emptyList(),
|
||||
false, null);
|
||||
assertEquals(1, outMap.size());
|
||||
assertEquals(expTrue, outMap.get(n));
|
||||
|
||||
|
@ -187,7 +191,7 @@ public class TestSessions extends BaseNd4jTest {
|
|||
//Check false case:
|
||||
bArr.assign(0);
|
||||
is = new InferenceSession(sd);
|
||||
outMap = is.output(Collections.singletonList(n), m, null, false, null);
|
||||
outMap = is.output(Collections.singletonList(n), m, null, Collections.<String>emptyList(), false, null);
|
||||
assertEquals(1, outMap.size());
|
||||
assertEquals(expFalse, outMap.get(n));
|
||||
}
|
||||
|
@ -218,7 +222,8 @@ public class TestSessions extends BaseNd4jTest {
|
|||
String n = "while/Exit";
|
||||
String n2 = "while/Exit_1";
|
||||
|
||||
Map<String, INDArray> m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null, false, null);
|
||||
Map<String, INDArray> m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null,
|
||||
Collections.<String>emptyList(), false, null);
|
||||
assertEquals(2, m.size());
|
||||
|
||||
INDArray exp = Nd4j.scalar((float)numIter);
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.autodiff.samediff.listeners;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.listeners.impl.ScoreListener;
|
||||
import org.nd4j.autodiff.listeners.records.History;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.TrainingConfig;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.Evaluation.Metric;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.IrisDataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.IUpdater;
|
||||
import org.nd4j.weightinit.impl.XavierInitScheme;
|
||||
|
||||
public class ListenerTest extends BaseNd4jTest {
|
||||
|
||||
public ListenerTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void irisHistoryTest() {
|
||||
|
||||
DataSetIterator iter = new IrisDataSetIterator(150, 150);
|
||||
NormalizerStandardize std = new NormalizerStandardize();
|
||||
std.fit(iter);
|
||||
iter.setPreProcessor(std);
|
||||
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
SameDiff sd = SameDiff.create();
|
||||
|
||||
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
|
||||
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
|
||||
|
||||
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
|
||||
SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
|
||||
|
||||
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
|
||||
SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
|
||||
|
||||
SDVariable z0 = in.mmul(w0).add(b0);
|
||||
SDVariable a0 = sd.nn().relu(z0, 0);
|
||||
SDVariable z1 = a0.mmul(w1).add(b1);
|
||||
SDVariable predictions = sd.nn().softmax("predictions", z1, 1);
|
||||
|
||||
SDVariable loss = sd.loss.softmaxCrossEntropy("loss", label, predictions);
|
||||
|
||||
sd.setLossVariables("loss");
|
||||
|
||||
IUpdater updater = new Adam(1e-2);
|
||||
|
||||
Evaluation e = new Evaluation();
|
||||
|
||||
TrainingConfig conf = new TrainingConfig.Builder()
|
||||
.l2(1e-4)
|
||||
.updater(updater)
|
||||
.dataSetFeatureMapping("input")
|
||||
.dataSetLabelMapping("label")
|
||||
.trainEvaluation(predictions, 0, e)
|
||||
.build();
|
||||
|
||||
sd.setTrainingConfig(conf);
|
||||
|
||||
sd.setListeners(new ScoreListener(1));
|
||||
|
||||
History hist = sd.fit(iter, 50);
|
||||
// Map<String, List<IEvaluation>> evalMap = new HashMap<>();
|
||||
// evalMap.put("prediction", Collections.singletonList(e));
|
||||
//
|
||||
// sd.evaluateMultiple(iter, evalMap);
|
||||
|
||||
e = (Evaluation) hist.finalTrainingEvaluations().evaluation(predictions);
|
||||
|
||||
System.out.println(e.stats());
|
||||
|
||||
float[] losses = hist.lossCurve().meanLoss(loss);
|
||||
|
||||
System.out.println("Losses: " + Arrays.toString(losses));
|
||||
|
||||
double acc = hist.finalTrainingEvaluations().getValue(Metric.ACCURACY);
|
||||
assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.evaluation;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import java.util.Arrays;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.custom.CustomEvaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.linalg.primitives.Pair;
|
||||
|
||||
public class CustomEvaluationTest extends BaseNd4jTest {
|
||||
|
||||
public CustomEvaluationTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void customEvalTest(){
|
||||
CustomEvaluation accuracyEval = new CustomEvaluation<Pair<Number, Long>>(
|
||||
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
|
||||
CustomEvaluation.mergeConcatenate());
|
||||
|
||||
accuracyEval.eval(Nd4j.createFromArray(1, 1, 2, 1, 3), Nd4j.createFromArray(1, 1, 4, 1, 2));
|
||||
|
||||
double acc = accuracyEval.getValue(new CustomEvaluation.Metric<Pair<Number, Long>>(
|
||||
(list) -> {
|
||||
int sum = 0;
|
||||
int count = 0;
|
||||
for(Pair<Number, Long> p : list){
|
||||
sum += p.getFirst().intValue();
|
||||
count += p.getSecond();
|
||||
}
|
||||
return ((double) (sum)) / count;
|
||||
}
|
||||
));
|
||||
|
||||
assertEquals("Accuracy", acc, 3.0/5, 0.001);
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -3,6 +3,7 @@ package org.nd4j.evaluation;
|
|||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.classification.*;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
|
@ -40,7 +41,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
|
|||
RegressionEvaluation re = new RegressionEvaluation();
|
||||
re.stats();
|
||||
|
||||
for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) {
|
||||
for (Metric m : Metric.values()) {
|
||||
try {
|
||||
re.scoreForMetric(m);
|
||||
} catch (Throwable t){
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
/*
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
package org.nd4j.evaluation;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.classification.Evaluation;
|
||||
import org.nd4j.evaluation.classification.EvaluationBinary;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
import org.nd4j.evaluation.classification.ROC;
|
||||
import org.nd4j.evaluation.classification.ROCBinary;
|
||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
public class NewInstanceTest extends BaseNd4jTest {
|
||||
|
||||
public NewInstanceTest(Nd4jBackend backend) {
|
||||
super(backend);
|
||||
}
|
||||
|
||||
@Override
|
||||
public char ordering() {
|
||||
return 'c';
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNewInstances() {
|
||||
boolean print = true;
|
||||
Nd4j.getRandom().setSeed(12345);
|
||||
|
||||
Evaluation evaluation = new Evaluation();
|
||||
EvaluationBinary evaluationBinary = new EvaluationBinary();
|
||||
ROC roc = new ROC(2);
|
||||
ROCBinary roc2 = new ROCBinary(2);
|
||||
ROCMultiClass roc3 = new ROCMultiClass(2);
|
||||
RegressionEvaluation regressionEvaluation = new RegressionEvaluation();
|
||||
EvaluationCalibration ec = new EvaluationCalibration();
|
||||
|
||||
|
||||
IEvaluation[] arr = new IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec};
|
||||
|
||||
INDArray evalLabel1 = Nd4j.create(10, 3);
|
||||
for (int i = 0; i < 10; i++) {
|
||||
evalLabel1.putScalar(i, i % 3, 1.0);
|
||||
}
|
||||
INDArray evalProb1 = Nd4j.rand(10, 3);
|
||||
evalProb1.diviColumnVector(evalProb1.sum(1));
|
||||
|
||||
evaluation.eval(evalLabel1, evalProb1);
|
||||
roc3.eval(evalLabel1, evalProb1);
|
||||
ec.eval(evalLabel1, evalProb1);
|
||||
|
||||
INDArray evalLabel2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5));
|
||||
INDArray evalProb2 = Nd4j.rand(10, 3);
|
||||
evaluationBinary.eval(evalLabel2, evalProb2);
|
||||
roc2.eval(evalLabel2, evalProb2);
|
||||
|
||||
INDArray evalLabel3 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5));
|
||||
INDArray evalProb3 = Nd4j.rand(10, 1);
|
||||
roc.eval(evalLabel3, evalProb3);
|
||||
|
||||
INDArray reg1 = Nd4j.rand(10, 3);
|
||||
INDArray reg2 = Nd4j.rand(10, 3);
|
||||
|
||||
regressionEvaluation.eval(reg1, reg2);
|
||||
|
||||
Evaluation evaluation2 = evaluation.newInstance();
|
||||
EvaluationBinary evaluationBinary2 = evaluationBinary.newInstance();
|
||||
ROC roc_2 = roc.newInstance();
|
||||
ROCBinary roc22 = roc2.newInstance();
|
||||
ROCMultiClass roc32 = roc3.newInstance();
|
||||
RegressionEvaluation regressionEvaluation2 = regressionEvaluation.newInstance();
|
||||
EvaluationCalibration ec2 = ec.newInstance();
|
||||
|
||||
IEvaluation[] arr2 = new IEvaluation[] {evaluation2, evaluationBinary2, roc_2, roc22, roc32, regressionEvaluation2, ec2};
|
||||
|
||||
evaluation2.eval(evalLabel1, evalProb1);
|
||||
roc32.eval(evalLabel1, evalProb1);
|
||||
ec2.eval(evalLabel1, evalProb1);
|
||||
|
||||
evaluationBinary2.eval(evalLabel2, evalProb2);
|
||||
roc22.eval(evalLabel2, evalProb2);
|
||||
|
||||
roc_2.eval(evalLabel3, evalProb3);
|
||||
|
||||
regressionEvaluation2.eval(reg1, reg2);
|
||||
|
||||
for (int i = 0 ; i < arr.length ; i++) {
|
||||
|
||||
IEvaluation e = arr[i];
|
||||
IEvaluation e2 = arr2[i];
|
||||
assertEquals("Json not equal ", e.toJson(), e2.toJson());
|
||||
assertEquals("Stats not equal ", e.stats(), e2.stats());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -17,8 +17,8 @@
|
|||
package org.nd4j.evaluation;
|
||||
|
||||
import org.junit.Test;
|
||||
import org.nd4j.evaluation.classification.EvaluationCalibration;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
||||
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
||||
import org.nd4j.linalg.BaseNd4jTest;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
||||
|
@ -256,7 +256,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
e3d.eval(label, prediction);
|
||||
e2d.eval(l2d, p2d);
|
||||
|
||||
for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) {
|
||||
for (Metric m : Metric.values()) {
|
||||
double d1 = e3d.scoreForMetric(m);
|
||||
double d2 = e2d.scoreForMetric(m);
|
||||
assertEquals(m.toString(), d2, d1, 1e-6);
|
||||
|
@ -288,7 +288,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
e4d.eval(label, prediction);
|
||||
e2d.eval(l2d, p2d);
|
||||
|
||||
for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) {
|
||||
for (Metric m : Metric.values()) {
|
||||
double d1 = e4d.scoreForMetric(m);
|
||||
double d2 = e2d.scoreForMetric(m);
|
||||
assertEquals(m.toString(), d2, d1, 1e-6);
|
||||
|
@ -347,7 +347,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
RegressionEvaluation e2d_m2 = new RegressionEvaluation();
|
||||
e4d_m2.eval(label, prediction, perOutMask);
|
||||
e2d_m2.eval(l2d, p2d, m2d);
|
||||
for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){
|
||||
for(Metric m : Metric.values()){
|
||||
double d1 = e4d_m2.scoreForMetric(m);
|
||||
double d2 = e2d_m2.scoreForMetric(m);
|
||||
assertEquals(m.toString(), d2, d1, 1e-6);
|
||||
|
@ -382,7 +382,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
RegressionEvaluation e2d_m1 = new RegressionEvaluation();
|
||||
e4d_m1.eval(label, prediction, mask1dPerEx);
|
||||
e2d_m1.eval(l2d, p2d);
|
||||
for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){
|
||||
for(Metric m : Metric.values()){
|
||||
double d1 = e4d_m1.scoreForMetric(m);
|
||||
double d2 = e2d_m1.scoreForMetric(m);
|
||||
assertEquals(m.toString(), d2, d1, 1e-6);
|
||||
|
@ -409,7 +409,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
|
|||
RegressionEvaluation e2d_m2 = new RegressionEvaluation();
|
||||
e4d_m2.eval(label, prediction, perOutMask);
|
||||
e2d_m2.eval(l2d, p2d, m2d);
|
||||
for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){
|
||||
for(Metric m : Metric.values()){
|
||||
double d1 = e4d_m2.scoreForMetric(m);
|
||||
double d2 = e2d_m2.scoreForMetric(m);
|
||||
assertEquals(m.toString(), d2, d1, 1e-6);
|
||||
|
|
|
@ -4,11 +4,13 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.*;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
|
||||
public class OpExecOrderListener extends BaseListener {
|
||||
|
||||
|
@ -22,7 +24,7 @@ public class OpExecOrderListener extends BaseListener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
String opName = op.getName();
|
||||
if(!opSet.contains(opName)){
|
||||
opNamesList.add(opName);
|
||||
|
@ -30,4 +32,8 @@ public class OpExecOrderListener extends BaseListener {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,9 +20,11 @@ import lombok.NonNull;
|
|||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.autodiff.listeners.At;
|
||||
import org.nd4j.autodiff.listeners.BaseListener;
|
||||
import org.nd4j.autodiff.listeners.Operation;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -30,6 +32,11 @@ import java.io.File;
|
|||
@Slf4j
|
||||
public class ImportDebugListener extends BaseListener {
|
||||
|
||||
@Override
|
||||
public boolean isActive(Operation operation) {
|
||||
return true;
|
||||
}
|
||||
|
||||
public enum OnFailure {EXCEPTION, LOG};
|
||||
|
||||
private File baseDir;
|
||||
|
@ -49,7 +56,7 @@ public class ImportDebugListener extends BaseListener {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
|
||||
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
|
||||
//No op
|
||||
|
||||
for( int i=0; i<outputs.length; i++ ) {
|
||||
|
|
Loading…
Reference in New Issue