* #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
443 lines
17 KiB
Java
443 lines
17 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.nd4j.evaluation;
|
|
|
|
import org.junit.Test;
|
|
import org.nd4j.evaluation.classification.Evaluation;
|
|
import org.nd4j.evaluation.classification.EvaluationBinary;
|
|
import org.nd4j.linalg.BaseNd4jTest;
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
import org.nd4j.linalg.api.iter.NdIndexIterator;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
|
import org.nd4j.linalg.indexing.INDArrayIndex;
|
|
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.List;
|
|
|
|
import static org.junit.Assert.assertEquals;
|
|
import static org.nd4j.evaluation.classification.EvaluationBinary.Metric.*;
|
|
/**
|
|
* Created by Alex on 20/03/2017.
|
|
*/
|
|
public class EvaluationBinaryTest extends BaseNd4jTest {
|
|
|
|
public EvaluationBinaryTest(Nd4jBackend backend) {
|
|
super(backend);
|
|
}
|
|
|
|
@Override
|
|
public char ordering() {
|
|
return 'c';
|
|
}
|
|
|
|
@Test
|
|
public void testEvaluationBinary() {
|
|
//Compare EvaluationBinary to Evaluation class
|
|
DataType dtypeBefore = Nd4j.defaultFloatingPointType();
|
|
EvaluationBinary first = null;
|
|
String sFirst = null;
|
|
try {
|
|
for (DataType globalDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.INT}) {
|
|
Nd4j.setDefaultDataTypes(globalDtype, globalDtype.isFPType() ? globalDtype : DataType.DOUBLE);
|
|
for (DataType lpDtype : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) {
|
|
|
|
Nd4j.getRandom().setSeed(12345);
|
|
|
|
int nExamples = 50;
|
|
int nOut = 4;
|
|
long[] shape = {nExamples, nOut};
|
|
|
|
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(lpDtype, shape), 0.5));
|
|
|
|
INDArray predicted = Nd4j.rand(lpDtype, shape);
|
|
INDArray binaryPredicted = predicted.gt(0.5);
|
|
|
|
EvaluationBinary eb = new EvaluationBinary();
|
|
eb.eval(labels, predicted);
|
|
|
|
//System.out.println(eb.stats());
|
|
|
|
double eps = 1e-6;
|
|
for (int i = 0; i < nOut; i++) {
|
|
INDArray lCol = labels.getColumn(i,true);
|
|
INDArray pCol = predicted.getColumn(i,true);
|
|
INDArray bpCol = binaryPredicted.getColumn(i,true);
|
|
|
|
int countCorrect = 0;
|
|
int tpCount = 0;
|
|
int tnCount = 0;
|
|
for (int j = 0; j < lCol.length(); j++) {
|
|
if (lCol.getDouble(j) == bpCol.getDouble(j)) {
|
|
countCorrect++;
|
|
if (lCol.getDouble(j) == 1) {
|
|
tpCount++;
|
|
} else {
|
|
tnCount++;
|
|
}
|
|
}
|
|
}
|
|
double acc = countCorrect / (double) lCol.length();
|
|
|
|
Evaluation e = new Evaluation();
|
|
e.eval(lCol, pCol);
|
|
|
|
assertEquals(acc, eb.accuracy(i), eps);
|
|
assertEquals(e.accuracy(), eb.scoreForMetric(ACCURACY, i), eps);
|
|
assertEquals(e.precision(1), eb.scoreForMetric(PRECISION, i), eps);
|
|
assertEquals(e.recall(1), eb.scoreForMetric(RECALL, i), eps);
|
|
assertEquals(e.f1(1), eb.scoreForMetric(F1, i), eps);
|
|
assertEquals(e.falseAlarmRate(), eb.scoreForMetric(FAR, i), eps);
|
|
assertEquals(e.falsePositiveRate(1), eb.falsePositiveRate(i), eps);
|
|
|
|
|
|
assertEquals(tpCount, eb.truePositives(i));
|
|
assertEquals(tnCount, eb.trueNegatives(i));
|
|
|
|
assertEquals((int) e.truePositives().get(1), eb.truePositives(i));
|
|
assertEquals((int) e.trueNegatives().get(1), eb.trueNegatives(i));
|
|
assertEquals((int) e.falsePositives().get(1), eb.falsePositives(i));
|
|
assertEquals((int) e.falseNegatives().get(1), eb.falseNegatives(i));
|
|
|
|
assertEquals(nExamples, eb.totalCount(i));
|
|
|
|
String s = eb.stats();
|
|
if(first == null) {
|
|
first = eb;
|
|
sFirst = s;
|
|
} else {
|
|
assertEquals(first, eb);
|
|
assertEquals(sFirst, s);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} finally {
|
|
Nd4j.setDefaultDataTypes(dtypeBefore, dtypeBefore);
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testEvaluationBinaryMerging() {
|
|
int nOut = 4;
|
|
int[] shape1 = {30, nOut};
|
|
int[] shape2 = {50, nOut};
|
|
|
|
Nd4j.getRandom().setSeed(12345);
|
|
INDArray l1 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape1), 0.5));
|
|
INDArray l2 = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape2), 0.5));
|
|
INDArray p1 = Nd4j.rand(shape1);
|
|
INDArray p2 = Nd4j.rand(shape2);
|
|
|
|
EvaluationBinary eb = new EvaluationBinary();
|
|
eb.eval(l1, p1);
|
|
eb.eval(l2, p2);
|
|
|
|
EvaluationBinary eb1 = new EvaluationBinary();
|
|
eb1.eval(l1, p1);
|
|
|
|
EvaluationBinary eb2 = new EvaluationBinary();
|
|
eb2.eval(l2, p2);
|
|
|
|
eb1.merge(eb2);
|
|
|
|
assertEquals(eb.stats(), eb1.stats());
|
|
}
|
|
|
|
@Test
|
|
public void testEvaluationBinaryPerOutputMasking() {
|
|
|
|
//Provide a mask array: "ignore" the masked steps
|
|
|
|
INDArray mask = Nd4j.create(new double[][] {{1, 1, 0}, {1, 0, 0}, {1, 1, 0}, {1, 0, 0}, {1, 1, 1}});
|
|
|
|
INDArray labels = Nd4j.create(new double[][] {{1, 1, 1}, {0, 0, 0}, {1, 1, 1}, {0, 1, 1}, {1, 0, 1}});
|
|
|
|
INDArray predicted = Nd4j.create(new double[][] {{0.9, 0.9, 0.9}, {0.7, 0.7, 0.7}, {0.6, 0.6, 0.6},
|
|
{0.4, 0.4, 0.4}, {0.1, 0.1, 0.1}});
|
|
|
|
//Correct?
|
|
// Y Y m
|
|
// N m m
|
|
// Y Y m
|
|
// Y m m
|
|
// N Y N
|
|
|
|
EvaluationBinary eb = new EvaluationBinary();
|
|
eb.eval(labels, predicted, mask);
|
|
|
|
assertEquals(0.6, eb.accuracy(0), 1e-6);
|
|
assertEquals(1.0, eb.accuracy(1), 1e-6);
|
|
assertEquals(0.0, eb.accuracy(2), 1e-6);
|
|
|
|
assertEquals(2, eb.truePositives(0));
|
|
assertEquals(2, eb.truePositives(1));
|
|
assertEquals(0, eb.truePositives(2));
|
|
|
|
assertEquals(1, eb.trueNegatives(0));
|
|
assertEquals(1, eb.trueNegatives(1));
|
|
assertEquals(0, eb.trueNegatives(2));
|
|
|
|
assertEquals(1, eb.falsePositives(0));
|
|
assertEquals(0, eb.falsePositives(1));
|
|
assertEquals(0, eb.falsePositives(2));
|
|
|
|
assertEquals(1, eb.falseNegatives(0));
|
|
assertEquals(0, eb.falseNegatives(1));
|
|
assertEquals(1, eb.falseNegatives(2));
|
|
}
|
|
|
|
@Test
|
|
public void testTimeSeriesEval() {
|
|
|
|
int[] shape = {2, 4, 3};
|
|
Nd4j.getRandom().setSeed(12345);
|
|
INDArray labels = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5));
|
|
INDArray predicted = Nd4j.rand(shape);
|
|
INDArray mask = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(shape), 0.5));
|
|
|
|
EvaluationBinary eb1 = new EvaluationBinary();
|
|
eb1.eval(labels, predicted, mask);
|
|
|
|
EvaluationBinary eb2 = new EvaluationBinary();
|
|
for (int i = 0; i < shape[2]; i++) {
|
|
INDArray l = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i));
|
|
INDArray p = predicted.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i));
|
|
INDArray m = mask.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i));
|
|
|
|
eb2.eval(l, p, m);
|
|
}
|
|
|
|
assertEquals(eb2.stats(), eb1.stats());
|
|
}
|
|
|
|
@Test
|
|
public void testEvaluationBinaryWithROC() {
|
|
//Simple test for nested ROCBinary in EvaluationBinary
|
|
|
|
Nd4j.getRandom().setSeed(12345);
|
|
INDArray l1 = Nd4j.getExecutioner()
|
|
.exec(new BernoulliDistribution(Nd4j.createUninitialized(new int[] {50, 4}), 0.5));
|
|
INDArray p1 = Nd4j.rand(50, 4);
|
|
|
|
EvaluationBinary eb = new EvaluationBinary(4, 30);
|
|
eb.eval(l1, p1);
|
|
|
|
// System.out.println(eb.stats());
|
|
eb.stats();
|
|
}
|
|
|
|
|
|
@Test
|
|
public void testEvaluationBinary3d() {
|
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
|
|
|
|
|
|
List<INDArray> rowsP = new ArrayList<>();
|
|
List<INDArray> rowsL = new ArrayList<>();
|
|
NdIndexIterator iter = new NdIndexIterator(2, 10);
|
|
while (iter.hasNext()) {
|
|
long[] idx = iter.next();
|
|
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
|
|
rowsP.add(prediction.get(idxs));
|
|
rowsL.add(label.get(idxs));
|
|
}
|
|
|
|
INDArray p2d = Nd4j.vstack(rowsP);
|
|
INDArray l2d = Nd4j.vstack(rowsL);
|
|
|
|
EvaluationBinary e3d = new EvaluationBinary();
|
|
EvaluationBinary e2d = new EvaluationBinary();
|
|
|
|
e3d.eval(label, prediction);
|
|
e2d.eval(l2d, p2d);
|
|
|
|
for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) {
|
|
for( int i=0; i<5; i++ ) {
|
|
double d1 = e3d.scoreForMetric(m, i);
|
|
double d2 = e2d.scoreForMetric(m, i);
|
|
assertEquals(m.toString(), d2, d1, 1e-6);
|
|
}
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testEvaluationBinary4d() {
|
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
|
|
|
|
|
List<INDArray> rowsP = new ArrayList<>();
|
|
List<INDArray> rowsL = new ArrayList<>();
|
|
NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
|
|
while (iter.hasNext()) {
|
|
long[] idx = iter.next();
|
|
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
|
|
rowsP.add(prediction.get(idxs));
|
|
rowsL.add(label.get(idxs));
|
|
}
|
|
|
|
INDArray p2d = Nd4j.vstack(rowsP);
|
|
INDArray l2d = Nd4j.vstack(rowsL);
|
|
|
|
EvaluationBinary e4d = new EvaluationBinary();
|
|
EvaluationBinary e2d = new EvaluationBinary();
|
|
|
|
e4d.eval(label, prediction);
|
|
e2d.eval(l2d, p2d);
|
|
|
|
for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) {
|
|
for( int i=0; i<3; i++ ) {
|
|
double d1 = e4d.scoreForMetric(m, i);
|
|
double d2 = e2d.scoreForMetric(m, i);
|
|
assertEquals(m.toString(), d2, d1, 1e-6);
|
|
}
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testEvaluationBinary3dMasking() {
|
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
|
|
|
|
List<INDArray> rowsP = new ArrayList<>();
|
|
List<INDArray> rowsL = new ArrayList<>();
|
|
|
|
//Check "DL4J-style" 2d per timestep masking [minibatch, seqLength] mask shape
|
|
INDArray mask2d = Nd4j.randomBernoulli(0.5, 2, 10);
|
|
rowsP.clear();
|
|
rowsL.clear();
|
|
NdIndexIterator iter = new NdIndexIterator(2, 10);
|
|
while (iter.hasNext()) {
|
|
long[] idx = iter.next();
|
|
if(mask2d.getDouble(idx[0], idx[1]) != 0.0) {
|
|
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
|
|
rowsP.add(prediction.get(idxs));
|
|
rowsL.add(label.get(idxs));
|
|
}
|
|
}
|
|
INDArray p2d = Nd4j.vstack(rowsP);
|
|
INDArray l2d = Nd4j.vstack(rowsL);
|
|
|
|
EvaluationBinary e3d_m2d = new EvaluationBinary();
|
|
EvaluationBinary e2d_m2d = new EvaluationBinary();
|
|
e3d_m2d.eval(label, prediction, mask2d);
|
|
e2d_m2d.eval(l2d, p2d);
|
|
|
|
|
|
|
|
//Check per-output masking:
|
|
INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape());
|
|
rowsP.clear();
|
|
rowsL.clear();
|
|
List<INDArray> rowsM = new ArrayList<>();
|
|
iter = new NdIndexIterator(2, 10);
|
|
while (iter.hasNext()) {
|
|
long[] idx = iter.next();
|
|
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
|
|
rowsP.add(prediction.get(idxs));
|
|
rowsL.add(label.get(idxs));
|
|
rowsM.add(perOutMask.get(idxs));
|
|
}
|
|
p2d = Nd4j.vstack(rowsP);
|
|
l2d = Nd4j.vstack(rowsL);
|
|
INDArray m2d = Nd4j.vstack(rowsM);
|
|
|
|
EvaluationBinary e4d_m2 = new EvaluationBinary();
|
|
EvaluationBinary e2d_m2 = new EvaluationBinary();
|
|
e4d_m2.eval(label, prediction, perOutMask);
|
|
e2d_m2.eval(l2d, p2d, m2d);
|
|
for(EvaluationBinary.Metric m : EvaluationBinary.Metric.values()){
|
|
for(int i=0; i<3; i++ ) {
|
|
double d1 = e4d_m2.scoreForMetric(m, i);
|
|
double d2 = e2d_m2.scoreForMetric(m, i);
|
|
assertEquals(m.toString(), d2, d1, 1e-6);
|
|
}
|
|
}
|
|
}
|
|
|
|
@Test
|
|
public void testEvaluationBinary4dMasking() {
|
|
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
|
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
|
|
|
|
List<INDArray> rowsP = new ArrayList<>();
|
|
List<INDArray> rowsL = new ArrayList<>();
|
|
|
|
//Check per-example masking:
|
|
INDArray mask1dPerEx = Nd4j.createFromArray(1, 0);
|
|
|
|
NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
|
|
while (iter.hasNext()) {
|
|
long[] idx = iter.next();
|
|
if(mask1dPerEx.getDouble(idx[0]) != 0.0) {
|
|
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
|
|
rowsP.add(prediction.get(idxs));
|
|
rowsL.add(label.get(idxs));
|
|
}
|
|
}
|
|
|
|
INDArray p2d = Nd4j.vstack(rowsP);
|
|
INDArray l2d = Nd4j.vstack(rowsL);
|
|
|
|
EvaluationBinary e4d_m1 = new EvaluationBinary();
|
|
EvaluationBinary e2d_m1 = new EvaluationBinary();
|
|
e4d_m1.eval(label, prediction, mask1dPerEx);
|
|
e2d_m1.eval(l2d, p2d);
|
|
for(EvaluationBinary.Metric m : EvaluationBinary.Metric.values()){
|
|
for( int i=0; i<3; i++ ) {
|
|
double d1 = e4d_m1.scoreForMetric(m, i);
|
|
double d2 = e2d_m1.scoreForMetric(m, i);
|
|
assertEquals(m.toString(), d2, d1, 1e-6);
|
|
}
|
|
}
|
|
|
|
//Check per-output masking:
|
|
INDArray perOutMask = Nd4j.randomBernoulli(0.5, label.shape());
|
|
rowsP.clear();
|
|
rowsL.clear();
|
|
List<INDArray> rowsM = new ArrayList<>();
|
|
iter = new NdIndexIterator(2, 10, 10);
|
|
while (iter.hasNext()) {
|
|
long[] idx = iter.next();
|
|
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
|
|
rowsP.add(prediction.get(idxs));
|
|
rowsL.add(label.get(idxs));
|
|
rowsM.add(perOutMask.get(idxs));
|
|
}
|
|
p2d = Nd4j.vstack(rowsP);
|
|
l2d = Nd4j.vstack(rowsL);
|
|
INDArray m2d = Nd4j.vstack(rowsM);
|
|
|
|
EvaluationBinary e3d_m2 = new EvaluationBinary();
|
|
EvaluationBinary e2d_m2 = new EvaluationBinary();
|
|
e3d_m2.eval(label, prediction, perOutMask);
|
|
e2d_m2.eval(l2d, p2d, m2d);
|
|
for(EvaluationBinary.Metric m : EvaluationBinary.Metric.values()){
|
|
for( int i=0; i<3; i++) {
|
|
double d1 = e3d_m2.scoreForMetric(m, i);
|
|
double d2 = e2d_m2.scoreForMetric(m, i);
|
|
assertEquals(m.toString(), d2, d1, 1e-6);
|
|
}
|
|
}
|
|
}
|
|
}
|