2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
* ******************************************************************************
|
|
|
|
* *
|
|
|
|
* *
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
* * under the License.
|
|
|
|
* *
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
* *****************************************************************************
|
|
|
|
*/
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
package org.nd4j.evaluation;
|
|
|
|
|
2021-03-20 19:06:24 +09:00
|
|
|
import org.junit.jupiter.api.Tag;
|
2021-03-16 11:57:24 +09:00
|
|
|
import org.junit.jupiter.api.Test;
|
2021-03-16 22:08:35 +09:00
|
|
|
import org.junit.jupiter.params.ParameterizedTest;
|
|
|
|
import org.junit.jupiter.params.provider.MethodSource;
|
2021-03-20 19:06:24 +09:00
|
|
|
import org.nd4j.common.tests.tags.TagNames;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.evaluation.classification.Evaluation;
|
|
|
|
import org.nd4j.evaluation.classification.EvaluationBinary;
|
2021-03-16 22:08:35 +09:00
|
|
|
import org.nd4j.linalg.BaseNd4jTestWithBackends;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
|
|
|
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
|
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
|
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
|
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
|
|
|
|
|
|
|
import java.util.Random;
|
|
|
|
|
2021-03-16 11:57:24 +09:00
|
|
|
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
|
|
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
2019-06-06 15:21:15 +03:00
|
|
|
|
2021-03-20 19:06:24 +09:00
|
|
|
@Tag(TagNames.EVAL_METRICS)
|
2021-03-16 22:08:35 +09:00
|
|
|
public class EvalCustomThreshold extends BaseNd4jTestWithBackends {
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
public char ordering() {
|
|
|
|
return 'c';
|
|
|
|
}
|
|
|
|
|
2021-03-16 22:08:35 +09:00
|
|
|
@ParameterizedTest
|
2021-03-17 20:04:53 +09:00
|
|
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
2021-03-16 22:08:35 +09:00
|
|
|
public void testEvaluationCustomBinaryThreshold(Nd4jBackend backend) {
|
2019-06-06 15:21:15 +03:00
|
|
|
Nd4j.getRandom().setSeed(12345);
|
|
|
|
|
|
|
|
//Sanity checks: 0.5 threshold for 1-output and 2-output binary cases
|
|
|
|
Evaluation e = new Evaluation();
|
|
|
|
Evaluation e05 = new Evaluation(0.5);
|
|
|
|
Evaluation e05v2 = new Evaluation(0.5);
|
|
|
|
|
|
|
|
int nExamples = 20;
|
|
|
|
int nOut = 2;
|
|
|
|
INDArray probs = Nd4j.rand(nExamples, nOut);
|
|
|
|
probs.diviColumnVector(probs.sum(1));
|
|
|
|
INDArray labels = Nd4j.create(nExamples, nOut);
|
|
|
|
Random r = new Random(12345);
|
|
|
|
for (int i = 0; i < nExamples; i++) {
|
|
|
|
labels.putScalar(i, r.nextInt(2), 1.0);
|
|
|
|
}
|
|
|
|
|
|
|
|
e.eval(labels, probs);
|
|
|
|
e05.eval(labels, probs);
|
|
|
|
e05v2.eval(labels.getColumn(1, true), probs.getColumn(1, true)); //"single output binary" case
|
|
|
|
|
|
|
|
for (Evaluation e2 : new Evaluation[] {e05, e05v2}) {
|
|
|
|
assertEquals(e.accuracy(), e2.accuracy(), 1e-6);
|
|
|
|
assertEquals(e.f1(), e2.f1(), 1e-6);
|
|
|
|
assertEquals(e.precision(), e2.precision(), 1e-6);
|
|
|
|
assertEquals(e.recall(), e2.recall(), 1e-6);
|
|
|
|
assertEquals(e.getConfusionMatrix(), e2.getConfusionMatrix());
|
|
|
|
}
|
|
|
|
|
|
|
|
//Check with decision threshold of 0.25
|
|
|
|
//In this test, we'll cheat a bit: multiply class 1 probabilities by 2 (max of 1.0); this should give an
|
|
|
|
// identical result to a threshold of 0.5 vs. no multiplication and threshold of 0.25
|
|
|
|
|
|
|
|
INDArray p2 = probs.dup();
|
|
|
|
INDArray p2c = p2.getColumn(1);
|
|
|
|
p2c.muli(2.0);
|
|
|
|
Nd4j.getExecutioner().exec(new ScalarMin(p2c, null, p2c, 1.0));
|
|
|
|
p2.getColumn(0).assign(p2.getColumn(1).rsub(1.0));
|
|
|
|
|
|
|
|
Evaluation e025 = new Evaluation(0.25);
|
|
|
|
e025.eval(labels, probs);
|
|
|
|
|
|
|
|
Evaluation ex2 = new Evaluation();
|
|
|
|
ex2.eval(labels, p2);
|
|
|
|
|
|
|
|
assertEquals(ex2.accuracy(), e025.accuracy(), 1e-6);
|
|
|
|
assertEquals(ex2.f1(), e025.f1(), 1e-6);
|
|
|
|
assertEquals(ex2.precision(), e025.precision(), 1e-6);
|
|
|
|
assertEquals(ex2.recall(), e025.recall(), 1e-6);
|
|
|
|
assertEquals(ex2.getConfusionMatrix(), e025.getConfusionMatrix());
|
|
|
|
|
|
|
|
|
|
|
|
//Check the same thing, but the single binary output case:
|
|
|
|
|
|
|
|
Evaluation e025v2 = new Evaluation(0.25);
|
|
|
|
e025v2.eval(labels.getColumn(1, true), probs.getColumn(1, true));
|
|
|
|
|
|
|
|
assertEquals(ex2.accuracy(), e025v2.accuracy(), 1e-6);
|
|
|
|
assertEquals(ex2.f1(), e025v2.f1(), 1e-6);
|
|
|
|
assertEquals(ex2.precision(), e025v2.precision(), 1e-6);
|
|
|
|
assertEquals(ex2.recall(), e025v2.recall(), 1e-6);
|
|
|
|
assertEquals(ex2.getConfusionMatrix(), e025v2.getConfusionMatrix());
|
|
|
|
}
|
|
|
|
|
2021-03-16 22:08:35 +09:00
|
|
|
@ParameterizedTest
|
2021-03-17 20:04:53 +09:00
|
|
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
2021-03-16 22:08:35 +09:00
|
|
|
public void testEvaluationCostArray(Nd4jBackend backend) {
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
|
|
|
int nExamples = 20;
|
|
|
|
int nOut = 3;
|
|
|
|
Nd4j.getRandom().setSeed(12345);
|
|
|
|
INDArray probs = Nd4j.rand(nExamples, nOut);
|
|
|
|
probs.diviColumnVector(probs.sum(1));
|
|
|
|
INDArray labels = Nd4j.create(nExamples, nOut);
|
|
|
|
Random r = new Random(12345);
|
|
|
|
for (int j = 0; j < nExamples; j++) {
|
|
|
|
labels.putScalar(j, r.nextInt(2), 1.0);
|
|
|
|
}
|
|
|
|
|
|
|
|
Evaluation e = new Evaluation();
|
|
|
|
e.eval(labels, probs);
|
|
|
|
|
|
|
|
//Sanity check: "all equal" cost array - equal to no cost array
|
|
|
|
for (int i = 1; i <= 3; i++) {
|
|
|
|
Evaluation e2 = new Evaluation(Nd4j.valueArrayOf(new int[] {1, nOut}, i));
|
|
|
|
e2.eval(labels, probs);
|
|
|
|
|
|
|
|
assertEquals(e.accuracy(), e2.accuracy(), 1e-6);
|
|
|
|
assertEquals(e.f1(), e2.f1(), 1e-6);
|
|
|
|
assertEquals(e.precision(), e2.precision(), 1e-6);
|
|
|
|
assertEquals(e.recall(), e2.recall(), 1e-6);
|
|
|
|
assertEquals(e.getConfusionMatrix(), e2.getConfusionMatrix());
|
|
|
|
}
|
|
|
|
|
|
|
|
//Manual checks:
|
|
|
|
INDArray costArray = Nd4j.create(new double[] {5, 2, 1});
|
|
|
|
labels = Nd4j.create(new double[][] {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}});
|
|
|
|
probs = Nd4j.create(new double[][] {{0.2, 0.3, 0.5}, //1.0, 0.6, 0.5
|
|
|
|
{0.1, 0.4, 0.5}, //0.5, 0.8, 0.5
|
|
|
|
{0.1, 0.1, 0.8}}); //0.5, 0.2, 0.8
|
|
|
|
|
|
|
|
//With no cost array: only last example is predicted correctly
|
|
|
|
e = new Evaluation();
|
|
|
|
e.eval(labels, probs);
|
|
|
|
assertEquals(1.0 / 3, e.accuracy(), 1e-6);
|
|
|
|
|
|
|
|
//With cost array: all examples predicted correctly
|
|
|
|
Evaluation e2 = new Evaluation(costArray);
|
|
|
|
e2.eval(labels, probs);
|
|
|
|
assertEquals(1.0, e2.accuracy(), 1e-6);
|
|
|
|
}
|
|
|
|
|
2021-03-16 22:08:35 +09:00
|
|
|
@ParameterizedTest
|
2021-03-17 20:04:53 +09:00
|
|
|
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
2021-03-16 22:08:35 +09:00
|
|
|
public void testEvaluationBinaryCustomThreshold(Nd4jBackend backend) {
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
//Sanity check: same results for 0.5 threshold vs. default (no threshold)
|
|
|
|
int nExamples = 20;
|
|
|
|
int nOut = 2;
|
|
|
|
INDArray probs = Nd4j.rand(nExamples, nOut);
|
|
|
|
INDArray labels = Nd4j.getExecutioner()
|
|
|
|
.exec(new BernoulliDistribution(Nd4j.createUninitialized(nExamples, nOut), 0.5));
|
|
|
|
|
|
|
|
EvaluationBinary eStd = new EvaluationBinary();
|
|
|
|
eStd.eval(labels, probs);
|
|
|
|
|
|
|
|
EvaluationBinary eb05 = new EvaluationBinary(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}));
|
|
|
|
eb05.eval(labels, probs);
|
|
|
|
|
|
|
|
EvaluationBinary eb05v2 = new EvaluationBinary(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}));
|
|
|
|
for (int i = 0; i < nExamples; i++) {
|
|
|
|
eb05v2.eval(labels.getRow(i, true), probs.getRow(i, true));
|
|
|
|
}
|
|
|
|
|
|
|
|
for (EvaluationBinary eb2 : new EvaluationBinary[] {eb05, eb05v2}) {
|
|
|
|
assertArrayEquals(eStd.getCountTruePositive(), eb2.getCountTruePositive());
|
|
|
|
assertArrayEquals(eStd.getCountFalsePositive(), eb2.getCountFalsePositive());
|
|
|
|
assertArrayEquals(eStd.getCountTrueNegative(), eb2.getCountTrueNegative());
|
|
|
|
assertArrayEquals(eStd.getCountFalseNegative(), eb2.getCountFalseNegative());
|
|
|
|
|
|
|
|
for (int j = 0; j < nOut; j++) {
|
|
|
|
assertEquals(eStd.accuracy(j), eb2.accuracy(j), 1e-6);
|
|
|
|
assertEquals(eStd.f1(j), eb2.f1(j), 1e-6);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
//Check with decision threshold of 0.25 and 0.125 (for different outputs)
|
|
|
|
//In this test, we'll cheat a bit: multiply probabilities by 2 (max of 1.0) and threshold of 0.25 should give
|
|
|
|
// an identical result to a threshold of 0.5
|
|
|
|
//Ditto for 4x and 0.125 threshold
|
|
|
|
|
|
|
|
INDArray probs2 = probs.mul(2);
|
|
|
|
probs2 = Transforms.min(probs2, 1.0);
|
|
|
|
|
|
|
|
INDArray probs4 = probs.mul(4);
|
|
|
|
probs4 = Transforms.min(probs4, 1.0);
|
|
|
|
|
|
|
|
EvaluationBinary ebThreshold = new EvaluationBinary(Nd4j.create(new double[] {0.25, 0.125}));
|
|
|
|
ebThreshold.eval(labels, probs);
|
|
|
|
|
|
|
|
EvaluationBinary ebStd2 = new EvaluationBinary();
|
|
|
|
ebStd2.eval(labels, probs2);
|
|
|
|
|
|
|
|
EvaluationBinary ebStd4 = new EvaluationBinary();
|
|
|
|
ebStd4.eval(labels, probs4);
|
|
|
|
|
|
|
|
assertEquals(ebThreshold.truePositives(0), ebStd2.truePositives(0));
|
|
|
|
assertEquals(ebThreshold.trueNegatives(0), ebStd2.trueNegatives(0));
|
|
|
|
assertEquals(ebThreshold.falsePositives(0), ebStd2.falsePositives(0));
|
|
|
|
assertEquals(ebThreshold.falseNegatives(0), ebStd2.falseNegatives(0));
|
|
|
|
|
|
|
|
assertEquals(ebThreshold.truePositives(1), ebStd4.truePositives(1));
|
|
|
|
assertEquals(ebThreshold.trueNegatives(1), ebStd4.trueNegatives(1));
|
|
|
|
assertEquals(ebThreshold.falsePositives(1), ebStd4.falsePositives(1));
|
|
|
|
assertEquals(ebThreshold.falseNegatives(1), ebStd4.falseNegatives(1));
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|