2019-06-06 15:21:15 +03:00

1094 lines
43 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.evaluation;
import org.junit.Test;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.util.FeatureUtil;
import java.text.DecimalFormat;
import java.util.*;
import static org.junit.Assert.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
/**
* Created by agibsonccc on 12/22/14.
*/
public class EvalTest extends BaseNd4jTest {
public EvalTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testEval() {
int classNum = 5;
Evaluation eval = new Evaluation (classNum);
// Testing the edge case when some classes do not have true positive
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 5); //[1,0,0,0,0]
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 5); //[1,0,0,0,0]
eval.eval(trueOutcome, predictedOutcome);
assertEquals(1, eval.classCount(0));
assertEquals(1.0, eval.f1(), 1e-1);
// Testing more than one sample. eval() does not reset the Evaluation instance
INDArray trueOutcome2 = FeatureUtil.toOutcomeVector(1, 5); //[0,1,0,0,0]
INDArray predictedOutcome2 = FeatureUtil.toOutcomeVector(0, 5); //[1,0,0,0,0]
eval.eval(trueOutcome2, predictedOutcome2);
// Verified with sklearn in Python
// from sklearn.metrics import classification_report
// classification_report(['a', 'a'], ['a', 'b'], labels=['a', 'b', 'c', 'd', 'e'])
assertEquals(eval.f1(), 0.6, 1e-1);
// The first entry is 0 label
assertEquals(1, eval.classCount(0));
// The first entry is 1 label
assertEquals(1, eval.classCount(1));
// Class 0: one positive, one negative -> (one true positive, one false positive); no true/false negatives
assertEquals(1, eval.positive().get(0), 0);
assertEquals(1, eval.negative().get(0), 0);
assertEquals(1, eval.truePositives().get(0), 0);
assertEquals(1, eval.falsePositives().get(0), 0);
assertEquals(0, eval.trueNegatives().get(0), 0);
assertEquals(0, eval.falseNegatives().get(0), 0);
// The rest are negative
assertEquals(1, eval.negative().get(0), 0);
// 2 rows and only the first is correct
assertEquals(0.5, eval.accuracy(), 0);
}
@Test
public void testEval2() {
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
Evaluation first = null;
String sFirst = null;
try {
for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
//Confusion matrix:
//actual 0 20 3
//actual 1 10 5
Evaluation evaluation = new Evaluation(Arrays.asList("class0", "class1"));
INDArray predicted0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype);
INDArray predicted1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype);
INDArray actual0 = Nd4j.create(new double[]{1, 0}, new long[]{1, 2}).castTo(lpDtype);
INDArray actual1 = Nd4j.create(new double[]{0, 1}, new long[]{1, 2}).castTo(lpDtype);
for (int i = 0; i < 20; i++) {
evaluation.eval(actual0, predicted0);
}
for (int i = 0; i < 3; i++) {
evaluation.eval(actual0, predicted1);
}
for (int i = 0; i < 10; i++) {
evaluation.eval(actual1, predicted0);
}
for (int i = 0; i < 5; i++) {
evaluation.eval(actual1, predicted1);
}
assertEquals(20, evaluation.truePositives().get(0), 0);
assertEquals(3, evaluation.falseNegatives().get(0), 0);
assertEquals(10, evaluation.falsePositives().get(0), 0);
assertEquals(5, evaluation.trueNegatives().get(0), 0);
assertEquals((20.0 + 5) / (20 + 3 + 10 + 5), evaluation.accuracy(), 1e-6);
String s = evaluation.stats();
if(first == null) {
first = evaluation;
sFirst = s;
} else {
assertEquals(first, evaluation);
assertEquals(sFirst, s);
}
}
}
} finally {
Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
}
}
@Test
public void testStringListLabels() {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
List<String> labelsList = new ArrayList<>();
labelsList.add("hobbs");
labelsList.add("cal");
Evaluation eval = new Evaluation(labelsList);
eval.eval(trueOutcome, predictedOutcome);
assertEquals(1, eval.classCount(0));
assertEquals(labelsList.get(0), eval.getClassLabel(0));
}
@Test
public void testStringHashLabels() {
INDArray trueOutcome = FeatureUtil.toOutcomeVector(0, 2);
INDArray predictedOutcome = FeatureUtil.toOutcomeVector(0, 2);
Map<Integer, String> labelsMap = new HashMap<>();
labelsMap.put(0, "hobbs");
labelsMap.put(1, "cal");
Evaluation eval = new Evaluation(labelsMap);
eval.eval(trueOutcome, predictedOutcome);
assertEquals(1, eval.classCount(0));
assertEquals(labelsMap.get(0), eval.getClassLabel(0));
}
@Test
public void testEvalMasking() {
int miniBatch = 5;
int nOut = 3;
int tsLength = 6;
INDArray labels = Nd4j.zeros(miniBatch, nOut, tsLength);
INDArray predicted = Nd4j.zeros(miniBatch, nOut, tsLength);
Nd4j.getRandom().setSeed(12345);
Random r = new Random(12345);
for (int i = 0; i < miniBatch; i++) {
for (int j = 0; j < tsLength; j++) {
INDArray rand = Nd4j.rand(1, nOut);
rand.divi(rand.sumNumber());
predicted.put(new INDArrayIndex[] {NDArrayIndex.point(i), all(), NDArrayIndex.point(j)},
rand);
int idx = r.nextInt(nOut);
labels.putScalar(new int[] {i, idx, j}, 1.0);
}
}
//Create a longer labels/predicted with mask for first and last time step
//Expect masked evaluation to be identical to original evaluation
INDArray labels2 = Nd4j.zeros(miniBatch, nOut, tsLength + 2);
labels2.put(new INDArrayIndex[] {all(), all(),
interval(1, tsLength + 1)}, labels);
INDArray predicted2 = Nd4j.zeros(miniBatch, nOut, tsLength + 2);
predicted2.put(new INDArrayIndex[] {all(), all(),
interval(1, tsLength + 1)}, predicted);
INDArray labelsMask = Nd4j.ones(miniBatch, tsLength + 2);
for (int i = 0; i < miniBatch; i++) {
labelsMask.putScalar(new int[] {i, 0}, 0.0);
labelsMask.putScalar(new int[] {i, tsLength + 1}, 0.0);
}
Evaluation evaluation = new Evaluation();
evaluation.evalTimeSeries(labels, predicted);
Evaluation evaluation2 = new Evaluation();
evaluation2.evalTimeSeries(labels2, predicted2, labelsMask);
System.out.println(evaluation.stats());
System.out.println(evaluation2.stats());
assertEquals(evaluation.accuracy(), evaluation2.accuracy(), 1e-12);
assertEquals(evaluation.f1(), evaluation2.f1(), 1e-12);
assertMapEquals(evaluation.falsePositives(), evaluation2.falsePositives());
assertMapEquals(evaluation.falseNegatives(), evaluation2.falseNegatives());
assertMapEquals(evaluation.truePositives(), evaluation2.truePositives());
assertMapEquals(evaluation.trueNegatives(), evaluation2.trueNegatives());
for (int i = 0; i < nOut; i++)
assertEquals(evaluation.classCount(i), evaluation2.classCount(i));
}
private static void assertMapEquals(Map<Integer, Integer> first, Map<Integer, Integer> second) {
assertEquals(first.keySet(), second.keySet());
for (Integer i : first.keySet()) {
assertEquals(first.get(i), second.get(i));
}
}
@Test
public void testFalsePerfectRecall() {
int testSize = 100;
int numClasses = 5;
int winner = 1;
int seed = 241;
INDArray labels = Nd4j.zeros(testSize, numClasses);
INDArray predicted = Nd4j.zeros(testSize, numClasses);
Nd4j.getRandom().setSeed(seed);
Random r = new Random(seed);
//Modelling the situation when system predicts the same class every time
for (int i = 0; i < testSize; i++) {
//Generating random prediction but with a guaranteed winner
INDArray rand = Nd4j.rand(1, numClasses);
rand.put(0, winner, rand.sumNumber());
rand.divi(rand.sumNumber());
predicted.put(new INDArrayIndex[] {NDArrayIndex.point(i), all()}, rand);
//Generating random label
int label = r.nextInt(numClasses);
labels.putScalar(new int[] {i, label}, 1.0);
}
//Explicitly specify the amount of classes
Evaluation eval = new Evaluation(numClasses);
eval.eval(labels, predicted);
//For sure we shouldn't arrive at 100% recall unless we guessed everything right for every class
assertNotEquals(1.0, eval.recall());
}
@Test
public void testEvaluationMerging() {
int nRows = 20;
int nCols = 3;
Random r = new Random(12345);
INDArray actual = Nd4j.create(nRows, nCols);
INDArray predicted = Nd4j.create(nRows, nCols);
for (int i = 0; i < nRows; i++) {
int x1 = r.nextInt(nCols);
int x2 = r.nextInt(nCols);
actual.putScalar(new int[] {i, x1}, 1.0);
predicted.putScalar(new int[] {i, x2}, 1.0);
}
Evaluation evalExpected = new Evaluation();
evalExpected.eval(actual, predicted);
//Now: split into 3 separate evaluation objects -> expect identical values after merging
Evaluation eval1 = new Evaluation();
eval1.eval(actual.get(interval(0, 5), all()),
predicted.get(interval(0, 5), all()));
Evaluation eval2 = new Evaluation();
eval2.eval(actual.get(interval(5, 10), all()),
predicted.get(interval(5, 10), all()));
Evaluation eval3 = new Evaluation();
eval3.eval(actual.get(interval(10, nRows), all()),
predicted.get(interval(10, nRows), all()));
eval1.merge(eval2);
eval1.merge(eval3);
checkEvaluationEquality(evalExpected, eval1);
//Next: check evaluation merging with empty, and empty merging with non-empty
eval1 = new Evaluation();
eval1.eval(actual.get(interval(0, 5), all()),
predicted.get(interval(0, 5), all()));
Evaluation evalInitiallyEmpty = new Evaluation();
evalInitiallyEmpty.merge(eval1);
evalInitiallyEmpty.merge(eval2);
evalInitiallyEmpty.merge(eval3);
checkEvaluationEquality(evalExpected, evalInitiallyEmpty);
eval1.merge(new Evaluation());
eval1.merge(eval2);
eval1.merge(new Evaluation());
eval1.merge(eval3);
checkEvaluationEquality(evalExpected, eval1);
}
private static void checkEvaluationEquality(Evaluation evalExpected, Evaluation evalActual) {
assertEquals(evalExpected.accuracy(), evalActual.accuracy(), 1e-3);
assertEquals(evalExpected.f1(), evalActual.f1(), 1e-3);
assertEquals(evalExpected.getNumRowCounter(), evalActual.getNumRowCounter(), 1e-3);
assertMapEquals(evalExpected.falseNegatives(), evalActual.falseNegatives());
assertMapEquals(evalExpected.falsePositives(), evalActual.falsePositives());
assertMapEquals(evalExpected.trueNegatives(), evalActual.trueNegatives());
assertMapEquals(evalExpected.truePositives(), evalActual.truePositives());
assertEquals(evalExpected.precision(), evalActual.precision(), 1e-3);
assertEquals(evalExpected.recall(), evalActual.recall(), 1e-3);
assertEquals(evalExpected.falsePositiveRate(), evalActual.falsePositiveRate(), 1e-3);
assertEquals(evalExpected.falseNegativeRate(), evalActual.falseNegativeRate(), 1e-3);
assertEquals(evalExpected.falseAlarmRate(), evalActual.falseAlarmRate(), 1e-3);
assertEquals(evalExpected.getConfusionMatrix(), evalActual.getConfusionMatrix());
}
@Test
public void testSingleClassBinaryClassification() {
Evaluation eval = new Evaluation(1);
for (int xe = 0; xe < 3; xe++) {
INDArray zero = Nd4j.create(1,1);
INDArray one = Nd4j.ones(1,1);
//One incorrect, three correct
eval.eval(one, zero);
eval.eval(one, one);
eval.eval(one, one);
eval.eval(zero, zero);
System.out.println(eval.stats());
assertEquals(0.75, eval.accuracy(), 1e-6);
assertEquals(4, eval.getNumRowCounter());
assertEquals(1, (int) eval.truePositives().get(0));
assertEquals(2, (int) eval.truePositives().get(1));
assertEquals(1, (int) eval.falseNegatives().get(1));
eval.reset();
}
}
@Test
public void testEvalInvalid() {
Evaluation e = new Evaluation(5);
e.eval(0, 1);
e.eval(1, 0);
e.eval(1, 1);
System.out.println(e.stats());
char c = "\uFFFD".toCharArray()[0];
System.out.println(c);
assertFalse(e.stats().contains("\uFFFD"));
}
@Test
public void testEvalMethods() {
//Check eval(int,int) vs. eval(INDArray,INDArray)
Evaluation e1 = new Evaluation(4);
Evaluation e2 = new Evaluation(4);
INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0}, new long[]{1, 4});
INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0}, new long[]{1, 4});
INDArray i2 = Nd4j.create(new double[] {0, 0, 1, 0}, new long[]{1, 4});
INDArray i3 = Nd4j.create(new double[] {0, 0, 0, 1}, new long[]{1, 4});
e1.eval(i0, i0); //order: actual, predicted
e2.eval(0, 0); //order: predicted, actual
e1.eval(i0, i2);
e2.eval(2, 0);
e1.eval(i0, i2);
e2.eval(2, 0);
e1.eval(i1, i2);
e2.eval(2, 1);
e1.eval(i3, i3);
e2.eval(3, 3);
e1.eval(i3, i0);
e2.eval(0, 3);
e1.eval(i3, i0);
e2.eval(0, 3);
org.nd4j.evaluation.classification.ConfusionMatrix<Integer> cm = e1.getConfusionMatrix();
assertEquals(1, cm.getCount(0, 0)); //Order: actual, predicted
assertEquals(2, cm.getCount(0, 2));
assertEquals(1, cm.getCount(1, 2));
assertEquals(1, cm.getCount(3, 3));
assertEquals(2, cm.getCount(3, 0));
System.out.println(e1.stats());
System.out.println(e2.stats());
assertEquals(e1.stats(), e2.stats());
}
@Test
public void testTopNAccuracy() {
Evaluation e = new Evaluation(null, 3);
INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0, 0}, new long[]{1, 5});
INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0, 0}, new long[]{1, 5});
INDArray p0_0 = Nd4j.create(new double[] {0.8, 0.05, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 0: highest prob
INDArray p0_1 = Nd4j.create(new double[] {0.4, 0.45, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 0: 2nd highest prob
INDArray p0_2 = Nd4j.create(new double[] {0.1, 0.45, 0.35, 0.05, 0.05}, new long[]{1, 5}); //class 0: 3rd highest prob
INDArray p0_3 = Nd4j.create(new double[] {0.1, 0.40, 0.30, 0.15, 0.05}, new long[]{1, 5}); //class 0: 4th highest prob
INDArray p1_0 = Nd4j.create(new double[] {0.05, 0.80, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 1: highest prob
INDArray p1_1 = Nd4j.create(new double[] {0.45, 0.40, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 1: 2nd highest prob
INDArray p1_2 = Nd4j.create(new double[] {0.35, 0.10, 0.45, 0.05, 0.05}, new long[]{1, 5}); //class 1: 3rd highest prob
INDArray p1_3 = Nd4j.create(new double[] {0.40, 0.10, 0.30, 0.15, 0.05}, new long[]{1, 5}); //class 1: 4th highest prob
// Correct TopNCorrect Total
e.eval(i0, p0_0); // 1 1 1
assertEquals(1.0, e.accuracy(), 1e-6);
assertEquals(1.0, e.topNAccuracy(), 1e-6);
assertEquals(1, e.getTopNCorrectCount());
assertEquals(1, e.getTopNTotalCount());
e.eval(i0, p0_1); // 1 2 2
assertEquals(0.5, e.accuracy(), 1e-6);
assertEquals(1.0, e.topNAccuracy(), 1e-6);
assertEquals(2, e.getTopNCorrectCount());
assertEquals(2, e.getTopNTotalCount());
e.eval(i0, p0_2); // 1 3 3
assertEquals(1.0 / 3, e.accuracy(), 1e-6);
assertEquals(1.0, e.topNAccuracy(), 1e-6);
assertEquals(3, e.getTopNCorrectCount());
assertEquals(3, e.getTopNTotalCount());
e.eval(i0, p0_3); // 1 3 4
assertEquals(0.25, e.accuracy(), 1e-6);
assertEquals(0.75, e.topNAccuracy(), 1e-6);
assertEquals(3, e.getTopNCorrectCount());
assertEquals(4, e.getTopNTotalCount());
e.eval(i1, p1_0); // 2 4 5
assertEquals(2.0 / 5, e.accuracy(), 1e-6);
assertEquals(4.0 / 5, e.topNAccuracy(), 1e-6);
e.eval(i1, p1_1); // 2 5 6
assertEquals(2.0 / 6, e.accuracy(), 1e-6);
assertEquals(5.0 / 6, e.topNAccuracy(), 1e-6);
e.eval(i1, p1_2); // 2 6 7
assertEquals(2.0 / 7, e.accuracy(), 1e-6);
assertEquals(6.0 / 7, e.topNAccuracy(), 1e-6);
e.eval(i1, p1_3); // 2 6 8
assertEquals(2.0 / 8, e.accuracy(), 1e-6);
assertEquals(6.0 / 8, e.topNAccuracy(), 1e-6);
assertEquals(6, e.getTopNCorrectCount());
assertEquals(8, e.getTopNTotalCount());
System.out.println(e.stats());
}
@Test
public void testTopNAccuracyMerging() {
Evaluation e1 = new Evaluation(null, 3);
Evaluation e2 = new Evaluation(null, 3);
INDArray i0 = Nd4j.create(new double[] {1, 0, 0, 0, 0}, new long[]{1, 5});
INDArray i1 = Nd4j.create(new double[] {0, 1, 0, 0, 0}, new long[]{1, 5});
INDArray p0_0 = Nd4j.create(new double[] {0.8, 0.05, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 0: highest prob
INDArray p0_1 = Nd4j.create(new double[] {0.4, 0.45, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 0: 2nd highest prob
INDArray p0_2 = Nd4j.create(new double[] {0.1, 0.45, 0.35, 0.05, 0.05}, new long[]{1, 5}); //class 0: 3rd highest prob
INDArray p0_3 = Nd4j.create(new double[] {0.1, 0.40, 0.30, 0.15, 0.05}, new long[]{1, 5}); //class 0: 4th highest prob
INDArray p1_0 = Nd4j.create(new double[] {0.05, 0.80, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 1: highest prob
INDArray p1_1 = Nd4j.create(new double[] {0.45, 0.40, 0.05, 0.05, 0.05}, new long[]{1, 5}); //class 1: 2nd highest prob
INDArray p1_2 = Nd4j.create(new double[] {0.35, 0.10, 0.45, 0.05, 0.05}, new long[]{1, 5}); //class 1: 3rd highest prob
INDArray p1_3 = Nd4j.create(new double[] {0.40, 0.10, 0.30, 0.15, 0.05}, new long[]{1, 5}); //class 1: 4th highest prob
// Correct TopNCorrect Total
e1.eval(i0, p0_0); // 1 1 1
e1.eval(i0, p0_1); // 1 2 2
e1.eval(i0, p0_2); // 1 3 3
e1.eval(i0, p0_3); // 1 3 4
assertEquals(0.25, e1.accuracy(), 1e-6);
assertEquals(0.75, e1.topNAccuracy(), 1e-6);
assertEquals(3, e1.getTopNCorrectCount());
assertEquals(4, e1.getTopNTotalCount());
e2.eval(i1, p1_0); // 1 1 1
e2.eval(i1, p1_1); // 1 2 2
e2.eval(i1, p1_2); // 1 3 3
e2.eval(i1, p1_3); // 1 3 4
assertEquals(1.0 / 4, e2.accuracy(), 1e-6);
assertEquals(3.0 / 4, e2.topNAccuracy(), 1e-6);
assertEquals(3, e2.getTopNCorrectCount());
assertEquals(4, e2.getTopNTotalCount());
e1.merge(e2);
assertEquals(8, e1.getNumRowCounter());
assertEquals(8, e1.getTopNTotalCount());
assertEquals(6, e1.getTopNCorrectCount());
assertEquals(2.0 / 8, e1.accuracy(), 1e-6);
assertEquals(6.0 / 8, e1.topNAccuracy(), 1e-6);
}
@Test
public void testBinaryCase() {
INDArray ones10 = Nd4j.ones(10, 1);
INDArray ones4 = Nd4j.ones(4, 1);
INDArray zeros4 = Nd4j.zeros(4, 1);
INDArray ones3 = Nd4j.ones(3, 1);
INDArray zeros3 = Nd4j.zeros(3, 1);
INDArray zeros2 = Nd4j.zeros(2, 1);
Evaluation e = new Evaluation();
e.eval(ones10, ones10); //10 true positives
e.eval(ones3, zeros3); //3 false negatives
e.eval(zeros4, ones4); //4 false positives
e.eval(zeros2, zeros2); //2 true negatives
assertEquals((10 + 2) / (double) (10 + 3 + 4 + 2), e.accuracy(), 1e-6);
assertEquals(10, (int) e.truePositives().get(1));
assertEquals(3, (int) e.falseNegatives().get(1));
assertEquals(4, (int) e.falsePositives().get(1));
assertEquals(2, (int) e.trueNegatives().get(1));
//If we switch the label around: tp becomes tn, fp becomes fn, etc
assertEquals(10, (int) e.trueNegatives().get(0));
assertEquals(3, (int) e.falsePositives().get(0));
assertEquals(4, (int) e.falseNegatives().get(0));
assertEquals(2, (int) e.truePositives().get(0));
}
@Test
public void testF1FBeta_MicroMacroAveraging() {
//Confusion matrix: rows = actual, columns = predicted
//[3, 1, 0]
//[2, 2, 1]
//[0, 3, 4]
INDArray zero = Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3});
INDArray one = Nd4j.create(new double[] {0, 1, 0}, new long[]{1, 3});
INDArray two = Nd4j.create(new double[] {0, 0, 1}, new long[]{1, 3});
Evaluation e = new Evaluation();
apply(e, 3, zero, zero);
apply(e, 1, one, zero);
apply(e, 2, zero, one);
apply(e, 2, one, one);
apply(e, 1, two, one);
apply(e, 3, one, two);
apply(e, 4, two, two);
assertEquals(3, e.getConfusionMatrix().getCount(0, 0));
assertEquals(1, e.getConfusionMatrix().getCount(0, 1));
assertEquals(0, e.getConfusionMatrix().getCount(0, 2));
assertEquals(2, e.getConfusionMatrix().getCount(1, 0));
assertEquals(2, e.getConfusionMatrix().getCount(1, 1));
assertEquals(1, e.getConfusionMatrix().getCount(1, 2));
assertEquals(0, e.getConfusionMatrix().getCount(2, 0));
assertEquals(3, e.getConfusionMatrix().getCount(2, 1));
assertEquals(4, e.getConfusionMatrix().getCount(2, 2));
double beta = 3.5;
double[] prec = new double[3];
double[] rec = new double[3];
for (int i = 0; i < 3; i++) {
prec[i] = e.truePositives().get(i) / (double) (e.truePositives().get(i) + e.falsePositives().get(i));
rec[i] = e.truePositives().get(i) / (double) (e.truePositives().get(i) + e.falseNegatives().get(i));
}
//Binarized confusion
//class 0:
// [3, 1] [tp fn]
// [2, 10] [fp tn]
assertEquals(3, (int) e.truePositives().get(0));
assertEquals(1, (int) e.falseNegatives().get(0));
assertEquals(2, (int) e.falsePositives().get(0));
assertEquals(10, (int) e.trueNegatives().get(0));
//class 1:
// [2, 3] [tp fn]
// [4, 7] [fp tn]
assertEquals(2, (int) e.truePositives().get(1));
assertEquals(3, (int) e.falseNegatives().get(1));
assertEquals(4, (int) e.falsePositives().get(1));
assertEquals(7, (int) e.trueNegatives().get(1));
//class 2:
// [4, 3] [tp fn]
// [1, 8] [fp tn]
assertEquals(4, (int) e.truePositives().get(2));
assertEquals(3, (int) e.falseNegatives().get(2));
assertEquals(1, (int) e.falsePositives().get(2));
assertEquals(8, (int) e.trueNegatives().get(2));
double[] fBeta = new double[3];
double[] f1 = new double[3];
double[] mcc = new double[3];
for (int i = 0; i < 3; i++) {
fBeta[i] = (1 + beta * beta) * prec[i] * rec[i] / (beta * beta * prec[i] + rec[i]);
f1[i] = 2 * prec[i] * rec[i] / (prec[i] + rec[i]);
assertEquals(fBeta[i], e.fBeta(beta, i), 1e-6);
assertEquals(f1[i], e.f1(i), 1e-6);
double gmeasure = Math.sqrt(prec[i] * rec[i]);
assertEquals(gmeasure, e.gMeasure(i), 1e-6);
double tp = e.truePositives().get(i);
double tn = e.trueNegatives().get(i);
double fp = e.falsePositives().get(i);
double fn = e.falseNegatives().get(i);
mcc[i] = (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
assertEquals(mcc[i], e.matthewsCorrelation(i), 1e-6);
}
//Test macro and micro averaging:
int tp = 0;
int fn = 0;
int fp = 0;
int tn = 0;
double macroPrecision = 0.0;
double macroRecall = 0.0;
double macroF1 = 0.0;
double macroFBeta = 0.0;
double macroMcc = 0.0;
for (int i = 0; i < 3; i++) {
tp += e.truePositives().get(i);
fn += e.falseNegatives().get(i);
fp += e.falsePositives().get(i);
tn += e.trueNegatives().get(i);
macroPrecision += prec[i];
macroRecall += rec[i];
macroF1 += f1[i];
macroFBeta += fBeta[i];
macroMcc += mcc[i];
}
macroPrecision /= 3;
macroRecall /= 3;
macroF1 /= 3;
macroFBeta /= 3;
macroMcc /= 3;
double microPrecision = tp / (double) (tp + fp);
double microRecall = tp / (double) (tp + fn);
double microFBeta =
(1 + beta * beta) * microPrecision * microRecall / (beta * beta * microPrecision + microRecall);
double microF1 = 2 * microPrecision * microRecall / (microPrecision + microRecall);
double microMcc = (tp * tn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
assertEquals(microPrecision, e.precision(EvaluationAveraging.Micro), 1e-6);
assertEquals(microRecall, e.recall(EvaluationAveraging.Micro), 1e-6);
assertEquals(macroPrecision, e.precision(EvaluationAveraging.Macro), 1e-6);
assertEquals(macroRecall, e.recall(EvaluationAveraging.Macro), 1e-6);
assertEquals(microFBeta, e.fBeta(beta, EvaluationAveraging.Micro), 1e-6);
assertEquals(macroFBeta, e.fBeta(beta, EvaluationAveraging.Macro), 1e-6);
assertEquals(microF1, e.f1(EvaluationAveraging.Micro), 1e-6);
assertEquals(macroF1, e.f1(EvaluationAveraging.Macro), 1e-6);
assertEquals(microMcc, e.matthewsCorrelation(EvaluationAveraging.Micro), 1e-6);
assertEquals(macroMcc, e.matthewsCorrelation(EvaluationAveraging.Macro), 1e-6);
}
private static void apply(Evaluation e, int nTimes, INDArray predicted, INDArray actual) {
for (int i = 0; i < nTimes; i++) {
e.eval(actual, predicted);
}
}
@Test
public void testConfusionMatrixStats() {
Evaluation e = new Evaluation();
INDArray c0 = Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3});
INDArray c1 = Nd4j.create(new double[] {0, 1, 0}, new long[]{1, 3});
INDArray c2 = Nd4j.create(new double[] {0, 0, 1}, new long[]{1, 3});
apply(e, 3, c2, c0); //Predicted class 2 when actually class 0, 3 times
apply(e, 2, c0, c1); //Predicted class 0 when actually class 1, 2 times
String s1 = " 0 0 3 | 0 = 0"; //First row: predicted 2, actual 0 - 3 times
String s2 = " 2 0 0 | 1 = 1"; //Second row: predicted 0, actual 1 - 2 times
String stats = e.stats();
assertTrue(stats, stats.contains(s1));
assertTrue(stats, stats.contains(s2));
}
@Test
public void testEvalBinaryMetrics(){
Evaluation ePosClass1_nOut2 = new Evaluation(2, 1);
Evaluation ePosClass0_nOut2 = new Evaluation(2, 0);
Evaluation ePosClass1_nOut1 = new Evaluation(2, 1);
Evaluation ePosClass0_nOut1 = new Evaluation(2, 0);
Evaluation ePosClassNull_nOut2 = new Evaluation(2, null);
Evaluation ePosClassNull_nOut1 = new Evaluation(2, null);
Evaluation[] evals = new Evaluation[]{ePosClass1_nOut2, ePosClass0_nOut2, ePosClass1_nOut1, ePosClass0_nOut1};
int[] posClass = {1,0,1,0,-1,-1};
//Correct, actual positive class -> TP
INDArray p1_1 = Nd4j.create(new double[]{0.3, 0.7}, new long[]{1, 2});
INDArray l1_1 = Nd4j.create(new double[]{0,1}, new long[]{1, 2});
INDArray p1_0 = Nd4j.create(new double[]{0.7, 0.3}, new long[]{1, 2});
INDArray l1_0 = Nd4j.create(new double[]{1,0}, new long[]{1, 2});
//Incorrect, actual positive class -> FN
INDArray p2_1 = Nd4j.create(new double[]{0.6, 0.4}, new long[]{1, 2});
INDArray l2_1 = Nd4j.create(new double[]{0,1}, new long[]{1, 2});
INDArray p2_0 = Nd4j.create(new double[]{0.4, 0.6}, new long[]{1, 2});
INDArray l2_0 = Nd4j.create(new double[]{1,0}, new long[]{1, 2});
//Correct, actual negative class -> TN
INDArray p3_1 = Nd4j.create(new double[]{0.8, 0.2}, new long[]{1, 2});
INDArray l3_1 = Nd4j.create(new double[]{1,0}, new long[]{1, 2});
INDArray p3_0 = Nd4j.create(new double[]{0.2, 0.8}, new long[]{1, 2});
INDArray l3_0 = Nd4j.create(new double[]{0,1}, new long[]{1, 2});
//Incorrect, actual negative class -> FP
INDArray p4_1 = Nd4j.create(new double[]{0.45, 0.55}, new long[]{1, 2});
INDArray l4_1 = Nd4j.create(new double[]{1,0}, new long[]{1, 2});
INDArray p4_0 = Nd4j.create(new double[]{0.55, 0.45}, new long[]{1, 2});
INDArray l4_0 = Nd4j.create(new double[]{0,1}, new long[]{1, 2});
int tp = 7;
int fn = 5;
int tn = 3;
int fp = 1;
for( int i=0; i<tp; i++ ) {
ePosClass1_nOut2.eval(l1_1, p1_1);
ePosClass1_nOut1.eval(l1_1.getColumn(1).reshape(1,-1), p1_1.getColumn(1).reshape(1,-1));
ePosClass0_nOut2.eval(l1_0, p1_0);
ePosClass0_nOut1.eval(l1_0.getColumn(1).reshape(1,-1), p1_0.getColumn(1).reshape(1,-1)); //label 0 = instance of positive class
ePosClassNull_nOut2.eval(l1_1, p1_1);
ePosClassNull_nOut1.eval(l1_0.getColumn(0).reshape(1,-1), p1_0.getColumn(0).reshape(1,-1));
}
for( int i=0; i<fn; i++ ){
ePosClass1_nOut2.eval(l2_1, p2_1);
ePosClass1_nOut1.eval(l2_1.getColumn(1).reshape(1,-1), p2_1.getColumn(1).reshape(1,-1));
ePosClass0_nOut2.eval(l2_0, p2_0);
ePosClass0_nOut1.eval(l2_0.getColumn(1).reshape(1,-1), p2_0.getColumn(1).reshape(1,-1));
ePosClassNull_nOut2.eval(l2_1, p2_1);
ePosClassNull_nOut1.eval(l2_0.getColumn(0).reshape(1,-1), p2_0.getColumn(0).reshape(1,-1));
}
for( int i=0; i<tn; i++ ) {
ePosClass1_nOut2.eval(l3_1, p3_1);
ePosClass1_nOut1.eval(l3_1.getColumn(1).reshape(1,-1), p3_1.getColumn(1).reshape(1,-1));
ePosClass0_nOut2.eval(l3_0, p3_0);
ePosClass0_nOut1.eval(l3_0.getColumn(1).reshape(1,-1), p3_0.getColumn(1).reshape(1,-1));
ePosClassNull_nOut2.eval(l3_1, p3_1);
ePosClassNull_nOut1.eval(l3_0.getColumn(0).reshape(1,-1), p3_0.getColumn(0).reshape(1,-1));
}
for( int i=0; i<fp; i++ ){
ePosClass1_nOut2.eval(l4_1, p4_1);
ePosClass1_nOut1.eval(l4_1.getColumn(1).reshape(1,-1), p4_1.getColumn(1).reshape(1,-1));
ePosClass0_nOut2.eval(l4_0, p4_0);
ePosClass0_nOut1.eval(l4_0.getColumn(1).reshape(1,-1), p4_0.getColumn(1).reshape(1,-1));
ePosClassNull_nOut2.eval(l4_1, p4_1);
ePosClassNull_nOut1.eval(l4_0.getColumn(0).reshape(1,-1), p4_0.getColumn(0).reshape(1,-1));
}
for( int i=0; i<4; i++ ){
int positiveClass = posClass[i];
String m = String.valueOf(i);
int tpAct = evals[i].truePositives().get(positiveClass);
int tnAct = evals[i].trueNegatives().get(positiveClass);
int fpAct = evals[i].falsePositives().get(positiveClass);
int fnAct = evals[i].falseNegatives().get(positiveClass);
//System.out.println(evals[i].stats());
assertEquals(m, tp, tpAct);
assertEquals(m, tn, tnAct);
assertEquals(m, fp, fpAct);
assertEquals(m, fn, fnAct);
}
double acc = (tp+tn) / (double)(tp+fn+tn+fp);
double rec = tp / (double)(tp+fn);
double prec = tp / (double)(tp+fp);
double f1 = 2 * (prec * rec) / (prec + rec);
for( int i=0; i<evals.length; i++ ){
String m = String.valueOf(i);
assertEquals(m, acc, evals[i].accuracy(), 1e-5);
assertEquals(m, prec, evals[i].precision(), 1e-5);
assertEquals(m, rec, evals[i].recall(), 1e-5);
assertEquals(m, f1, evals[i].f1(), 1e-5);
}
//Also check macro-averaged versions (null positive class):
assertEquals(acc, ePosClassNull_nOut2.accuracy(), 1e-6);
assertEquals(ePosClass1_nOut2.recall(EvaluationAveraging.Macro), ePosClassNull_nOut2.recall(), 1e-6);
assertEquals(ePosClass1_nOut2.precision(EvaluationAveraging.Macro), ePosClassNull_nOut2.precision(), 1e-6);
assertEquals(ePosClass1_nOut2.f1(EvaluationAveraging.Macro), ePosClassNull_nOut2.f1(), 1e-6);
assertEquals(acc, ePosClassNull_nOut1.accuracy(), 1e-6);
assertEquals(ePosClass1_nOut2.recall(EvaluationAveraging.Macro), ePosClassNull_nOut1.recall(), 1e-6);
assertEquals(ePosClass1_nOut2.precision(EvaluationAveraging.Macro), ePosClassNull_nOut1.precision(), 1e-6);
assertEquals(ePosClass1_nOut2.f1(EvaluationAveraging.Macro), ePosClassNull_nOut1.f1(), 1e-6);
}
@Test
public void testConfusionMatrixString(){
Evaluation e = new Evaluation(Arrays.asList("a","b","c"));
INDArray class0 = Nd4j.create(new double[]{1,0,0}, new long[]{1, 3});
INDArray class1 = Nd4j.create(new double[]{0,1,0}, new long[]{1, 3});
INDArray class2 = Nd4j.create(new double[]{0,0,1}, new long[]{1, 3});
//Predicted class 0, actual class 1 x2
e.eval(class0, class1);
e.eval(class0, class1);
e.eval(class2, class2);
e.eval(class2, class2);
e.eval(class2, class2);
String s = e.confusionMatrix();
// System.out.println(s);
String exp =
" 0 1 2\n" +
"-------\n" +
" 0 2 0 | 0 = a\n" + //0 predicted as 1, 2 times
" 0 0 0 | 1 = b\n" +
" 0 0 3 | 2 = c\n" + //2 predicted as 2, 3 times
"\nConfusion matrix format: Actual (rowClass) predicted as (columnClass) N times";
assertEquals(exp, s);
System.out.println("============================");
System.out.println(e.stats());
System.out.println("\n\n\n\n");
//Test with 21 classes (> threshold)
e = new Evaluation();
class0 = Nd4j.create(1, 31);
class0.putScalar(0, 1);
e.eval(class0, class0);
System.out.println(e.stats());
System.out.println("\n\n\n\n");
System.out.println(e.stats(false, true));
}
@Test
public void testEvaluationNaNs(){
Evaluation e = new Evaluation();
INDArray predictions = Nd4j.create(new double[]{0.1, Double.NaN, 0.3}, new long[]{1,3});
INDArray labels = Nd4j.create(new double[]{0, 0, 1}, new long[]{1,3});
try {
e.eval(labels, predictions);
} catch (IllegalStateException ex){
assertTrue(ex.getMessage().contains("NaN"));
}
}
@Test
public void testSegmentation(){
for( int c : new int[]{4, 1}) { //c=1 should be treated as binary classification case
Nd4j.getRandom().setSeed(12345);
int mb = 3;
int h = 3;
int w = 2;
//NCHW
INDArray labels = Nd4j.create(DataType.FLOAT, mb, c, h, w);
Random r = new Random(12345);
for (int i = 0; i < mb; i++) {
for (int j = 0; j < h; j++) {
for (int k = 0; k < w; k++) {
if(c == 1){
labels.putScalar(i, 0, j, k, r.nextInt(2));
} else {
int classIdx = r.nextInt(c);
labels.putScalar(i, classIdx, j, k, 1.0);
}
}
}
}
INDArray predictions = Nd4j.rand(DataType.FLOAT, mb, c, h, w);
if(c > 1) {
DynamicCustomOp op = DynamicCustomOp.builder("softmax")
.addInputs(predictions)
.addOutputs(predictions)
.callInplace(true)
.addIntegerArguments(1) //Axis
.build();
Nd4j.exec(op);
}
Evaluation e2d = new Evaluation();
Evaluation e4d = new Evaluation();
e4d.eval(labels, predictions);
for (int i = 0; i < mb; i++) {
for (int j = 0; j < h; j++) {
for (int k = 0; k < w; k++) {
INDArray rowLabel = labels.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j), NDArrayIndex.point(k));
INDArray rowPredictions = predictions.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j), NDArrayIndex.point(k));
rowLabel = rowLabel.reshape(1, rowLabel.length());
rowPredictions = rowPredictions.reshape(1, rowLabel.length());
e2d.eval(rowLabel, rowPredictions);
}
}
}
assertEquals(e2d, e4d);
//NHWC, etc
INDArray lOrig = labels;
INDArray fOrig = predictions;
for (int i = 0; i < 4; i++) {
switch (i) {
case 0:
//CNHW - Never really used
labels = lOrig.permute(1, 0, 2, 3).dup();
predictions = fOrig.permute(1, 0, 2, 3).dup();
break;
case 1:
//NCHW
labels = lOrig;
predictions = fOrig;
break;
case 2:
//NHCW - Never really used...
labels = lOrig.permute(0, 2, 1, 3).dup();
predictions = fOrig.permute(0, 2, 1, 3).dup();
break;
case 3:
//NHWC
labels = lOrig.permute(0, 2, 3, 1).dup();
predictions = fOrig.permute(0, 2, 3, 1).dup();
break;
default:
throw new RuntimeException();
}
Evaluation e = new Evaluation();
e.setAxis(i);
e.eval(labels, predictions);
assertEquals(e2d, e);
}
}
}
@Test
public void testLabelReset(){
Map<Integer,String> m = new HashMap<>();
m.put(0, "False");
m.put(1, "True");
Evaluation e1 = new Evaluation(m);
INDArray zero = Nd4j.create(new double[]{1,0}).reshape(1,2);
INDArray one = Nd4j.create(new double[]{0,1}).reshape(1,2);
e1.eval(zero, zero);
e1.eval(zero, zero);
e1.eval(one, zero);
e1.eval(one, one);
e1.eval(one, one);
e1.eval(one, one);
String s1 = e1.stats();
System.out.println(s1);
e1.reset();
e1.eval(zero, zero);
e1.eval(zero, zero);
e1.eval(one, zero);
e1.eval(one, one);
e1.eval(one, one);
e1.eval(one, one);
String s2 = e1.stats();
assertEquals(s1, s2);
}
@Test
public void testEvalStatsBinaryCase(){
//Make sure we report class 1 precision/recall/f1 not macro averaged, for binary case
Evaluation e = new Evaluation();
INDArray l0 = Nd4j.createFromArray(new double[]{1,0}).reshape(1,2);
INDArray l1 = Nd4j.createFromArray(new double[]{0,1}).reshape(1,2);
e.eval(l1, l1);
e.eval(l1, l1);
e.eval(l1, l1);
e.eval(l0, l0);
e.eval(l1, l0);
e.eval(l1, l0);
e.eval(l0, l1);
double tp = 3;
double fp = 1;
double fn = 2;
double prec = tp / (tp + fp);
double rec = tp / (tp + fn);
double f1 = 2 * prec * rec / (prec + rec);
assertEquals(prec, e.precision(), 1e-6);
assertEquals(rec, e.recall(), 1e-6);
DecimalFormat df = new DecimalFormat("0.0000");
String stats = e.stats();
//System.out.println(stats);
String stats2 = stats.replaceAll("( )+", " ");
String recS = " Recall: " + df.format(rec);
String preS = " Precision: " + df.format(prec);
String f1S = "F1 Score: " + df.format(f1);
assertTrue(stats2, stats2.contains(recS));
assertTrue(stats2, stats2.contains(preS));
assertTrue(stats2, stats2.contains(f1S));
}
}