SameDiff: Listener changes and training api update (#99)

* example api

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Lambda based evaluation

Signed-off-by: Ryan Nett <rnett@skymind.io>

* lambda test

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* partial fixes, use get-variable listener framework, example EvaluationListener

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc fix and newInstance implementations

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fit and evaluate methods with validation data (for fit) and listeners

Signed-off-by: Ryan Nett <rnett@skymind.io>

* output method overloads + listener args

Signed-off-by: Ryan Nett <rnett@skymind.io>

* history and evaluation helpers

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* FitConfig and added getters and setters

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadocs

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes, javadoc, added activations to history, added latest activation listener

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes, start of tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes and updates

Signed-off-by: Ryan Nett <rnett@skymind.io>

* newInstance fixes, tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* test fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadocs, getters with SDVariable overrides, CustomEvaluation fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests

Signed-off-by: Ryan Nett <rnett@skymind.io>

* merge fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fix

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes, most old fit/evaluate/output methods use the builders

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* numerous fixes/cleanup

Signed-off-by: Ryan Nett <rnett@skymind.io>

* fixes

Signed-off-by: Ryan Nett <rnett@skymind.io>

* javadoc

Signed-off-by: Ryan Nett <rnett@skymind.io>

* Polish round 1

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 2

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Formatting + round 3

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 4

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Ryan Nett 2019-08-09 22:30:31 -07:00 committed by Alex Black
parent 6ed03217b4
commit 11bddb3825
55 changed files with 5180 additions and 788 deletions

View File

@ -21,12 +21,13 @@ import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
/**
* Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph
* on a test set. Supports all regression metrics: {@link RegressionEvaluation.Metric}
* on a test set. Supports all regression metrics: {@link Metric}
*
* @author Alex Black
*/
@ -35,13 +36,13 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
@NoArgsConstructor(access = AccessLevel.PROTECTED) //For JSON
public class RegressionScoreFunction extends BaseNetScoreFunction {
protected RegressionEvaluation.Metric metric;
protected Metric metric;
public RegressionScoreFunction(@NonNull org.deeplearning4j.eval.RegressionEvaluation.Metric metric) {
this(metric.toNd4j());
}
public RegressionScoreFunction(@NonNull RegressionEvaluation.Metric metric) {
public RegressionScoreFunction(@NonNull Metric metric) {
this.metric = metric;
}

View File

@ -51,7 +51,7 @@ import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -107,7 +107,7 @@ public class TestEarlyStopping extends BaseDL4JTest {
min = false;
break;
case 3:
sc = new RegressionScoreCalculator(RegressionEvaluation.Metric.MSE, irisIter);
sc = new RegressionScoreCalculator(Metric.MSE, irisIter);
min = true;
break;
case 4:
@ -561,8 +561,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
@Test
public void testRegressionScoreFunctionSimple() throws Exception {
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
@ -604,8 +604,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
@Test
public void testAEScoreFunctionSimple() throws Exception {
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
@ -647,8 +647,8 @@ public class TestEarlyStopping extends BaseDL4JTest {
@Test
public void testVAEScoreFunctionSimple() throws Exception {
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

View File

@ -43,7 +43,7 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.dataset.DataSet;
@ -289,8 +289,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
@Test
public void testRegressionScoreFunctionSimple() throws Exception {
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
@ -335,8 +335,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
public void testAEScoreFunctionSimple() throws Exception {
DataType dt = Nd4j.defaultFloatingPointType();
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
@ -380,8 +380,8 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest {
@Test
public void testVAEScoreFunctionSimple() throws Exception {
for(RegressionEvaluation.Metric metric : new RegressionEvaluation.Metric[]{RegressionEvaluation.Metric.MSE,
RegressionEvaluation.Metric.MAE}) {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@ -30,16 +31,16 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
/**
* Score function for a MultiLayerNetwork or ComputationGraph with a single
* {@link org.deeplearning4j.nn.conf.layers.AutoEncoder} layer.
* Calculates the specified {@link RegressionEvaluation.Metric} on the layer's reconstructions.
* Calculates the specified {@link Metric} on the layer's reconstructions.
*
* @author Alex Black
*/
public class AutoencoderScoreCalculator extends BaseScoreCalculator<Model> {
protected final RegressionEvaluation.Metric metric;
protected final Metric metric;
protected RegressionEvaluation evaluation;
public AutoencoderScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){
public AutoencoderScoreCalculator(Metric metric, DataSetIterator iterator){
super(iterator);
this.metric = metric;
}

View File

@ -19,19 +19,20 @@ package org.deeplearning4j.earlystopping.scorecalc;
import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
/**
* Calculate the regression score of the network (MultiLayerNetwork or ComputationGraph) on a test set, using the
* specified regression metric - {@link RegressionEvaluation.Metric}
* specified regression metric - {@link Metric}
*
* @author Alex Black
*/
public class RegressionScoreCalculator extends BaseIEvaluationScoreCalculator<Model, RegressionEvaluation> {
protected final RegressionEvaluation.Metric metric;
protected final Metric metric;
public RegressionScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator){
public RegressionScoreCalculator(Metric metric, DataSetIterator iterator){
super(iterator);
this.metric = metric;
}

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
@ -35,7 +36,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
*/
public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<Model> {
protected final RegressionEvaluation.Metric metric;
protected final Metric metric;
protected RegressionEvaluation evaluation;
/**
@ -44,7 +45,7 @@ public class VAEReconErrorScoreCalculator extends BaseScoreCalculator<Model> {
* @param metric
* @param iterator
*/
public VAEReconErrorScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator) {
public VAEReconErrorScoreCalculator(Metric metric, DataSetIterator iterator) {
super(iterator);
this.metric = metric;
}

View File

@ -20,6 +20,22 @@ public class At {
private int iteration;
private int trainingThreadNum;
private long javaThreadNum;
private Operation operation;
/**
* @return A new instance with everything set to 0, and operation set to INFERENCE
*/
public static At defaultAt(){
return new At(0, 0, 0, 0, Operation.INFERENCE);
}
/**
* @param op Operation
* @return A new instance with everything set to 0, except for the specified operation
*/
public static At defaultAt(@NonNull Operation op){
return new At(0, 0, 0, 0, op);
}
/**
* @return The current training epoch
@ -48,4 +64,26 @@ public class At {
public long javaThreadNum(){
return javaThreadNum;
}
/**
* @return The current operation
*/
public Operation operation(){
return operation;
}
/**
* @return A copy of the current At instance
*/
public At copy(){
return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation);
}
/**
* @param operation Operation to set in the new instance
* @return A copy of the current instance, but with the specified operation
*/
public At copy(Operation operation){
return new At(epoch, iteration, trainingThreadNum, javaThreadNum, operation);
}
}

View File

@ -0,0 +1,151 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* A base listener class that will preform the provided evaluations, and provide the results in epochEnd and validationDone
*
* Instead of overriding requiredVariables, epochStart, epochEnd, validationDone, and/or opExecution,
* override otherRequiredVariables, epochStartEvaluations, epochEndEvaluations, validationDoneEvaluations, and/or opExecutionEvaluations
*
* <strong>If you want to use Evaluations in your listener, extend this class</strong>
*/
public abstract class BaseEvaluationListener extends BaseListener {
private Map<String, List<IEvaluation>> trainingEvaluations = new HashMap<>();
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
/**
* Return the requested evaluations. New instances of these evaluations will be made each time they are used
*/
public abstract ListenerEvaluations evaluations();
@Override
public final ListenerVariables requiredVariables(SameDiff sd) {
return evaluations().requiredVariables().merge(otherRequiredVariables(sd));
}
/**
* Return any requested variables that are not part of the evaluations
*/
public ListenerVariables otherRequiredVariables(SameDiff sd){
return ListenerVariables.empty();
}
@Override
public final void epochStart(SameDiff sd, At at) {
trainingEvaluations = new HashMap<>();
for(Map.Entry<String, List<IEvaluation>> entry : evaluations().trainEvaluations().entrySet()){
List<IEvaluation> evals = new ArrayList<>();
for(IEvaluation ie : entry.getValue())
evals.add(ie.newInstance());
trainingEvaluations.put(entry.getKey(), evals);
}
validationEvaluations = new HashMap<>();
for(Map.Entry<String, List<IEvaluation>> entry : evaluations().validationEvaluations().entrySet()){
List<IEvaluation> evals = new ArrayList<>();
for(IEvaluation ie : entry.getValue())
evals.add(ie.newInstance());
validationEvaluations.put(entry.getKey(), evals);
}
epochStartEvaluations(sd, at);
}
/**
* See {@link Listener#epochStart(SameDiff, At)}
*/
public void epochStartEvaluations(SameDiff sd, At at){
//No op
}
@Override
public final ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
return epochEndEvaluations(sd, at, lossCurve, epochTimeMillis, new EvaluationRecord(trainingEvaluations));
}
/**
* See {@link Listener#epochEnd(SameDiff, At, LossCurve, long)}, also provided the requested evaluations
*/
public ListenerResponse epochEndEvaluations(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis, EvaluationRecord evaluations) {
//No op
return ListenerResponse.CONTINUE;
}
@Override
public final ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
return validationDoneEvaluations(sd, at, validationTimeMillis, new EvaluationRecord(validationEvaluations));
}
/**
* See {@link Listener#validationDone(SameDiff, At, long)}, also provided the requested evaluations
*/
public ListenerResponse validationDoneEvaluations(SameDiff sd, At at, long validationTimeMillis, EvaluationRecord evaluations) {
//No op
return ListenerResponse.CONTINUE;
}
@Override
public final void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
INDArray activation) {
if(at.operation() == Operation.TRAINING) {
if (trainingEvaluations.containsKey(varName)) {
INDArray labels = batch.getLabels(evaluations().trainEvaluationLabels().get(varName));
INDArray mask = batch.getLabelsMaskArray(evaluations().trainEvaluationLabels().get(varName));
for (IEvaluation e : trainingEvaluations.get(varName))
e.eval(labels, activation, mask);
}
} else if(at.operation() == Operation.TRAINING_VALIDATION) {
if (validationEvaluations.containsKey(varName)) {
INDArray labels = batch.getLabels(evaluations().validationEvaluationLabels().get(varName));
INDArray mask = batch.getLabelsMaskArray(evaluations().validationEvaluationLabels().get(varName));
for (IEvaluation e : validationEvaluations.get(varName))
e.eval(labels, activation, mask);
}
}
activationAvailableEvaluations(sd, at, batch, op, varName, activation);
}
/**
* See {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)}
*/
public void activationAvailableEvaluations(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
INDArray activation){
//No op
}
}

View File

@ -1,6 +1,6 @@
package org.nd4j.autodiff.listeners;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
@ -11,18 +11,32 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
* A base/abstract {@link Listener} with all methods implemented as no-op.
* Extend this for custom listeners to selectively override only the required methods
*
* <strong>If you want to use evaluations in your listener, use {@link BaseEvaluationListener}</strong>
*
* @author Alex Black
*/
public abstract class BaseListener implements Listener {
@Override
public ListenerVariables requiredVariables(SameDiff sd){
return ListenerVariables.empty();
}
@Override
public void epochStart(SameDiff sd, At at) {
//No op
}
@Override
public void epochEnd(SameDiff sd, At at) {
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
//No op
return ListenerResponse.CONTINUE;
}
@Override
@ -36,12 +50,28 @@ public abstract class BaseListener implements Listener {
}
@Override
public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) {
public void operationStart(SameDiff sd, Operation op) {
//No op
}
@Override
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
public void operationEnd(SameDiff sd, Operation op) {
//No op
}
@Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
//No op
}
@Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
//No op
}
@Override
public void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName,
INDArray activation) {
//No op
}

View File

@ -1,6 +1,6 @@
package org.nd4j.autodiff.listeners;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
@ -11,10 +11,29 @@ import org.nd4j.linalg.dataset.api.MultiDataSet;
* A {@link SameDiff} listener interface that is called during every iteration of training or inference
*
* @author Alex Black
* @see BaseListener BaseListener, for extending
* @see BaseListener BaseListener, for extending only the required methods (all others are no-op)
* @see BaseEvaluationListener BaseEvaluationListener, for extending if you want to use evaluations
*/
public interface Listener {
/**
* Required variables for this listener.
* <p>
* Used to ensure these variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}.
* Otherwise, if the variables weren't required by a loss variable, they would not be calculated.
* <p>
* Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)}
* called for them, regardless of whether they would normally be calculated or not.
*/
ListenerVariables requiredVariables(SameDiff sd);
/**
* Returns whether this listener is active during the given operation. If this returns false for the given operation,
* those listener methods will not be called.
*/
boolean isActive(Operation operation);
/**
* Called at the start of every epoch, when fitting from an iterator
*
@ -28,8 +47,21 @@ public interface Listener {
*
* @param sd The SameDiff instance
* @param at Current iteration/epoch etc
* @param lossCurve The losses so far
* @param epochTimeMillis How long this epoch took
* @return ListenerResponse.STOP to stop training, CONTINUE or null to continue
*/
void epochEnd(SameDiff sd, At at);
ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis);
/**
* Called after the end of every epoch, once validation evaluation is done, when training
*
* @param sd The SameDiff instance
* @param at Current iteration/epoch etc
* @param validationTimeMillis How long validation took for this epoch
* @return ListenerResponse.STOP to stop training, CONTINUE or null to continue
*/
ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis);
/**
* Called at the start of every iteration (minibatch), before any operations have been executed
@ -45,10 +77,26 @@ public interface Listener {
* @param sd The SameDiff instance
* @param at Current iteration/epoch etc
* @param dataSet The current dataset (minibatch) used for training
* @param loss The loss value for the current minibatch
* @param loss The loss value for the current minibatch. Will be null except for during training
*/
void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss);
/**
* Called at the start of an operation, e.g. training or validation
*
* @param sd The SameDiff instance
* @param op The operation being started
*/
void operationStart(SameDiff sd, Operation op);
/**
* Called at the end of an operation, e.g. training or validation
*
* @param sd The SameDiff instance
* @param op The operation being started
*/
void operationEnd(SameDiff sd, Operation op);
/**
* Called just before each operation is executed (native code called, etc) - after all inputs etc have been set
*
@ -56,20 +104,43 @@ public interface Listener {
* @param at Current iteration/epoch etc
* @param op Operation that has just been executed
*/
void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op);
void preOpExecution(SameDiff sd, At at, SameDiffOp op);
/**
* Called at the end of each operation execution
* Called at the end of each operation execution<br>
* <p>
* Note: Outputs will most likely be freed later, use detach() if you need to save it.
*
* @param sd The SameDiff instance
* @param at Current iteration/epoch etc
* @param batch The batch's input data. May be null if not called with a batch
* @param op Operation that has just been executed
* @param outputs The output arrays for the just-executed operation
*/
void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs);
void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs);
/**
* Called just before each parameter is to be updated - i.e., just before each parameter is modified
* Called when any activation becomes available.
* <p>
* The activation will most likely be freed later, use detach() if you need to save it.<br>
* <br>
* Note that this method will be called when any activation becomes available, not just ones from {@link #requiredVariables(SameDiff)}<br>
* It is guaranteed to be called for variables from requiredVariables().<br>
* <br>
* Note that the activations here overlap with {@link #opExecution(SameDiff, At, MultiDataSet, SameDiffOp, INDArray[])} -
* both contain the same information/arrays
*
* @param sd The SameDiff instance
* @param at Current iteration/epoch etc
* @param batch The batch's input data. May be null if not called with a batch
* @param op Operation that has just been executed
* @param varName The name of the variable
* @param activation The variable's activation
*/
void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation);
/**
* Called just before each parameter is to be updated - i.e., just before each parameter is modified.
*
* @param sd SameDiff instance
* @param at The current iteration/epoch etc

View File

@ -0,0 +1,228 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
/**
* A class to allow Listeners to define what evaluations they need to run during training<br>
* <p>
* Usage example - does classification ({@link org.nd4j.evaluation.classification.Evaluation}) evaluation on
* the training set (as training proceeds) and also Evaluation/ROCMultiClass evaluation on the test/validation set.
* Assumes that the output predictions are called "softmax" and the first DataSet/MultiDataSet labels are those corresponding
* to the "softmax" node
* <pre>{@code
* ListenerEvaluations.builder()
* //trainEvaluations: on the training data (in-line, as training proceeds through the epoch)
* .trainEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
* //validationEvaluation: on the test/validation data, at the end of each epoch
* .validationEvaluation("softmax", 0, new Evaluation(), new ROCMultiClass())
* .build();
* }</pre>
*/
@Getter
public class ListenerEvaluations {
private Map<String, List<IEvaluation>> trainEvaluations;
private Map<String, Integer> trainEvaluationLabels;
private Map<String, List<IEvaluation>> validationEvaluations;
private Map<String, Integer> validationEvaluationLabels;
public ListenerEvaluations(Map<String, List<IEvaluation>> trainEvaluations,
Map<String, Integer> trainEvaluationLabels, Map<String, List<IEvaluation>> validationEvaluations,
Map<String, Integer> validationEvaluationLabels) {
this.trainEvaluations = trainEvaluations;
this.trainEvaluationLabels = trainEvaluationLabels;
this.validationEvaluations = validationEvaluations;
this.validationEvaluationLabels = validationEvaluationLabels;
Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()),
"Must specify a label index for each train evaluation. Expected: %s, got: %s",
trainEvaluations.keySet(), trainEvaluationLabels.keySet());
Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()),
"Must specify a label index for each validation evaluation. Expected: %s, got: %s",
validationEvaluations.keySet(), validationEvaluationLabels.keySet());
}
private ListenerEvaluations() {
}
public static Builder builder() {
return new Builder();
}
/**
* Get the requested training evaluations
*/
public Map<String, List<IEvaluation>> trainEvaluations() {
return trainEvaluations;
}
/**
* Get the label indices for the requested training evaluations
*/
public Map<String, Integer> trainEvaluationLabels() {
return trainEvaluationLabels;
}
/**
* Get the requested validation evaluations
*/
public Map<String, List<IEvaluation>> validationEvaluations() {
return validationEvaluations;
}
/**
* Get the label indices for the requested validation evaluations
*/
public Map<String, Integer> validationEvaluationLabels() {
return validationEvaluationLabels;
}
/**
* Get the required variables for these evaluations
*/
public ListenerVariables requiredVariables() {
return new ListenerVariables(trainEvaluations.keySet(), validationEvaluations.keySet(),
new HashSet<String>(), new HashSet<String>());
}
/**
* @return true if there are no requested evaluations
*/
public boolean isEmpty() {
return trainEvaluations.isEmpty() && validationEvaluations.isEmpty();
}
@NoArgsConstructor
@Getter
@Setter
public static class Builder {
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
private Map<String, Integer> validationEvaluationLabels = new HashMap<>();
private void addEvaluations(boolean validation, @NonNull Map<String, List<IEvaluation>> evaluationMap, @NonNull Map<String, Integer> labelMap,
@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
if (evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex) {
String s;
if (validation) {
s = "This ListenerEvaluations.Builder already has validation evaluations for ";
} else {
s = "This ListenerEvaluations.Builder already has train evaluations for ";
}
throw new IllegalArgumentException(s + "variable " +
variableName + " with label index " + labelIndex + ". You can't add " +
" evaluations with a different label index. Got label index " + labelIndex);
}
if (evaluationMap.containsKey(variableName)) {
evaluationMap.get(variableName).addAll(Arrays.asList(evaluations));
} else {
evaluationMap.put(variableName, Arrays.asList(evaluations));
labelMap.put(variableName, labelIndex);
}
}
/**
* Add requested training evaluations for a parm/variable
*
* @param variableName The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName,
labelIndex, evaluations);
return this;
}
/**
* Add requested training evaluations for a parm/variable
*
* @param variable The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
return trainEvaluation(variable.getVarName(), labelIndex, evaluations);
}
/**
* Add requested validation evaluations for a parm/variable
*
* @param variableName The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName,
labelIndex, evaluations);
return this;
}
/**
* Add requested validation evaluations for a parm/variable
*
* @param variable The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations) {
return validationEvaluation(variable.getVarName(), labelIndex, evaluations);
}
/**
* Add requested evaluations for a parm/variable, for either training or validation
*
* @param validation Whether to add these evaluations as validation or training
* @param variableName The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations) {
if (validation) {
return validationEvaluation(variableName, labelIndex, evaluations);
} else {
return trainEvaluation(variableName, labelIndex, evaluations);
}
}
public ListenerEvaluations build() {
return new ListenerEvaluations(trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels);
}
}
}

View File

@ -0,0 +1,26 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners;
/**
* An enum representing feedback given by listeners during the training loop.<br>
* CONTINUE: Continue training for more epochs, unless the specified (maximum) number of training epochs have already been completed.<br>
* STOP: Terminate training at the current point, irrespective of how many total epochs were specified when calling fit.<br>
*/
public enum ListenerResponse {
CONTINUE, STOP;
}

View File

@ -0,0 +1,236 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners;
import com.google.common.collect.Sets;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import lombok.Setter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* Specifies a Listener's required variables for each operation.
* Used to ensure those variables end up in the minimum required subgraph calculated by {@link org.nd4j.autodiff.samediff.internal.InferenceSession}.
* Otherwise, if the variables weren't required by a loss variable, they would not be calculated.
* <p>
* Any variables in here are guaranteed to have {@link Listener#activationAvailable(SameDiff, At, MultiDataSet, SameDiffOp, String, INDArray)} called for them.
*/
@RequiredArgsConstructor
@Getter
public class ListenerVariables {
public static ListenerVariables empty() {
return ListenerVariables.builder().build();
}
@NonNull
private Set<String> trainingVariables;
@NonNull
private Set<String> validationVariables;
@NonNull
private Set<String> evaluationVariables;
@NonNull
private Set<String> inferenceVariables;
public static Builder builder() {
return new Builder();
}
/**
* Get required training variables
*/
public Set<String> trainingVariables() {
return trainingVariables;
}
/**
* Get required validation variables
*/
public Set<String> validationVariables() {
return validationVariables;
}
/**
* Get required evaluation variables
*/
public Set<String> evaluationVariables() {
return evaluationVariables;
}
/**
* Get required inference variables
*/
public Set<String> inferenceVariables() {
return inferenceVariables;
}
/**
* Get required variables for specified op
*/
public Set<String> requiredVariables(Operation op) {
switch (op) {
case TRAINING:
return trainingVariables;
case TRAINING_VALIDATION:
return validationVariables;
case INFERENCE:
return inferenceVariables;
case EVALUATION:
return evaluationVariables;
}
throw new IllegalArgumentException("Unknown operation " + op);
}
private ListenerVariables() {
}
/**
* Return a new ListenerVariables that contains the variables of this ListenerVariables and of other
*/
public ListenerVariables merge(ListenerVariables other) {
return new ListenerVariables(
Sets.newHashSet(Sets.union(trainingVariables, other.trainingVariables)),
Sets.newHashSet(Sets.union(validationVariables, other.validationVariables)),
Sets.newHashSet(Sets.union(evaluationVariables, other.evaluationVariables)),
Sets.newHashSet(Sets.union(inferenceVariables, other.inferenceVariables)));
}
@NoArgsConstructor
@Getter
@Setter
public static class Builder {
@NonNull
private Set<String> trainingVariables = new HashSet<>();
@NonNull
private Set<String> validationVariables = new HashSet<>();
@NonNull
private Set<String> evaluationVariables = new HashSet<>();
@NonNull
private Set<String> inferenceVariables = new HashSet<>();
/**
* Add required variables for the specified op
*
* @param op The op to require the variable for
*/
public Builder requireVariables(@NonNull Operation op, @NonNull String... variables) {
switch (op) {
case TRAINING:
trainingVariables.addAll(Arrays.asList(variables));
break;
case TRAINING_VALIDATION:
validationVariables.addAll(Arrays.asList(variables));
break;
case INFERENCE:
inferenceVariables.addAll(Arrays.asList(variables));
break;
case EVALUATION:
evaluationVariables.addAll(Arrays.asList(variables));
break;
}
return this;
}
/**
* Add required variables for the specified op
*
* @param op The op to require the variable for
*/
public Builder requireVariables(@NonNull Operation op, @NonNull SDVariable... variables) {
String[] names = new String[variables.length];
for (int i = 0; i < variables.length; i++)
names[i] = variables[i].getVarName();
return requireVariables(op, names);
}
/**
* Add required variables for training
*/
public Builder trainingVariables(@NonNull String... variables) {
return requireVariables(Operation.TRAINING, variables);
}
/**
* Add required variables for training
*/
public Builder trainingVariables(@NonNull SDVariable... variables) {
return requireVariables(Operation.TRAINING, variables);
}
/**
* Add required variables for validation
*/
public Builder validationVariables(@NonNull String... variables) {
return requireVariables(Operation.TRAINING_VALIDATION, variables);
}
/**
* Add required variables for validation
*/
public Builder validationVariables(@NonNull SDVariable... variables) {
return requireVariables(Operation.TRAINING_VALIDATION, variables);
}
/**
* Add required variables for inference
*/
public Builder inferenceVariables(@NonNull String... variables) {
return requireVariables(Operation.INFERENCE, variables);
}
/**
* Add required variables for inference
*/
public Builder inferenceVariables(@NonNull SDVariable... variables) {
return requireVariables(Operation.INFERENCE, variables);
}
/**
* Add required variables for evaluation
*/
public Builder evaluationVariables(@NonNull String... variables) {
return requireVariables(Operation.EVALUATION, variables);
}
/**
* Add required variables for evaluation
*/
public Builder evaluationVariables(@NonNull SDVariable... variables) {
return requireVariables(Operation.EVALUATION, variables);
}
public ListenerVariables build() {
return new ListenerVariables(trainingVariables, validationVariables, evaluationVariables, inferenceVariables);
}
}
}

View File

@ -1,5 +1,8 @@
package org.nd4j.autodiff.listeners;
import java.util.ArrayList;
import java.util.Collections;
import lombok.Data;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
@ -7,7 +10,7 @@ import org.nd4j.base.Preconditions;
import java.util.List;
/**
* Loss class - represents the loss (score) for the network. Provides a breakdown of all the loss components
* Loss class - represents the loss (score) for the network, for one iteration. Provides a breakdown of all the loss components
*
* @author Alex Black
*/
@ -70,4 +73,96 @@ public class Loss {
}
return sum;
}
public Loss copy() {
return new Loss(lossNames, losses);
}
public static Loss sum(List<Loss> losses) {
if (losses.size() == 0)
return new Loss(Collections.<String>emptyList(), new double[0]);
double[] lossValues = new double[losses.get(0).losses.length];
List<String> lossNames = new ArrayList<>(losses.get(0).lossNames);
for (int i = 0; i < losses.size(); i++) {
Loss l = losses.get(i);
Preconditions.checkState(l.losses.length == lossValues.length,
"Loss %s has %s losses, the others before it had %s.", i, l.losses.length, lossValues.length);
Preconditions.checkState(l.lossNames.equals(lossNames),
"Loss %s has different loss names from the others before it. Expected %s, got %s.",
i, lossNames, l.lossNames);
for (int j = 0; j < lossValues.length; j++)
lossValues[j] += l.losses[j];
}
return new Loss(lossNames, lossValues);
}
public static Loss average(List<Loss> losses) {
Loss sum = sum(losses);
for (int i = 0; i < sum.losses.length; i++) {
sum.losses[i] /= losses.size();
}
return sum;
}
public static Loss add(Loss a, Loss b) {
Preconditions.checkState(a.lossNames.equals(b.lossNames),
"Loss names differ. First loss has names %s, second has names %s.",
a.lossNames, b.lossNames);
double[] lossValues = new double[a.losses.length];
for (int i = 0; i < lossValues.length; i++)
lossValues[i] = a.losses[i] + b.losses[i];
return new Loss(a.lossNames, lossValues);
}
public static Loss sub(Loss a, Loss b) {
Preconditions.checkState(a.lossNames.equals(b.lossNames),
"Loss names differ. First loss has names %s, second has names %s.",
a.lossNames, b.lossNames);
double[] lossValues = new double[a.losses.length];
for (int i = 0; i < lossValues.length; i++)
lossValues[i] = a.losses[i] - b.losses[i];
return new Loss(a.lossNames, lossValues);
}
public static Loss div(Loss a, Number b) {
double[] lossValues = new double[a.losses.length];
for (int i = 0; i < lossValues.length; i++)
lossValues[i] = a.losses[i] / b.doubleValue();
return new Loss(a.lossNames, lossValues);
}
public Loss add(Loss other) {
return add(this, other);
}
public Loss sub(Loss other) {
return sub(this, other);
}
public Loss plus(Loss other) {
return add(this, other);
}
public Loss minus(Loss other) {
return sub(this, other);
}
public Loss div(Number other) {
return div(this, other);
}
}

View File

@ -0,0 +1,60 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
/**
* An enum representing the operation being done on a SameDiff graph.<br>
* <p>
* TRAINING: {@link SameDiff#fit()} methods training step (everything except validation)<br>
* TRAINING_VALIDATION: the validation step during {@link SameDiff#fit()} methods - i.e., test/validation set evaluation,<br>
* INFERENCE: {@link SameDiff#output()}, {@link SameDiff#batchOutput()} and {@link SameDiff#exec(Map, String...)} ()} methods,
* including the single batch and placeholder ones. Also {@link SDVariable#eval()}<br>
* EVALUATION: {@link SameDiff#evaluate()} methods<br>
*/
public enum Operation {
/**
* The training operation: {@link SameDiff#fit()} methods training step (everything except validation).
*/
TRAINING,
/**
* The training validation operation: the validation step during {@link SameDiff#fit()} methods.
*/
TRAINING_VALIDATION,
/**
* Inference operations: {@link SameDiff#output()}, {@link SameDiff#batchOutput()} and {@link SameDiff#exec(Map, String...)} ()} methods,
* as well as {@link SameDiff#execBackwards(Map, Operation, String...)} methods.
*/
INFERENCE,
/**
* Evaluation operations: {@link SameDiff#evaluate()} methods.
*/
EVALUATION;
public boolean isTrainingPhase() {
return this == TRAINING || this == TRAINING_VALIDATION;
}
public boolean isValidation() {
return this == TRAINING_VALIDATION;
}
}

View File

@ -7,13 +7,15 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.*;
@ -148,12 +150,18 @@ public class CheckpointListener extends BaseListener implements Serializable {
}
@Override
public void epochEnd(SameDiff sameDiff, At at) {
public ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long epochTimeMillis) {
if(saveEveryNEpochs != null && (at.epoch()+1) % saveEveryNEpochs == 0){
//Save:
saveCheckpoint(sameDiff, at);
}
//General saving conditions: don't need to check here - will check in iterationDone
return ListenerResponse.CONTINUE;
}
@Override
public boolean isActive(Operation operation) {
return operation == Operation.TRAINING;
}
@Override

View File

@ -3,6 +3,7 @@ package org.nd4j.autodiff.listeners.debugging;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.buffer.DataType;
@ -70,7 +71,12 @@ public class ExecDebuggingListener extends BaseListener {
}
@Override
public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) {
public boolean isActive(Operation operation) {
return true;
}
@Override
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
if(lastIter != at.iteration()){
lastIter = at.iteration();
stepThisIter = 0;

View File

@ -0,0 +1,121 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners.impl;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.Getter;
import lombok.Setter;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseEvaluationListener;
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.ListenerEvaluations;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* HistoryListener is mainly used internally to collect information such as the loss curve and evaluations,
* which will be reported later in a {@link History} instance
*/
public class HistoryListener extends BaseEvaluationListener {
@Getter
@Setter
private ListenerEvaluations evaluations;
private List<EvaluationRecord> trainingHistory = new ArrayList<>();
private List<EvaluationRecord> validationHistory = new ArrayList<>();
private LossCurve loss = null;
private long startTime;
private long endTime;
private List<Long> validationTimes = new ArrayList<>();
private long validationStartTime;
public HistoryListener(TrainingConfig tc) {
this.evaluations = new ListenerEvaluations(tc.getTrainEvaluations(), tc.getTrainEvaluationLabels(),
tc.getValidationEvaluations(), tc.getValidationEvaluationLabels());
}
public HistoryListener(ListenerEvaluations evaluations) {
this.evaluations = evaluations;
}
public HistoryListener newInstance() {
return new HistoryListener(evaluations);
}
@Override
public ListenerEvaluations evaluations() {
return evaluations;
}
@Override
public boolean isActive(Operation operation) {
return operation.isTrainingPhase();
}
@Override
public ListenerResponse epochEndEvaluations(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis, EvaluationRecord evaluations) {
trainingHistory.add(evaluations);
loss = lossCurve;
return ListenerResponse.CONTINUE;
}
@Override
public ListenerResponse validationDoneEvaluations(SameDiff sd, At at, long validationTimeMillis, EvaluationRecord evaluations) {
validationHistory.add(evaluations);
return ListenerResponse.CONTINUE;
}
@Override
public void operationStart(SameDiff sd, Operation op) {
if (op == Operation.TRAINING) {
startTime = System.currentTimeMillis();
} else if (op == Operation.TRAINING_VALIDATION) {
validationStartTime = System.currentTimeMillis();
}
}
@Override
public void operationEnd(SameDiff sd, Operation op) {
if (op == Operation.TRAINING) {
endTime = System.currentTimeMillis();
} else if (op == Operation.TRAINING_VALIDATION) {
validationTimes.add(System.currentTimeMillis() - validationStartTime);
}
}
public History getReport() {
return new History(trainingHistory, validationHistory, loss, endTime - startTime, validationTimes);
}
}

View File

@ -3,7 +3,10 @@ package org.nd4j.autodiff.listeners.impl;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;
@ -32,7 +35,6 @@ public class ScoreListener extends BaseListener {
private final boolean reportEpochs;
private final boolean reportIterPerformance;
private long epochStart;
private long epochExampleCount;
private int epochBatchCount;
private long etlTotalTimeEpoch;
@ -72,10 +74,14 @@ public class ScoreListener extends BaseListener {
}
@Override
public boolean isActive(Operation operation) {
return operation == Operation.TRAINING;
}
@Override
public void epochStart(SameDiff sd, At at) {
if (reportEpochs) {
epochStart = System.currentTimeMillis();
epochExampleCount = 0;
epochBatchCount = 0;
etlTotalTimeEpoch = 0;
@ -85,17 +91,18 @@ public class ScoreListener extends BaseListener {
}
@Override
public void epochEnd(SameDiff sd, At at) {
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
if (reportEpochs) {
long epochDuration = System.currentTimeMillis() - epochStart;
double batchesPerSec = epochBatchCount / (epochDuration / 1000.0);
double examplesPerSec = epochExampleCount / (epochDuration / 1000.0);
double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochDuration;
double batchesPerSec = epochBatchCount / (epochTimeMillis / 1000.0);
double examplesPerSec = epochExampleCount / (epochTimeMillis / 1000.0);
double pcEtl = 100.0 * etlTotalTimeEpoch / (double) epochTimeMillis;
String etl = formatDurationMs(etlTotalTimeEpoch) + " ETL time" + (etlTotalTimeEpoch > 0 ? "(" + format2dp(pcEtl) + " %)" : "");
log.info("Epoch {} complete on iteration {} - {} batches ({} examples) in {} - {} batches/sec, {} examples/sec, {}",
at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochDuration),
at.epoch(), at.iteration(), epochBatchCount, epochExampleCount, formatDurationMs(epochTimeMillis),
format2dp(batchesPerSec), format2dp(examplesPerSec), etl);
}
return ListenerResponse.CONTINUE;
}
@Override

View File

@ -4,7 +4,10 @@ import com.google.flatbuffers.Table;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
@ -279,13 +282,18 @@ public class UIListener extends BaseListener {
writer.writeFinishStaticMarker();
}
@Override
public boolean isActive(Operation operation) {
return operation == Operation.TRAINING;
}
@Override
public void epochStart(SameDiff sd, At at) {
epochTrainEval = null;
}
@Override
public void epochEnd(SameDiff sd, At at) {
public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
//If any training evaluation, report it here:
if(epochTrainEval != null){
@ -315,6 +323,7 @@ public class UIListener extends BaseListener {
}
epochTrainEval = null;
return ListenerResponse.CONTINUE;
}
@Override
@ -401,13 +410,13 @@ public class UIListener extends BaseListener {
@Override
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
//Do training set evaluation, if required
//Note we'll do it in opExecution not iterationDone because we can't be sure arrays will be stil be around in the future
//i.e., we'll eventually add workspaces and clear activation arrays once they have been consumed
if(training && trainEvalMetrics != null && trainEvalMetrics.size() > 0){
if(at.operation() == Operation.TRAINING && trainEvalMetrics != null && trainEvalMetrics.size() > 0){
long time = System.currentTimeMillis();
//First: check if this op is relevant at all to evaluation...

View File

@ -0,0 +1,240 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners.records;
import com.google.common.base.Predicates;
import com.google.common.collect.Collections2;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Getter;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
/**
* A helper class to hold evaluations and provide methods to easily query them
*/
@Getter
public class EvaluationRecord {
private ImmutableMap<String, List<IEvaluation>> evaluations;
private Map<Class<? extends IEvaluation>, IEvaluation> classEvaluations = new HashMap<>();
private boolean isEmpty = true;
public EvaluationRecord(Map<String, List<IEvaluation>> evaluations) {
this.evaluations = ImmutableMap.copyOf(evaluations);
for (List<IEvaluation> le : evaluations.values()) {
for (IEvaluation e : le) {
isEmpty = false;
if (classEvaluations.containsKey(e.getClass()))
classEvaluations.remove(e.getClass());
else
classEvaluations.put(e.getClass(), e);
}
}
}
private EvaluationRecord() {
}
public boolean isEmpty() {
return isEmpty;
}
/**
* Get all evaluations
*/
public ImmutableMap<String, List<IEvaluation>> evaluations() {
return evaluations;
}
/**
* Get evaluations for a given param/variable
*
* @param param The target param/variable
*/
public List<IEvaluation> evaluations(String param) {
Preconditions.checkArgument(evaluations.containsKey(param),
"No evaluations for %s.", param);
return evaluations.get(param);
}
/**
* Get evaluations for a given param/variable
*
* @param param The target param/variable
*/
public List<IEvaluation> evaluations(SDVariable param) {
return evaluations(param.getVarName());
}
/**
* Get the evaluation for param at the specified index
*/
public IEvaluation evaluation(String param, int index) {
return evaluations(param).get(index);
}
/**
* Get the evaluation for param at the specified index
*/
public IEvaluation evaluation(SDVariable param, int index) {
return evaluation(param.getVarName(), index);
}
/**
* Get the evaluation for a given param/variable
* <p>
* Will throw an exception if there are more than one or no evaluations for the param
*
* @param param The target param/variable
*/
public <T extends IEvaluation> T evaluation(String param) {
Preconditions.checkArgument(evaluations.containsKey(param),
"No evaluations for %s.", param);
Preconditions.checkArgument(evaluations.get(param).size() == 1,
"Multiple evaluations for %s. Use evaluations().", param);
return (T) evaluations.get(param).get(0);
}
/**
* Get the evaluation for a given param/variable
* <p>
* Will throw an exception if there are more than one or no evaluations for the param
*
* @param param The target param/variable
*/
public <T extends IEvaluation> T evaluation(SDVariable param) {
return evaluation(param.getVarName());
}
/**
* Get the evaluation of a given type
* <p>
* Will throw an exception if there are more than one or no evaluations of that type
*
* @param evalClass The type of evaluation to look for
*/
public <T extends IEvaluation<T>> T evaluation(Class<T> evalClass) {
Preconditions.checkArgument(classEvaluations.containsKey(evalClass),
"Can't get evaluation for %s. Either no evaluations with that class are present, or more than one are.", evalClass);
return (T) classEvaluations.get(evalClass);
}
/**
* Get the evaluation of a given type, for a given param/variable
* <p>
* Will throw an exception if there are more than one or no evaluations of that type for the given param
*
* @param param The target param/variable
* @param evalClass The type of evaluation to look for
*/
public <T extends IEvaluation<T>> T evaluation(String param, Class<T> evalClass) {
Collection<IEvaluation> evals = Collections2.filter(evaluations(param), Predicates.instanceOf(evalClass));
Preconditions.checkArgument(evals.size() == 1, "Multiple or no evaluations of type %s for param %s.", evalClass, param);
return (T) evals.iterator().next();
}
/**
* Get the evaluation of a given type, for a given param/variable
* <p>
* Will throw an exception if there are more than one or no evaluations of that type for the given param
*
* @param param The target param/variable
* @param evalClass The type of evaluation to look for
*/
public <T extends IEvaluation<T>> T evaluation(SDVariable param, Class<T> evalClass) {
return evaluation(param.getVarName(), evalClass);
}
/**
* Get the metric's value for the evaluation of the metric's type
* <p>
* Will throw an exception if there are more than one or no evaluations of that type
*
* @param metric The metric to calculate
*/
public double getValue(IMetric metric) {
return evaluation(metric.getEvaluationClass()).getValue(metric);
}
/**
* Get the metric's value for the evaluation of the metric's type, for a given param/variable
* <p>
* Will throw an exception if there are more than one or no evaluations of that type for the given param
*
* @param param The target param/variable
* @param metric The metric to calculate
*/
public double getValue(String param, IMetric metric) {
return evaluation(param, metric.getEvaluationClass()).getValue(metric);
}
/**
* Get the metric's value for the evaluation of the metric's type, for a given param/variable
* <p>
* Will throw an exception if there are more than one or no evaluations of that type for the given param
*
* @param param The target param/variable
* @param metric The metric to calculate
*/
public double getValue(SDVariable param, IMetric metric) {
return getValue(param.getVarName(), metric);
}
/**
* Get the metric's value for the evaluation for a given param/variable at the given index
* <p>
* Will throw an exception if the target evaluation doesn't support the given metric
*
* @param param The target param/variable
* @param index The index of the target evaluation on the param
* @param metric The metric to calculate
*/
public double getValue(String param, int index, IMetric metric) {
return evaluation(param, index).getValue(metric);
}
/**
* Get the metric's value for the evaluation for a given param/variable at the given index
* <p>
* Will throw an exception if the target evaluation doesn't support the given metric
*
* @param param The target param/variable
* @param index The index of the target evaluation on the param
* @param metric The metric to calculate
*/
public double getValue(SDVariable param, int index, IMetric metric) {
return getValue(param.getVarName(), index, metric);
}
}

View File

@ -0,0 +1,356 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners.records;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import lombok.Getter;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
/**
* An object containing training history for a SameDiff.fit call, such as {@link SameDiff#fit()}, {@link SameDiff#fit(DataSetIterator, int, Listener...)}, etc.<br>
* Contains information including:<br>
* - Evaluations performed (training set and test set)<br>
* - Loss curve (score values at each iteration)<br>
* - Training times, and validation times<br>
* - Number of epochs performed<br>
*/
@Getter
public class History {
private List<EvaluationRecord> trainingHistory;
private List<EvaluationRecord> validationHistory;
private LossCurve lossCurve;
private long trainingTimeMillis;
private List<Long> validationTimesMillis;
public History(List<EvaluationRecord> training, List<EvaluationRecord> validation, LossCurve loss,
long trainingTimeMillis, List<Long> validationTimesMillis){
trainingHistory = ImmutableList.copyOf(training);
validationHistory = ImmutableList.copyOf(validation);
this.lossCurve = loss;
this.trainingTimeMillis = trainingTimeMillis;
this.validationTimesMillis = ImmutableList.copyOf(validationTimesMillis);
}
/**
* Get the training evaluations
*/
public List<EvaluationRecord> trainingEval(){
return trainingHistory;
}
/**
* Get the validation evaluations
*/
public List<EvaluationRecord> validationEval(){
return validationHistory;
}
/**
* Get the loss curve
*/
public LossCurve lossCurve(){
return lossCurve;
}
/**
* Get the total training time, in milliseconds
*/
public long trainingTimeMillis(){
return trainingTimeMillis;
}
/**
* Get the total validation time, in milliseconds
*/
public List<Long> validationTimesMillis(){
return validationTimesMillis;
}
/**
* Get the number of epochs trained for
*/
public int trainingEpochs(){
return trainingHistory.size();
}
/**
* Get the number of epochs validation was ran on
*/
public int validationEpochs(){
return validationHistory.size();
}
/**
* Get the results of a training evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List<Double> trainingEval(String param, IMetric metric){
List<Double> data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.getValue(param, metric));
return data;
}
/**
* Get the results of a training evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List<Double> trainingEval(SDVariable param, IMetric metric){
return trainingEval(param.getVarName(), metric);
}
/**
* Get the results of a training evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<Double> trainingEval(String param, int index, IMetric metric){
List<Double> data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.getValue(param, index, metric));
return data;
}
/**
* Get the results of a training evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<Double> trainingEval(SDVariable param, int index, IMetric metric){
return trainingEval(param.getVarName(), index, metric);
}
/**
* Get the results of a training evaluation for a given metric
*
* Only works if there is only one evaluation with the given metric
*/
public List<Double> trainingEval(IMetric metric){
List<Double> data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.getValue(metric));
return data;
}
/**
* Get the results of a training evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List<IEvaluation> trainingEval(String param){
List<IEvaluation> data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.evaluation(param));
return data;
}
/**
* Get the results of a training evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List<IEvaluation> trainingEval(SDVariable param){
return trainingEval(param.getVarName());
}
/**
* Get the results of a training evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<IEvaluation> trainingEval(String param, int index){
List<IEvaluation> data = new ArrayList<>();
for(EvaluationRecord er : trainingHistory)
data.add(er.evaluation(param, index));
return data;
}
/**
* Get the results of a training evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<IEvaluation> trainingEval(SDVariable param, int index){
return trainingEval(param.getVarName(), index);
}
/**
* Get the results of a validation evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List<Double> validationEval(String param, IMetric metric){
List<Double> data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.getValue(param, metric));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter for a given metric
*
* Only works if there is only one evaluation with the given metric for param
*/
public List<Double> validationEval(SDVariable param, IMetric metric){
return validationEval(param.getVarName(), metric);
}
/**
* Get the results of a validation evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<Double> validationEval(String param, int index, IMetric metric){
List<Double> data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.getValue(param, index, metric));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter at a given index, for a given metric
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<Double> validationEval(SDVariable param, int index, IMetric metric){
return validationEval(param.getVarName(), index, metric);
}
/**
* Get the results of a validation evaluation for a given metric
*
* Only works if there is only one evaluation with the given metric
*/
public List<Double> validationEval(IMetric metric){
List<Double> data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.getValue(metric));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List<IEvaluation> validationEval(String param){
List<IEvaluation> data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.evaluation(param));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter
*
* Only works if there is only one evaluation for param.
*/
public List<IEvaluation> validationEval(SDVariable param){
return validationEval(param.getVarName());
}
/**
* Get the results of a validation evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<IEvaluation> validationEval(String param, int index){
List<IEvaluation> data = new ArrayList<>();
for(EvaluationRecord er : validationHistory)
data.add(er.evaluation(param, index));
return data;
}
/**
* Get the results of a validation evaluation on a given parameter at a given index
*
* Note that it returns all recorded evaluations.
* Index determines the evaluation used not the epoch's results to return.
*/
public List<IEvaluation> validationEval(SDVariable param, int index){
return validationEval(param.getVarName(), index);
}
/**
* Gets the training evaluations ran during the last epoch
*/
public EvaluationRecord finalTrainingEvaluations(){
return trainingHistory.get(trainingHistory.size() - 1);
}
/**
* Gets the validation evaluations ran during the last epoch
*/
public EvaluationRecord finalValidationEvaluations(){
return validationHistory.get(validationHistory.size() - 1);
}
/**
* Gets the evaluation record for a given epoch.
* @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end.
*/
public EvaluationRecord trainingEvaluations(int epoch){
if(epoch >= 0){
return trainingHistory.get(epoch);
} else {
return trainingHistory.get(trainingHistory.size() - epoch);
}
}
/**
* Gets the evaluation record for a given epoch.
* @param epoch The epoch to get results for. If negative, returns results for the epoch that many epochs from the end.
*/
public EvaluationRecord validationEvaluations(int epoch){
if(epoch >= 0){
return trainingHistory.get(epoch);
} else {
return validationHistory.get(validationHistory.size() - epoch);
}
}
}

View File

@ -0,0 +1,211 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.listeners.records;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class LossCurve {
@Getter
private List<String> lossNames;
@Getter
private INDArray lossValues;
public LossCurve(List<Loss> losses){
lossNames = ImmutableList.copyOf(losses.get(0).getLossNames());
int numLossValues = losses.get(0).lossValues().length;
lossValues = Nd4j.create(DataType.FLOAT, losses.size(), losses.get(0).lossValues().length);
for(int i = 0 ; i < losses.size() ; i++){
Loss l = losses.get(i);
Preconditions.checkArgument(l.getLossNames().equals(lossNames),
"Loss names for loss %s differ from others. Expected %s, got %s",
i, lossNames, l.getLossNames());
Preconditions.checkArgument(l.getLosses().length == numLossValues,
"Number of loss values for loss %s differ from others. Expected %s, got %s",
i, numLossValues, l.getLosses().length);
lossValues = lossValues.putRow(i, Nd4j.createFromArray(l.getLosses()).castTo(DataType.FLOAT));
}
}
public LossCurve(double[] lossValues, List<String> lossNames){
this.lossValues = Nd4j.createFromArray(new double[][]{ lossValues}).castTo(DataType.FLOAT);
this.lossNames = lossNames;
}
protected LossCurve(INDArray lossValues, List<String> lossNames){
Preconditions.checkArgument(lossValues.rank() == 2, "lossValues must have a rank of 2, got %s", lossValues.rank());
Preconditions.checkArgument(lossValues.dataType() == DataType.FLOAT, "lossValues must be type FLOAT, got %s", lossValues.dataType());
this.lossValues = lossValues;
this.lossNames = lossNames;
}
public List<Loss> losses(){
List<Loss> losses = new ArrayList<>();
for(int i = 0 ; i < lossValues.size(0) ; i++){
losses.add(new Loss(lossNames, lossValues.getRow(i).toDoubleVector()));
}
return losses;
}
/**
* Get the mean loss for a given epoch
*
* If epoch is negative, counts backwards from the end.
* E.g. losses(-1) gets the last epoch.
*
* @param epoch The epoch to get. If negative, returns results for the epoch that many epochs from the end
*/
public Loss meanLoss(int epoch){
if(epoch >= 0){
return new Loss(lossNames, lossValues.getRow(epoch).toDoubleVector());
} else {
return new Loss(lossNames, lossValues.getRow(lossValues.rows() + epoch).toDoubleVector());
}
}
/**
* Get the mean loss for the last epoch.
*/
public Loss lastMeanLoss(){
return meanLoss(-1);
}
/**
* Return all mean loss values for a given variable
*/
public float[] meanLoss(@NonNull String lossName){
int idx = lossNames.indexOf(lossName);
Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames);
float[] loss = new float[(int) lossValues.size(0)];
for(int i = 0 ; i < lossValues.size(0) ; i++){
loss[i] = lossValues.getFloat(i, idx);
}
return loss;
}
/**
* Return all mean loss values for a given variable
*/
public float[] meanLoss(@NonNull SDVariable loss){
return meanLoss(loss.getVarName());
}
/**
* Return the mean loss value for a given variable on a given epoch.
*
* See {@link #meanLoss(int)}
*/
public float meanLoss(@NonNull String lossName, int epoch){
int idx = lossNames.indexOf(lossName);
Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames);
if(epoch >= 0) {
return lossValues.getFloat(epoch, idx);
} else {
return lossValues.getFloat(lossValues.rows() + epoch, idx);
}
}
/**
* Return the mean loss value for a given variable on a given epoch.
*
* See {@link #meanLoss(int)}
*/
public float meanLoss(@NonNull SDVariable loss, int epoch){
return meanLoss(loss.getVarName(), epoch);
}
/**
* Return the mean loss value for a given variable on the last epoch.
*/
public float lastMeanLoss(@NonNull String lossName){
int idx = lossNames.indexOf(lossName);
Preconditions.checkArgument(idx >= 0, "No loss value for %s. Existing losses: %s", lossName, lossNames);
return lossValues.getFloat(lossValues.rows() - 1, idx);
}
/**
* Return the mean loss value for a given variable on the last epoch.
*/
public float lastMeanLoss(@NonNull SDVariable loss){
return lastMeanLoss(loss.getVarName());
}
/**
* Return the loss delta between the last epoch and the one before it.
* Equivalent to meanLoss(-1) - meanLoss(-2).
* A positive delta means the loss is increasing, and a negative delta means it is decreasing.
*/
public Loss lastMeanDelta(){
return lastMeanLoss().sub(meanLoss(-2));
}
/**
* Return the loss delta between the last epoch and the one before it, for a given variable.
* Equivalent to meanLoss(-1) - meanLoss(-2).
* A positive delta means the loss is increasing, and a negative delta means it is decreasing.
*/
public double lastMeanDelta(String lossName){
return lastMeanDelta().getLoss(lossName);
}
/**
* Return the loss delta between the last epoch and the one before it, for a given variable.
* Equivalent to meanLoss(-1) - meanLoss(-2).
* A positive delta means the loss is increasing, and a negative delta means it is decreasing.
*/
public double lastMeanDelta(SDVariable loss){
return lastMeanDelta(loss.getVarName());
}
/**
* Return a new LossCurve with the given losses added on as the most recent epoch
*/
public LossCurve addLossAndCopy(Loss loss){
return addLossAndCopy(loss.getLosses(), loss.lossNames());
}
/**
* Return a new LossCurve with the given losses added on as the most recent epoch
*/
public LossCurve addLossAndCopy(double[] values, List<String> lossNames){
return new LossCurve(
Nd4j.concat(0, lossValues,
Nd4j.createFromArray(new double[][]{values}).castTo(DataType.FLOAT)),
lossNames);
}
}

View File

@ -18,7 +18,9 @@ package org.nd4j.autodiff.samediff;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.ListenerEvaluations;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.regularization.L1Regularization;
import org.nd4j.linalg.learning.regularization.L2Regularization;
@ -62,6 +64,12 @@ public class TrainingConfig {
private int iterationCount;
private int epochCount;
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
private Map<String, Integer> validationEvaluationLabels = new HashMap<>();
/**
* Create a training configuration suitable for training a single input, single output network.<br>
* See also the {@link Builder} for creating a TrainingConfig
@ -106,6 +114,17 @@ public class TrainingConfig {
this.lossVariables = lossVariables;
}
protected TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping,
List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables,
Map<String, List<IEvaluation>> trainEvaluations, Map<String, Integer> trainEvaluationLabels,
Map<String, List<IEvaluation>> validationEvaluations, Map<String, Integer> validationEvaluationLabels){
this(updater, regularization, minimize, dataSetFeatureMapping, dataSetLabelMapping, dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables);
this.trainEvaluations = trainEvaluations;
this.trainEvaluationLabels = trainEvaluationLabels;
this.validationEvaluations = validationEvaluations;
this.validationEvaluationLabels = validationEvaluationLabels;
}
/**
* Increment the iteration count by 1
*/
@ -146,6 +165,12 @@ public class TrainingConfig {
private boolean skipValidation = false;
private boolean markLabelsUnused = false;
private Map<String, List<IEvaluation>> trainEvaluations = new HashMap<>();
private Map<String, Integer> trainEvaluationLabels = new HashMap<>();
private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<>();
private Map<String, Integer> validationEvaluationLabels = new HashMap<>();
/**
* Set the updater (such as {@link org.nd4j.linalg.learning.config.Adam}, {@link org.nd4j.linalg.learning.config.Nesterovs}
* etc. This is also how the learning rate (or learning rate schedule) is set.
@ -327,7 +352,7 @@ public class TrainingConfig {
* Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the
* DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2"
* and the MultiDataSet features masks should be mapped with {@code MultiDataSet.getFeatureMaskArray(0)->"mask1"}
* and {@code MultiDataSet.getFeatureMaskArray(1)->"mask2"}, then this should be set to {@code "mask2", "mask2"}.
* and {@code MultiDataSet.getFeatureMaskArray(1)->"mask2"}, then this should be set to {@code "mask1", "mask2"}.
*
* @param dataSetFeatureMaskMapping Name of the variables/placeholders that the feature arrays should be mapped to
*/
@ -347,7 +372,7 @@ public class TrainingConfig {
* Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the
* DataSet or MultiDataSet. For example, if the network had 2 mask variables called "mask1" and "mask2"
* and the MultiDataSet label masks should be mapped with {@code MultiDataSet.getLabelMaskArray(0)->"mask1"}
* and {@code MultiDataSet.getLabelMaskArray(1)->"mask2"}, then this should be set to {@code "mask2", "mask2"}.
* and {@code MultiDataSet.getLabelMaskArray(1)->"mask2"}, then this should be set to {@code "mask1", "mask2"}.
*
* @param dataSetLabelMaskMapping Name of the variables/placeholders that the feature arrays should be mapped to
*/
@ -366,6 +391,104 @@ public class TrainingConfig {
return this;
}
private void addEvaluations(boolean validation, @NonNull Map<String, List<IEvaluation>> evaluationMap, @NonNull Map<String, Integer> labelMap,
@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
if(evaluationMap.containsKey(variableName) && labelMap.get(variableName) != labelIndex){
String s;
if(validation){
s = "This ListenerEvaluations.Builder already has validation evaluations for ";
} else {
s = "This ListenerEvaluations.Builder already has train evaluations for ";
}
throw new IllegalArgumentException(s + "variable " +
variableName + " with label index " + labelIndex + ". You can't add " +
" evaluations with a different label index. Got label index " + labelIndex);
}
if(evaluationMap.containsKey(variableName)){
evaluationMap.get(variableName).addAll(Arrays.asList(evaluations));
} else {
evaluationMap.put(variableName, Arrays.asList(evaluations));
labelMap.put(variableName, labelIndex);
}
}
/**
* Add requested History training evaluations for a parm/variable.
*
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
*
* @param variableName The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder trainEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
addEvaluations(false, this.trainEvaluations, this.trainEvaluationLabels, variableName,
labelIndex, evaluations);
return this;
}
/**
* Add requested History training evaluations for a parm/variable.
*
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
*
* @param variable The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder trainEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
return trainEvaluation(variable.getVarName(), labelIndex, evaluations);
}
/**
* Add requested History validation evaluations for a parm/variable.
*
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
*
* @param variableName The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder validationEvaluation(@NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
addEvaluations(true, this.validationEvaluations, this.validationEvaluationLabels, variableName,
labelIndex, evaluations);
return this;
}
/**
* Add requested History validation evaluations for a parm/variable.
*
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
*
* @param variable The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder validationEvaluation(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
return validationEvaluation(variable.getVarName(), labelIndex, evaluations);
}
/**
* Add requested evaluations for a parm/variable, for either training or validation.
*
* These evaluations will be reported in the {@link org.nd4j.autodiff.listeners.records.History} object returned by fit.
*
* @param validation Whether to add these evaluations as validation or training
* @param variableName The variable to evaluate
* @param labelIndex The index of the label to evaluate against
* @param evaluations The evaluations to run
*/
public Builder addEvaluations(boolean validation, @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations){
if(validation){
return validationEvaluation(variableName, labelIndex, evaluations);
} else{
return trainEvaluation(variableName, labelIndex, evaluations);
}
}
public TrainingConfig build(){
if(!skipValidation) {
Preconditions.checkState(updater != null, "Updater (optimizer) must not be null. Use updater(IUpdater) to set an updater");
@ -374,10 +497,20 @@ public class TrainingConfig {
Preconditions.checkState(markLabelsUnused || dataSetLabelMapping != null, "No DataSet label mapping has been provided. A " +
"mapping between DataSet array positions and variables/placeholders must be provided - use dataSetLabelMapping(...) to set this," +
" or use markLabelsUnused() to mark labels as unused (for example, for unsupervised learning)");
Preconditions.checkArgument(trainEvaluations.keySet().equals(trainEvaluationLabels.keySet()),
"Must specify a label index for each train evaluation. Expected: %s, got: %s",
trainEvaluations.keySet(), trainEvaluationLabels.keySet());
Preconditions.checkArgument(validationEvaluations.keySet().equals(validationEvaluationLabels.keySet()),
"Must specify a label index for each validation evaluation. Expected: %s, got: %s",
validationEvaluations.keySet(), validationEvaluationLabels.keySet());
}
return new TrainingConfig(updater, regularization, minimize, dataSetFeatureMapping, dataSetLabelMapping,
dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables);
dataSetFeatureMaskMapping, dataSetLabelMaskMapping, lossVariables,
trainEvaluations, trainEvaluationLabels, validationEvaluations, validationEvaluationLabels);
}
}

View File

@ -0,0 +1,152 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.samediff.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
/**
* Configuration for a single batch {@link SameDiff} inference operation.
*
* Used in {@link SameDiff#batchOutput()}.
*/
@Getter
@Setter
public class BatchOutputConfig {
@Setter(AccessLevel.NONE)
private SameDiff sd;
@NonNull
private List<String> outputs = new ArrayList<>();
private Map<String, INDArray> placeholders = new HashMap<>();
@NonNull
private List<Listener> listeners = new ArrayList<>();
public BatchOutputConfig(@NonNull SameDiff sd){
this.sd = sd;
}
/**
* Add required outputs
*/
public BatchOutputConfig output(@NonNull String... outputs){
this.outputs.addAll(Arrays.asList(outputs));
return this;
}
/**
* Add required outputs
*/
public BatchOutputConfig output(@NonNull SDVariable... outputs){
String[] outNames = new String[outputs.length];
for(int i = 0 ; i < outputs.length ; i++){
outNames[i] = outputs[i].getVarName();
}
return output(outNames);
}
/**
* Add all variables as required outputs
*/
public BatchOutputConfig outputAll(){
return output(sd.variables().toArray(new SDVariable[0]));
}
/**
* Add a placeholder value for a specified variable
*/
public BatchOutputConfig input(@NonNull String variable, @NonNull INDArray placeholder){
Preconditions.checkState(!placeholders.containsKey(variable),
"Placeholder for variable %s already specified", variable);
Preconditions.checkNotNull(sd.getVariable(variable),
"Variable %s does not exist in this SameDiff graph", variable);
placeholders.put(variable, placeholder);
return this;
}
/**
* See {@link #input(String, INDArray)}
*/
public BatchOutputConfig input(@NonNull SDVariable variable, @NonNull INDArray placeholder){
return input(variable.getVarName(), placeholder);
}
/**
* Calls {@link #input(String, INDArray)} on each entry in the map.
*/
public BatchOutputConfig inputs(Map<String, INDArray> placeholders){
if(placeholders == null) {
this.placeholders = null;
return this;
}
for(Map.Entry<String, INDArray> e : placeholders.entrySet()){
input(e.getKey(), e.getValue());
}
return this;
}
/**
* Add listeners for this operation
*/
public BatchOutputConfig listeners(@NonNull Listener... listeners){
this.listeners.addAll(Arrays.asList(listeners));
return this;
}
/**
* Do inference and return the results
*/
public Map<String, INDArray> exec(){
return sd.output(placeholders, listeners, outputs.toArray(new String[0]));
}
/**
* Do inference and return the results for the single output
*
* Only works if exactly one output is specified
*/
public INDArray execSingle(){
Preconditions.checkState(outputs.size() == 1,
"Can only use execSingle() when exactly one output is specified, there were %s", outputs.size());
return exec().get(outputs.get(0));
}
}

View File

@ -0,0 +1,202 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.samediff.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
/**
* Configuration for a {@link SameDiff} evaluation operation.
*
* Used in {@link SameDiff#evaluate()}.
*/
@Getter
@Setter
public class EvaluationConfig {
@NonNull
private Map<String, List<IEvaluation>> evaluations = new HashMap<>();
@NonNull
private Map<String, Integer> labelIndices = new HashMap<>();
private MultiDataSetIterator data;
@NonNull
private List<Listener> listeners = new ArrayList<>();
private boolean singleInput = false;
@Setter(AccessLevel.NONE)
private SameDiff sd;
public EvaluationConfig(@NonNull SameDiff sd){
this.sd = sd;
}
/**
* Add evaluations to be preformed on a specified variable, and set that variable's label index.
*
* Setting a label index is required if using a MultiDataSetIterator.
*
* @param param The param to evaluate
* @param labelIndex The label index of that parameter
* @param evaluations The evaluations to preform
*/
public EvaluationConfig evaluate(@NonNull String param, int labelIndex, @NonNull IEvaluation... evaluations){
return evaluate(param, evaluations).labelIndex(param, labelIndex);
}
/**
* See {@link #evaluate(String, int, IEvaluation[])}
*/
public EvaluationConfig evaluate(@NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations){
return evaluate(variable.getVarName(), labelIndex, evaluations);
}
/**
* Add evaluations to be preformed on a specified variable, without setting a label index.
*
* Setting a label index (which is not done here) is required if using a MultiDataSetIterator.
*
* @param param The param to evaluate
* @param evaluations The evaluations to preform
*/
public EvaluationConfig evaluate(@NonNull String param, @NonNull IEvaluation... evaluations){
if(this.evaluations.get(param) == null){
this.evaluations.put(param, new ArrayList<IEvaluation>());
}
this.evaluations.get(param).addAll(Arrays.asList(evaluations));
return this;
}
/**
* See {@link #evaluate(String, IEvaluation[])}
*/
public EvaluationConfig evaluate(@NonNull SDVariable variable, @NonNull IEvaluation... evaluations){
return evaluate(variable.getVarName(), evaluations);
}
/**
* Set the label index for a parameter
*/
public EvaluationConfig labelIndex(@NonNull String param, int labelIndex){
if(this.labelIndices.get(param) != null){
int existingIndex = this.labelIndices.get(param);
Preconditions.checkArgument(existingIndex == labelIndex,
"Different label index already specified for param %s. Already specified: %s, given: %s",
param, existingIndex, labelIndex);
}
labelIndices.put(param, labelIndex);
return this;
}
/**
* See {@link #labelIndex(String, int)}
*/
public EvaluationConfig labelIndex(@NonNull SDVariable variable, int labelIndex){
return labelIndex(variable.getVarName(), labelIndex);
}
/**
* Add listeners for this operation
*/
public EvaluationConfig listeners(@NonNull Listener... listeners){
this.listeners.addAll(Arrays.asList(listeners));
return this;
}
/**
* Set the data to evaluate on.
*
* Setting a label index for each variable to evaluate is required
*/
public EvaluationConfig data(@NonNull MultiDataSetIterator data){
this.data = data;
singleInput = false;
return this;
}
/**
* Set the data to evaluate on.
*
* Setting a label index for each variable to evaluate is NOT required (since there is only one input)
*/
public EvaluationConfig data(@NonNull DataSetIterator data){
this.data = new MultiDataSetIteratorAdapter(data);
singleInput = true;
return this;
}
private void validateConfig(){
Preconditions.checkNotNull(data, "Must specify data. It may not be null.");
if(!singleInput){
for(String param : this.evaluations.keySet()){
Preconditions.checkState(labelIndices.containsKey(param),
"Using multiple input dataset iterator without specifying a label index for %s", param);
}
}
for(String param : this.evaluations.keySet()){
Preconditions.checkState(sd.variableMap().containsKey(param),
"Parameter %s not present in this SameDiff graph", param);
}
}
/**
* Run the evaluation.
*
* Note that the evaluations in the returned {@link EvaluationRecord} are the evaluations set using {@link #evaluate(String, int, IEvaluation[])},
* it does not matter which you use to access results.
*
* @return The specified listeners, in an {@link EvaluationRecord} for easy access.
*/
public EvaluationRecord exec(){
validateConfig();
if(singleInput){
for(String param : this.evaluations.keySet()){
labelIndices.put(param, 0);
}
}
sd.evaluate(data, evaluations, labelIndices, listeners.toArray(new Listener[0]));
return new EvaluationRecord(evaluations);
}
}

View File

@ -0,0 +1,176 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.samediff.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
/**
* Configuration for a {@link SameDiff} training operation.
* <p>
* Used in {@link SameDiff#fit()}.
*/
@Getter
@Setter
public class FitConfig {
@Setter(AccessLevel.NONE)
private SameDiff sd;
private MultiDataSetIterator trainingData;
private MultiDataSetIterator validationData = null;
private int epochs = -1;
private int validationFrequency = 1;
@NonNull
private List<Listener> listeners = new ArrayList<>();
public FitConfig(@NonNull SameDiff sd) {
this.sd = sd;
}
/**
* Set the number of epochs to train for
*/
public FitConfig epochs(int epochs) {
this.epochs = epochs;
return this;
}
/**
* Set the training data
*/
public FitConfig train(@NonNull MultiDataSetIterator trainingData) {
this.trainingData = trainingData;
return this;
}
/**
* Set the training data
*/
public FitConfig train(@NonNull DataSetIterator trainingData) {
return train(new MultiDataSetIteratorAdapter(trainingData));
}
/**
* Set the training data and number of epochs
*/
public FitConfig train(@NonNull MultiDataSetIterator trainingData, int epochs) {
return train(trainingData).epochs(epochs);
}
/**
* Set the training data and number of epochs
*/
public FitConfig train(@NonNull DataSetIterator trainingData, int epochs) {
return train(trainingData).epochs(epochs);
}
/**
* Set the validation data
*/
public FitConfig validate(MultiDataSetIterator validationData) {
this.validationData = validationData;
return this;
}
/**
* Set the validation data
*/
public FitConfig validate(DataSetIterator validationData) {
if (validationData == null) {
return validate((MultiDataSetIterator) null);
} else {
return validate(new MultiDataSetIteratorAdapter(validationData));
}
}
/**
* Set the validation frequency. Validation will be preformed once every so many epochs.
* <p>
* Specifically, validation will be preformed when i % validationFrequency == 0
*/
public FitConfig validationFrequency(int validationFrequency) {
this.validationFrequency = validationFrequency;
return this;
}
/**
* Set the validation data and frequency
* <p>
* Specifically, validation will be preformed when i % validationFrequency == 0
*/
public FitConfig validate(MultiDataSetIterator validationData, int validationFrequency) {
return validate(validationData).validationFrequency(validationFrequency);
}
/**
* Set the validation data and frequency
* <p>
* Specifically, validation will be preformed when i % validationFrequency == 0
*/
public FitConfig validate(DataSetIterator validationData, int validationFrequency) {
return validate(validationData).validationFrequency(validationFrequency);
}
/**
* Add listeners for this operation
*/
public FitConfig listeners(@NonNull Listener... listeners) {
this.listeners.addAll(Arrays.asList(listeners));
return this;
}
private void validateConfig() {
Preconditions.checkNotNull(trainingData, "Training data must not be null");
Preconditions.checkState(epochs > 0, "Epochs must be > 0, got %s", epochs);
if (validationData != null)
Preconditions.checkState(validationFrequency > 0, "Validation Frequency must be > 0 if validation data is given, got %s", validationFrequency);
}
/**
* Do the training.
*
* @return a {@link History} object containing the history information for this training operation
* (evaluations specified in the {@link TrainingConfig}, loss values, and timing information).
*/
public History exec() {
validateConfig();
return sd.fit(trainingData, epochs, validationData, validationFrequency, listeners.toArray(new Listener[0]));
}
}

View File

@ -0,0 +1,154 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.samediff.config;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.util.TrainingUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
/**
* Configuration for a {@link SameDiff} inference operation.
*
* Used in {@link SameDiff#output()}.
*/
@Getter
@Setter
public class OutputConfig {
@Setter(AccessLevel.NONE)
private SameDiff sd;
@NonNull
private List<String> outputs = new ArrayList<>();
@NonNull
private List<Listener> listeners = new ArrayList<>();
private MultiDataSetIterator data;
public OutputConfig(@NonNull SameDiff sd) {
this.sd = sd;
}
/**
* Add required outputs
*/
public OutputConfig output(@NonNull String... outputs) {
this.outputs.addAll(Arrays.asList(outputs));
return this;
}
/**
* Add required outputs
*/
public OutputConfig output(@NonNull SDVariable... outputs) {
String[] outNames = new String[outputs.length];
for (int i = 0; i < outputs.length; i++) {
outNames[i] = outputs[i].getVarName();
}
return output(outNames);
}
/**
* Set the data to use as input.
*/
public OutputConfig data(@NonNull MultiDataSetIterator data) {
this.data = data;
return this;
}
/**
* Set the data to use as input.
*/
public OutputConfig data(@NonNull DataSetIterator data) {
this.data = new MultiDataSetIteratorAdapter(data);
return this;
}
/**
* Add listeners for this operation
*/
public OutputConfig listeners(@NonNull Listener... listeners) {
this.listeners.addAll(Arrays.asList(listeners));
return this;
}
private void validateConfig() {
Preconditions.checkNotNull(data, "Must specify data. It may not be null.");
}
/**
* Do inference and return the results.
*
* Uses concatenation on the outputs of {@link #execBatches()} which may cause issues with some inputs. RNNs with
* variable time series length and CNNs with variable image sizes will most likely have issues.
*/
public Map<String, INDArray> exec() {
return sd.output(data, listeners, outputs.toArray(new String[0]));
}
/**
* Do inference and return the results in batches.
*/
public List<Map<String, INDArray>> execBatches() {
return sd.outputBatches(data, listeners, outputs.toArray(new String[0]));
}
/**
* Do inference and return the results for the single output variable specified.
*
* Only works if exactly one output is specified.
*
* Uses concatenation on the outputs of {@link #execBatches()} which may cause issues with some inputs. RNNs with
* variable time series length and CNNs with variable image sizes will most likely have issues.
*/
public INDArray execSingle() {
Preconditions.checkState(outputs.size() == 1,
"Can only use execSingle() when exactly one output is specified, there were %s", outputs.size());
return sd.output(data, listeners, outputs.toArray(new String[0])).get(outputs.get(0));
}
/**
* Do inference and return the results (in batches) for the single output variable specified.
*
* Only works if exactly one output is specified.
*/
public List<INDArray> execSingleBatches() {
Preconditions.checkState(outputs.size() == 1,
"Can only use execSingleBatches() when exactly one output is specified, there were %s", outputs.size());
return TrainingUtils
.getSingleOutput(sd.outputBatches(data, listeners, outputs.toArray(new String[0])), outputs.get(0));
}
}

View File

@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
@ -31,6 +32,7 @@ import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import java.util.*;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
/**
@ -133,16 +135,42 @@ public abstract class AbstractSession<T, O> {
return newVarId(variable, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame());
}
/**
* @deprecated Use {@link #output(List, Map, MultiDataSet, Collection, List, At)}.
*
* @param training Uses Operation.TRAINING if true, otherwise Operation.INFERENCE
*/
@Deprecated
public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues,
MultiDataSet batch, Collection<String> requiredActivations, boolean training, At at){
if(at == null){
if(training)
at = At.defaultAt(Operation.TRAINING);
else
at = At.defaultAt(Operation.INFERENCE);
}
return output(variables, placeholderValues, batch, requiredActivations, Collections.<Listener>emptyList(), at);
}
/**
* Get the output of the session - i.e., perform inference/forward pass
*
* @param variables Name of the variables we want the arrays/activations for
* @param placeholderValues The placeholder values (if any).
* @param batch The batch data, used to call Listener.opExecution
* @param requiredActivations Additional activations that are required. Won't be outputed, but opExecution will be called. May be null.
* @return The specified variable values, optionally in the specified workspace
*/
public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues, List<Listener> listeners, boolean training, At at) {
public Map<String, T> output(@NonNull List<String> variables, Map<String, T> placeholderValues,
MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at) {
Preconditions.checkState(!variables.isEmpty(), "Variables to perform forward pass for must not be empty");
if(requiredActivations == null)
requiredActivations = Collections.emptyList();
if(at == null)
at = At.defaultAt();
//Step 0: validation - that variables exist, placeholders have arrays, etc
for (String s : variables) {
@ -164,7 +192,9 @@ public abstract class AbstractSession<T, O> {
//Step 1: determine subgraph structure we actually need to execute
//Basic plan: work backwards from the variables we want, based on the graph structure, to work out what
// we actually need to execute
initSubgraph(variables);
List<String> allRequired = new ArrayList<>(requiredActivations);
allRequired.addAll(variables);
initSubgraph(allRequired);
//Step 1a: Check that we have required placeholders
List<String> phNames = sameDiff.inputs();
@ -198,7 +228,7 @@ public abstract class AbstractSession<T, O> {
// Some Keras layers (like GRU) do different things depending on whether the model is training.
// We provide this value directly.
if(s.endsWith("keras_learning_phase")){
placeholderValues.put(s, (T) Nd4j.scalar(training));
placeholderValues.put(s, (T) Nd4j.scalar(at.operation().isTrainingPhase()));
} else {
throw new IllegalStateException(
"An input placeholder \"" + s + "\" is required to calculate the requested outputs," +
@ -302,7 +332,7 @@ public abstract class AbstractSession<T, O> {
//Execute op
FrameIter frameIter = varToExec.toFrameIter();
O parameterizedOp = getAndParameterizeOp(opName, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, placeholderValues);
T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, training, at);
T[] opOutputValues = getOutputs(parameterizedOp, frameIter, inputsToVar, inputsToVarAllIter, constPhForVar, listeners, at, batch);
//Post execution: work out what is now available for exec
@ -831,7 +861,7 @@ public abstract class AbstractSession<T, O> {
* @return The outputs of the op
*/
public abstract T[] getOutputs(O op, FrameIter outputFrameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs,
List<Listener> listeners, boolean training, At at);
List<Listener> listeners, At at, MultiDataSet batch);
/**
* This method is used to record that the specified input is required for calculating the specified output.

View File

@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* Infer datatypes for all variables.
@ -80,7 +81,7 @@ public class DataTypesSession extends AbstractSession<DataType, DataTypesSession
@Override
public DataType[] getOutputs(DataTypeCalc op, FrameIter outputFrameIter, Set<VarId> inputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, List<Listener> listeners, boolean training, At at) {
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
List<DataType> outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes());
if(dynamicUpdate) {

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.internal;
import com.google.common.collect.ImmutableMap;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.functions.DifferentialFunction;
@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
@ -106,19 +108,34 @@ public class InferenceSession extends AbstractSession<INDArray,DifferentialFunct
@Override
public INDArray[] getOutputs(DifferentialFunction op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, List<Listener> listeners, boolean training, At at) {
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
if(listeners != null && listeners.size() > 0){
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
for(Listener l : listeners){
l.preOpExecution(sameDiff, at, training, sdOp);
if(l.isActive(at.operation()))
l.preOpExecution(sameDiff, at, sdOp);
}
}
INDArray[] out = getOutputsHelper(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs);
if(listeners != null && listeners.size() > 0){
SameDiffOp sdOp = sameDiff.getOps().get(op.getOwnName());
ImmutableMap.Builder<String, INDArray> namedOutsBuilder = ImmutableMap.builder();
for(int i = 0 ; i < out.length ; i++)
namedOutsBuilder.put(sdOp.outputsOfOp.get(i), out[i]);
Map<String, INDArray> namedOuts = namedOutsBuilder.build();
for(Listener l : listeners){
l.opExecution(sameDiff, at, training, sdOp, out);
if(l.isActive(at.operation())) {
l.opExecution(sameDiff, at, batch, sdOp, out);
for(String varName : namedOuts.keySet()){
l.activationAvailable(sameDiff, at, batch, sdOp, varName, namedOuts.get(varName));
}
}
}
}
return out;

View File

@ -5,12 +5,14 @@ import lombok.NoArgsConstructor;
import lombok.Setter;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.List;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* A listener used for debugging and testing purposes - specifically for gradient checking activations internally in
@ -29,7 +31,12 @@ public class ActivationGradientCheckListener extends BaseListener {
private double eps;
@Override
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
public boolean isActive(Operation operation) {
return true;
}
@Override
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
Preconditions.checkState(variableName != null, "No variable name has been set yet. Variable name must be set before using this listener");
Preconditions.checkState(eps != 0.0, "Epsilon has not been set");

View File

@ -1,9 +1,9 @@
package org.nd4j.autodiff.validation.listeners;
import lombok.Getter;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
@ -14,6 +14,7 @@ import org.nd4j.linalg.api.ops.Op;
import java.security.MessageDigest;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.dataset.api.MultiDataSet;
public class NonInplaceValidationListener extends BaseListener {
@Getter
@ -30,7 +31,7 @@ public class NonInplaceValidationListener extends BaseListener {
}
@Override
public void preOpExecution(SameDiff sd, At at, boolean training, SameDiffOp op) {
public void preOpExecution(SameDiff sd, At at, SameDiffOp op) {
if(op.getOp().isInPlace()){
//Don't check inplace op
return;
@ -57,7 +58,7 @@ public class NonInplaceValidationListener extends BaseListener {
}
@Override
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
if(op.getOp().isInPlace()){
//Don't check inplace op
return;
@ -124,4 +125,8 @@ public class NonInplaceValidationListener extends BaseListener {
}
}
@Override
public boolean isActive(Operation operation) {
return true;
}
}

View File

@ -99,4 +99,14 @@ public interface IEvaluation<T extends IEvaluation> extends Serializable {
* @return
*/
String toYaml();
/**
* Get the value of a given metric for this evaluation.
*/
double getValue(IMetric metric);
/**
* Get a new instance of this evaluation, with the same configuration but no data.
*/
T newInstance();
}

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.evaluation;
/**
* A metric used to get a double value from an {@link IEvaluation}.
*
* Examples: {@link org.nd4j.evaluation.classification.Evaluation.Metric#ACCURACY}, {@link org.nd4j.evaluation.classification.ROC.Metric#AUPRC}.
*/
public interface IMetric {
/**
* The {@link IEvaluation} class this metric is for
*/
public Class<? extends IEvaluation> getEvaluationClass();
/**
* Whether this metric should be minimized (aka whether lower values are better).
*/
public boolean minimize();
}

View File

@ -22,7 +22,11 @@ import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.EvaluationAveraging;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.meta.Prediction;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.evaluation.serde.ConfusionMatrixDeserializer;
import org.nd4j.evaluation.serde.ConfusionMatrixSerializer;
import org.nd4j.linalg.api.buffer.DataType;
@ -83,7 +87,18 @@ import java.util.*;
@JsonIgnoreProperties({"confusionMatrixMetaData"})
public class Evaluation extends BaseEvaluation<Evaluation> {
public enum Metric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC}
public enum Metric implements IMetric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC;
@Override
public Class<? extends IEvaluation> getEvaluationClass() {
return Evaluation.class;
}
@Override
public boolean minimize() {
return false;
}
}
//What to output from the precision/recall function when we encounter an edge case
protected static final double DEFAULT_EDGE_VALUE = 0.0;
@ -122,6 +137,17 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
@Getter @Setter
protected int maxWarningClassesToPrint = 16;
protected Evaluation(int axis, Integer binaryPositiveClass, int topN, List<String> labelsList,
Double binaryDecisionThreshold, INDArray costArray, int maxWarningClassesToPrint){
this.axis = axis;
this.binaryPositiveClass = binaryPositiveClass;
this.topN = topN;
this.labelsList = labelsList;
this.binaryDecisionThreshold = binaryDecisionThreshold;
this.costArray = costArray;
this.maxWarningClassesToPrint = maxWarningClassesToPrint;
}
// Empty constructor
public Evaluation() {
this.topN = 1;
@ -190,6 +216,7 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
if (labels != null) {
createConfusion(labels.size());
}
this.topN = topN;
if(labels != null && labels.size() == 2){
this.binaryPositiveClass = 1;
@ -1869,4 +1896,17 @@ public class Evaluation extends BaseEvaluation<Evaluation> {
public static Evaluation fromYaml(String yaml) {
return fromYaml(yaml, Evaluation.class);
}
@Override
public double getValue(IMetric metric){
if(metric instanceof Metric){
return scoreForMetric((Metric) metric);
} else
throw new IllegalStateException("Can't get value for non-evaluation Metric " + metric);
}
@Override
public Evaluation newInstance() {
return new Evaluation(axis, binaryPositiveClass, topN, labelsList, binaryDecisionThreshold, costArray, maxWarningClassesToPrint);
}
}

View File

@ -21,6 +21,9 @@ import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
@ -58,7 +61,18 @@ import java.util.List;
@Data
public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
public enum Metric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC, FAR}
public enum Metric implements IMetric {ACCURACY, F1, PRECISION, RECALL, GMEASURE, MCC, FAR;
@Override
public Class<? extends IEvaluation> getEvaluationClass() {
return EvaluationBinary.class;
}
@Override
public boolean minimize() {
return false;
}
}
public static final int DEFAULT_PRECISION = 4;
public static final double DEFAULT_EDGE_VALUE = 0.0;
@ -80,6 +94,13 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
private INDArray decisionThreshold;
protected EvaluationBinary(int axis, ROCBinary rocBinary, List<String> labels, INDArray decisionThreshold){
this.axis = axis;
this.rocBinary = rocBinary;
this.labels = labels;
this.decisionThreshold = decisionThreshold;
}
/**
* Create an EvaulationBinary instance with an optional decision threshold array.
*
@ -452,10 +473,25 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
}
/**
* Calculate the G-measure for the given output
* Macro average of the Matthews correlation coefficient (MCC) (see {@link #matthewsCorrelation(int)}) for all labels.
*
* @return The macro average of the MCC for all labels.
*/
public double averageMatthewsCorrelation() {
double ret = 0.0;
for (int i = 0; i < numLabels(); i++) {
ret += matthewsCorrelation(i);
}
ret /= (double) numLabels();
return ret;
}
/**
* Calculate the macro average G-measure for the given output
*
* @param output The specified output
* @return The G-measure for the specified output
* @return The macro average of the G-measure for the specified output
*/
public double gMeasure(int output) {
double precision = precision(output);
@ -463,6 +499,21 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
return EvaluationUtils.gMeasure(precision, recall);
}
/**
* Average G-measure (see {@link #gMeasure(int)}) for all labels.
*
* @return The G-measure for all labels.
*/
public double averageGMeasure() {
double ret = 0.0;
for (int i = 0; i < numLabels(); i++) {
ret += gMeasure(i);
}
ret /= (double) numLabels();
return ret;
}
/**
* Returns the false positive rate for a given label
*
@ -679,5 +730,37 @@ public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
return fromYaml(yaml, EvaluationBinary.class);
}
@Override
public double getValue(IMetric metric){
if(metric instanceof Metric){
switch ((Metric) metric){
case ACCURACY:
return averageAccuracy();
case F1:
return averageF1();
case PRECISION:
return averagePrecision();
case RECALL:
return averageRecall();
case GMEASURE:
return averageGMeasure();
case MCC:
return averageMatthewsCorrelation();
case FAR:
return averageFalseAlarmRate();
default:
throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + metric);
}
} else
throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + metric);
}
@Override
public EvaluationBinary newInstance() {
if(rocBinary != null) {
return new EvaluationBinary(axis, rocBinary.newInstance(), labels, decisionThreshold);
} else {
return new EvaluationBinary(axis, null, labels, decisionThreshold);
}
}
}

View File

@ -21,6 +21,8 @@ import lombok.Getter;
import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.ReliabilityDiagram;
import org.nd4j.linalg.api.buffer.DataType;
@ -105,6 +107,13 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
@JsonDeserialize(using = NDArrayDeSerializer.class)
private INDArray probHistogramByLabelClass; //Histogram - for each label class separately
protected EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins) {
this.axis = axis;
this.reliabilityDiagNumBins = reliabilityDiagNumBins;
this.histogramNumBins = histogramNumBins;
this.excludeEmptyBins = excludeEmptyBins;
}
/**
* Create an EvaluationCalibration instance with the default number of bins
*/
@ -476,4 +485,14 @@ public class EvaluationCalibration extends BaseEvaluation<EvaluationCalibration>
public static EvaluationCalibration fromJson(String json){
return fromJson(json, EvaluationCalibration.class);
}
@Override
public double getValue(IMetric metric){
throw new IllegalStateException("Can't get value for non-calibration Metric " + metric);
}
@Override
public EvaluationCalibration newInstance() {
return new EvaluationCalibration(axis, reliabilityDiagNumBins, histogramNumBins, excludeEmptyBins);
}
}

View File

@ -20,6 +20,9 @@ import lombok.*;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.serde.ROCSerializer;
@ -77,11 +80,24 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
@JsonSerialize(using = ROCSerializer.class)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY)
public class ROC extends BaseEvaluation<ROC> {
/**
* AUROC: Area under ROC curve<br>
* AUPRC: Area under Precision-Recall Curve
*/
public enum Metric {AUROC, AUPRC}
public enum Metric implements IMetric {
AUROC, AUPRC;
@Override
public Class<? extends IEvaluation> getEvaluationClass() {
return ROC.class;
}
@Override
public boolean minimize() {
return false;
}
}
private static final int DEFAULT_EXACT_ALLOC_BLOCK_SIZE = 2048;
private final Map<Double, CountsForThreshold> counts = new LinkedHashMap<>();
@ -100,6 +116,13 @@ public class ROC extends BaseEvaluation<ROC> {
private int exactAllocBlockSize;
protected int axis = 1;
public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize, int axis) {
this(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize);
this.axis = axis;
}
public ROC() {
//Default to exact
this(0);
@ -811,4 +834,17 @@ public class ROC extends BaseEvaluation<ROC> {
throw new IllegalStateException("Unknown metric: " + metric);
}
}
@Override
public double getValue(IMetric metric){
if(metric instanceof Metric){
return scoreForMetric((Metric) metric);
} else
throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
}
@Override
public ROC newInstance() {
return new ROC(thresholdSteps, rocRemoveRedundantPts, exactAllocBlockSize, axis);
}
}

View File

@ -21,6 +21,9 @@ import lombok.EqualsAndHashCode;
import lombok.val;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.ROCMultiClass.Metric;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.serde.ROCArraySerializer;
@ -53,7 +56,18 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> {
* AUROC: Area under ROC curve<br>
* AUPRC: Area under Precision-Recall Curve
*/
public enum Metric {AUROC, AUPRC}
public enum Metric implements IMetric {AUROC, AUPRC;
@Override
public Class<? extends IEvaluation> getEvaluationClass() {
return ROCBinary.class;
}
@Override
public boolean minimize() {
return false;
}
}
@JsonSerialize(using = ROCArraySerializer.class)
private ROC[] underlying;
@ -65,6 +79,13 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> {
@EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality
protected int axis = 1;
protected ROCBinary(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels) {
this.thresholdSteps = thresholdSteps;
this.rocRemoveRedundantPts = rocRemoveRedundantPts;
this.axis = axis;
this.labels = labels;
}
public ROCBinary() {
this(0);
}
@ -410,4 +431,22 @@ public class ROCBinary extends BaseEvaluation<ROCBinary> {
throw new IllegalStateException("Unknown metric: " + metric);
}
}
@Override
public double getValue(IMetric metric){
if(metric instanceof Metric){
if(metric == Metric.AUPRC)
return calculateAverageAUCPR();
else if(metric == Metric.AUROC)
return calculateAverageAuc();
else
throw new IllegalStateException("Can't get value for non-binary ROC Metric " + metric);
} else
throw new IllegalStateException("Can't get value for non-binary ROC Metric " + metric);
}
@Override
public ROCBinary newInstance() {
return new ROCBinary(axis, thresholdSteps, rocRemoveRedundantPts, labels);
}
}

View File

@ -20,6 +20,9 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.ROC.Metric;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.serde.ROCArraySerializer;
@ -49,7 +52,18 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
* AUROC: Area under ROC curve<br>
* AUPRC: Area under Precision-Recall Curve
*/
public enum Metric {AUROC, AUPRC}
public enum Metric implements IMetric {AUROC, AUPRC;
@Override
public Class<? extends IEvaluation> getEvaluationClass() {
return ROCMultiClass.class;
}
@Override
public boolean minimize() {
return false;
}
}
private int thresholdSteps;
private boolean rocRemoveRedundantPts;
@ -60,6 +74,13 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
@EqualsAndHashCode.Exclude //Exclude axis: otherwise 2 Evaluation instances could contain identical stats and fail equality
protected int axis = 1;
protected ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels) {
this.thresholdSteps = thresholdSteps;
this.rocRemoveRedundantPts = rocRemoveRedundantPts;
this.axis = axis;
this.labels = labels;
}
public ROCMultiClass() {
//Default to exact
this(0);
@ -362,4 +383,22 @@ public class ROCMultiClass extends BaseEvaluation<ROCMultiClass> {
throw new IllegalStateException("Unknown metric: " + metric);
}
}
@Override
public double getValue(IMetric metric){
if(metric instanceof Metric){
if(metric == Metric.AUPRC)
return calculateAverageAUCPR();
else if(metric == Metric.AUROC)
return calculateAverageAUC();
else
throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
} else
throw new IllegalStateException("Can't get value for non-ROC Metric " + metric);
}
@Override
public ROCMultiClass newInstance() {
return new ROCMultiClass(axis, thresholdSteps, rocRemoveRedundantPts, labels);
}
}

View File

@ -0,0 +1,182 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.evaluation.custom;
import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* A evaluation using lambdas to calculate the score.
*
* Uses 3 lambdas:<br>
* EvaluationLambda: takes in the labels, predictions, mask, and metadata and returns a value of type T<br>
* MergeLambda: takes in two lists of Ts, returns one. Used in merging for distributed training<br>
* ResultLambda (in Metric): takes a list of Ts, returns a double value<br>
* <br>
* The EvaluationLambda will be called on each batch, and the results will be stored in a list.
* MergeLambda merges two of those lists for distributed training (think Spark or Map-Reduce).
* ResultLambda gets a score from that list.
*
*/
@Data
@EqualsAndHashCode(callSuper = true)
public class CustomEvaluation<T> extends BaseEvaluation<CustomEvaluation> {
/**
* The metric used to get a score for the CustomEvaluation. Uses a ResultLambda
*/
@AllArgsConstructor
@RequiredArgsConstructor
public static class Metric<T> implements IMetric{
@Getter
@NonNull private ResultLambda<T> getResult;
private boolean minimize = false;
@Override
public Class<? extends IEvaluation> getEvaluationClass() {
return CustomEvaluation.class;
}
@Override
public boolean minimize() {
return minimize;
}
/**
* A metric that takes the average of a list of doubles
*/
public static Metric<Double> doubleAverage(boolean minimize){
return new Metric<>(new ResultLambda<Double>() {
@Override
public double toResult(List<Double> data) {
int count = 0;
double sum = 0;
for (Double d : data) {
count++;
sum += d;
}
return sum / count;
}
}, minimize);
}
/**
* A metric that takes the max of a list of doubles
*/
public static Metric<Double> doubleMax(boolean minimize){
return new Metric<>(new ResultLambda<Double>() {
@Override
public double toResult(List<Double> data) {
double max = 0;
for (Double d : data) {
if(d > max)
max = d;
}
return max;
}
}, minimize);
}
/**
* A metric that takes the min of a list of doubles
*/
public static Metric<Double> doubleMin(boolean minimize){
return new Metric<>(new ResultLambda<Double>() {
@Override
public double toResult(List<Double> data) {
double max = 0;
for (Double d : data) {
if(d < max)
max = d;
}
return max;
}
}, minimize);
}
}
/**
* A MergeLambda that merges by concatenating the two lists
*/
public static <R> MergeLambda<R> mergeConcatenate(){
return new MergeLambda<R>() {
@Override
public List<R> merge(List<R> a, List<R> b) {
List<R> res = Lists.newArrayList(a);
res.addAll(b);
return res;
}
};
}
@NonNull private EvaluationLambda<T> evaluationLambda;
@NonNull private MergeLambda<T> mergeLambda;
private List<T> evaluations = new ArrayList<>();
@Override
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray,
List<? extends Serializable> recordMetaData) {
evaluations.add(evaluationLambda.eval(labels, networkPredictions, maskArray, recordMetaData));
}
@Override
public void merge(CustomEvaluation other) {
evaluations = mergeLambda.merge(evaluations, other.evaluations);
}
@Override
public void reset() {
evaluations = new ArrayList<>();
}
@Override
public String stats() {
return "";
}
@Override
public double getValue(IMetric metric) {
if(metric instanceof Metric){
return ((Metric<T>) metric).getGetResult().toResult(evaluations);
} else
throw new IllegalStateException("Can't get value for non-regression Metric " + metric);
}
@Override
public CustomEvaluation<T> newInstance() {
return new CustomEvaluation<T>(evaluationLambda, mergeLambda);
}
}

View File

@ -0,0 +1,31 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.evaluation.custom;
import java.io.Serializable;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* A lambda used to get an evaluation result for a batch
* See {@link CustomEvaluation}
*/
public interface EvaluationLambda<T> {
public T eval(INDArray labels, INDArray networkPredictions, INDArray maskArray,
List<? extends Serializable> recordMetaData);
}

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.evaluation.custom;
import com.google.common.collect.Lists;
import java.util.List;
/**
* A lambda used to merge two lists of evaluation results
* See {@link CustomEvaluation}
*/
public interface MergeLambda<T> {
public List<T> merge(List<T> a, List<T> b);
}

View File

@ -0,0 +1,27 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.evaluation.custom;
import java.util.List;
/**
* A lambda used to get a score from a list of evaluation results
* See {@link CustomEvaluation}
*/
public interface ResultLambda<T> {
public double toResult(List<T> data);
}

View File

@ -20,6 +20,8 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.val;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.same.ASum;
@ -54,12 +56,18 @@ import java.util.List;
@EqualsAndHashCode(callSuper = true)
public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
public enum Metric { MSE, MAE, RMSE, RSE, PC, R2;
public enum Metric implements IMetric { MSE, MAE, RMSE, RSE, PC, R2;
@Override
public Class<? extends IEvaluation> getEvaluationClass() {
return RegressionEvaluation.class;
}
/**
* @return True if the metric should be minimized, or false if the metric should be maximized.
* For example, MSE of 0 is best, but R^2 of 1.0 is best
*/
@Override
public boolean minimize(){
if(this == R2 || this == PC){
return false;
@ -106,6 +114,12 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
@JsonDeserialize(using = NDArrayTextDeSerializer.class)
private INDArray sumLabels;
protected RegressionEvaluation(int axis, List<String> columnNames, long precision){
this.axis = axis;
this.columnNames = columnNames;
this.precision = precision;
}
public RegressionEvaluation() {
this(null, DEFAULT_PRECISION);
}
@ -568,6 +582,14 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
return ret / (double) numColumns();
}
@Override
public double getValue(IMetric metric){
if(metric instanceof Metric){
return scoreForMetric((Metric) metric);
} else
throw new IllegalStateException("Can't get value for non-regression Metric " + metric);
}
public double scoreForMetric(Metric metric){
switch (metric){
case MSE:
@ -590,4 +612,9 @@ public class RegressionEvaluation extends BaseEvaluation<RegressionEvaluation> {
public static RegressionEvaluation fromJson(String json){
return fromJson(json, RegressionEvaluation.class);
}
@Override
public RegressionEvaluation newInstance() {
return new RegressionEvaluation(axis, columnNames, precision);
}
}

View File

@ -72,7 +72,8 @@ public class TestSessions extends BaseNd4jTest {
m.put("x", x);
m.put("y", y);
Map<String,INDArray> outMap = is.output(Collections.singletonList("out"), m, null, true, null);
Map<String,INDArray> outMap = is.output(Collections.singletonList("out"), m, null,
Collections.<String>emptyList(), true, null);
assertEquals(1, outMap.size());
assertEquals(outExp, outMap.get("out"));
@ -109,7 +110,8 @@ public class TestSessions extends BaseNd4jTest {
m.put("y", y);
System.out.println("----------------------------------");
Map<String,INDArray> outMap = is.output(Collections.singletonList("d"), m, null, false, null);
Map<String,INDArray> outMap = is.output(Collections.singletonList("d"), m, null,
Collections.<String>emptyList(), false, null);
assertEquals(1, outMap.size());
assertEquals(dExp, outMap.get("d"));
@ -143,7 +145,8 @@ public class TestSessions extends BaseNd4jTest {
InferenceSession is = new InferenceSession(sd);
// String outName = merge.getVarName();
String outName = outVar.getVarName();
Map<String,INDArray> outMap = is.output(Collections.singletonList(outName), m, null, false, null);
Map<String,INDArray> outMap = is.output(Collections.singletonList(outName), m, null,
Collections.<String>emptyList(), false, null);
assertEquals(1, outMap.size());
INDArray out = outMap.get(outName);
@ -178,7 +181,8 @@ public class TestSessions extends BaseNd4jTest {
String n = merge.getVarName();
System.out.println("----------------------------------");
Map<String,INDArray> outMap = is.output(Collections.singletonList(n), m, null, false, null);
Map<String,INDArray> outMap = is.output(Collections.singletonList(n), m, null, Collections.<String>emptyList(),
false, null);
assertEquals(1, outMap.size());
assertEquals(expTrue, outMap.get(n));
@ -187,7 +191,7 @@ public class TestSessions extends BaseNd4jTest {
//Check false case:
bArr.assign(0);
is = new InferenceSession(sd);
outMap = is.output(Collections.singletonList(n), m, null, false, null);
outMap = is.output(Collections.singletonList(n), m, null, Collections.<String>emptyList(), false, null);
assertEquals(1, outMap.size());
assertEquals(expFalse, outMap.get(n));
}
@ -218,7 +222,8 @@ public class TestSessions extends BaseNd4jTest {
String n = "while/Exit";
String n2 = "while/Exit_1";
Map<String, INDArray> m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null, false, null);
Map<String, INDArray> m = is.output(Arrays.asList(n, n2), Collections.emptyMap(), null,
Collections.<String>emptyList(), false, null);
assertEquals(2, m.size());
INDArray exp = Nd4j.scalar((float)numIter);

View File

@ -0,0 +1,119 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.autodiff.samediff.listeners;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.ScoreListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.Evaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.IrisDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.weightinit.impl.XavierInitScheme;
public class ListenerTest extends BaseNd4jTest {
public ListenerTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void irisHistoryTest() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NormalizerStandardize std = new NormalizerStandardize();
std.fit(iter);
iter.setPreProcessor(std);
Nd4j.getRandom().setSeed(12345);
SameDiff sd = SameDiff.create();
SDVariable in = sd.placeHolder("input", DataType.FLOAT, -1, 4);
SDVariable label = sd.placeHolder("label", DataType.FLOAT, -1, 3);
SDVariable w0 = sd.var("w0", new XavierInitScheme('c', 4, 10), DataType.FLOAT, 4, 10);
SDVariable b0 = sd.zero("b0", DataType.FLOAT, 1, 10);
SDVariable w1 = sd.var("w1", new XavierInitScheme('c', 10, 3), DataType.FLOAT, 10, 3);
SDVariable b1 = sd.zero("b1", DataType.FLOAT, 1, 3);
SDVariable z0 = in.mmul(w0).add(b0);
SDVariable a0 = sd.nn().relu(z0, 0);
SDVariable z1 = a0.mmul(w1).add(b1);
SDVariable predictions = sd.nn().softmax("predictions", z1, 1);
SDVariable loss = sd.loss.softmaxCrossEntropy("loss", label, predictions);
sd.setLossVariables("loss");
IUpdater updater = new Adam(1e-2);
Evaluation e = new Evaluation();
TrainingConfig conf = new TrainingConfig.Builder()
.l2(1e-4)
.updater(updater)
.dataSetFeatureMapping("input")
.dataSetLabelMapping("label")
.trainEvaluation(predictions, 0, e)
.build();
sd.setTrainingConfig(conf);
sd.setListeners(new ScoreListener(1));
History hist = sd.fit(iter, 50);
// Map<String, List<IEvaluation>> evalMap = new HashMap<>();
// evalMap.put("prediction", Collections.singletonList(e));
//
// sd.evaluateMultiple(iter, evalMap);
e = (Evaluation) hist.finalTrainingEvaluations().evaluation(predictions);
System.out.println(e.stats());
float[] losses = hist.lossCurve().meanLoss(loss);
System.out.println("Losses: " + Arrays.toString(losses));
double acc = hist.finalTrainingEvaluations().getValue(Metric.ACCURACY);
assertTrue("Accuracy < 75%, was " + acc, acc >= 0.75);
}
@Override
public char ordering() {
return 'c';
}
}

View File

@ -0,0 +1,65 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.evaluation;
import static org.junit.Assert.assertEquals;
import java.util.Arrays;
import org.junit.Test;
import org.nd4j.evaluation.custom.CustomEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.primitives.Pair;
public class CustomEvaluationTest extends BaseNd4jTest {
public CustomEvaluationTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void customEvalTest(){
CustomEvaluation accuracyEval = new CustomEvaluation<Pair<Number, Long>>(
(labels, pred, mask, meta) -> new Pair<>(labels.eq(pred).castTo(DataType.INT).sumNumber(), labels.size(0)),
CustomEvaluation.mergeConcatenate());
accuracyEval.eval(Nd4j.createFromArray(1, 1, 2, 1, 3), Nd4j.createFromArray(1, 1, 4, 1, 2));
double acc = accuracyEval.getValue(new CustomEvaluation.Metric<Pair<Number, Long>>(
(list) -> {
int sum = 0;
int count = 0;
for(Pair<Number, Long> p : list){
sum += p.getFirst().intValue();
count += p.getSecond();
}
return ((double) (sum)) / count;
}
));
assertEquals("Accuracy", acc, 3.0/5, 0.001);
}
}

View File

@ -3,6 +3,7 @@ package org.nd4j.evaluation;
import org.junit.Test;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.factory.Nd4jBackend;
@ -40,7 +41,7 @@ public class EmptyEvaluationTests extends BaseNd4jTest {
RegressionEvaluation re = new RegressionEvaluation();
re.stats();
for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) {
for (Metric m : Metric.values()) {
try {
re.scoreForMetric(m);
} catch (Throwable t){

View File

@ -0,0 +1,117 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.evaluation;
import static org.junit.Assert.assertEquals;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
public class NewInstanceTest extends BaseNd4jTest {
public NewInstanceTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testNewInstances() {
boolean print = true;
Nd4j.getRandom().setSeed(12345);
Evaluation evaluation = new Evaluation();
EvaluationBinary evaluationBinary = new EvaluationBinary();
ROC roc = new ROC(2);
ROCBinary roc2 = new ROCBinary(2);
ROCMultiClass roc3 = new ROCMultiClass(2);
RegressionEvaluation regressionEvaluation = new RegressionEvaluation();
EvaluationCalibration ec = new EvaluationCalibration();
IEvaluation[] arr = new IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec};
INDArray evalLabel1 = Nd4j.create(10, 3);
for (int i = 0; i < 10; i++) {
evalLabel1.putScalar(i, i % 3, 1.0);
}
INDArray evalProb1 = Nd4j.rand(10, 3);
evalProb1.diviColumnVector(evalProb1.sum(1));
evaluation.eval(evalLabel1, evalProb1);
roc3.eval(evalLabel1, evalProb1);
ec.eval(evalLabel1, evalProb1);
INDArray evalLabel2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5));
INDArray evalProb2 = Nd4j.rand(10, 3);
evaluationBinary.eval(evalLabel2, evalProb2);
roc2.eval(evalLabel2, evalProb2);
INDArray evalLabel3 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5));
INDArray evalProb3 = Nd4j.rand(10, 1);
roc.eval(evalLabel3, evalProb3);
INDArray reg1 = Nd4j.rand(10, 3);
INDArray reg2 = Nd4j.rand(10, 3);
regressionEvaluation.eval(reg1, reg2);
Evaluation evaluation2 = evaluation.newInstance();
EvaluationBinary evaluationBinary2 = evaluationBinary.newInstance();
ROC roc_2 = roc.newInstance();
ROCBinary roc22 = roc2.newInstance();
ROCMultiClass roc32 = roc3.newInstance();
RegressionEvaluation regressionEvaluation2 = regressionEvaluation.newInstance();
EvaluationCalibration ec2 = ec.newInstance();
IEvaluation[] arr2 = new IEvaluation[] {evaluation2, evaluationBinary2, roc_2, roc22, roc32, regressionEvaluation2, ec2};
evaluation2.eval(evalLabel1, evalProb1);
roc32.eval(evalLabel1, evalProb1);
ec2.eval(evalLabel1, evalProb1);
evaluationBinary2.eval(evalLabel2, evalProb2);
roc22.eval(evalLabel2, evalProb2);
roc_2.eval(evalLabel3, evalProb3);
regressionEvaluation2.eval(reg1, reg2);
for (int i = 0 ; i < arr.length ; i++) {
IEvaluation e = arr[i];
IEvaluation e2 = arr2[i];
assertEquals("Json not equal ", e.toJson(), e2.toJson());
assertEquals("Stats not equal ", e.stats(), e2.stats());
}
}
}

View File

@ -17,8 +17,8 @@
package org.nd4j.evaluation;
import org.junit.Test;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
@ -256,7 +256,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
e3d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) {
for (Metric m : Metric.values()) {
double d1 = e3d.scoreForMetric(m);
double d2 = e2d.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6);
@ -288,7 +288,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
e4d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()) {
for (Metric m : Metric.values()) {
double d1 = e4d.scoreForMetric(m);
double d2 = e2d.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6);
@ -347,7 +347,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
RegressionEvaluation e2d_m2 = new RegressionEvaluation();
e4d_m2.eval(label, prediction, perOutMask);
e2d_m2.eval(l2d, p2d, m2d);
for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){
for(Metric m : Metric.values()){
double d1 = e4d_m2.scoreForMetric(m);
double d2 = e2d_m2.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6);
@ -382,7 +382,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
RegressionEvaluation e2d_m1 = new RegressionEvaluation();
e4d_m1.eval(label, prediction, mask1dPerEx);
e2d_m1.eval(l2d, p2d);
for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){
for(Metric m : Metric.values()){
double d1 = e4d_m1.scoreForMetric(m);
double d2 = e2d_m1.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6);
@ -409,7 +409,7 @@ public class RegressionEvalTest extends BaseNd4jTest {
RegressionEvaluation e2d_m2 = new RegressionEvaluation();
e4d_m2.eval(label, prediction, perOutMask);
e2d_m2.eval(l2d, p2d, m2d);
for(RegressionEvaluation.Metric m : RegressionEvaluation.Metric.values()){
for(Metric m : Metric.values()){
double d1 = e4d_m2.scoreForMetric(m);
double d2 = e2d_m2.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6);

View File

@ -4,11 +4,13 @@ import lombok.Getter;
import lombok.Setter;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.*;
import org.nd4j.linalg.dataset.api.MultiDataSet;
public class OpExecOrderListener extends BaseListener {
@ -22,7 +24,7 @@ public class OpExecOrderListener extends BaseListener {
}
@Override
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
String opName = op.getName();
if(!opSet.contains(opName)){
opNamesList.add(opName);
@ -30,4 +32,8 @@ public class OpExecOrderListener extends BaseListener {
}
}
@Override
public boolean isActive(Operation operation) {
return true;
}
}

View File

@ -20,9 +20,11 @@ import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
@ -30,6 +32,11 @@ import java.io.File;
@Slf4j
public class ImportDebugListener extends BaseListener {
@Override
public boolean isActive(Operation operation) {
return true;
}
public enum OnFailure {EXCEPTION, LOG};
private File baseDir;
@ -49,7 +56,7 @@ public class ImportDebugListener extends BaseListener {
}
@Override
public void opExecution(SameDiff sd, At at, boolean training, SameDiffOp op, INDArray[] outputs) {
public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) {
//No op
for( int i=0; i<outputs.length; i++ ) {