Alex Black 29104083cc
Various fixes (#143)
* #8568 ArrayUtil optimization

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #6171 Keras ReLU and ELU support

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Keras softmax layer import

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8549 Webjars dependency management

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for TF import names ':0' suffix issue / NPE

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* BiasAdd: fix default data format for TF import

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Update zoo test ignores

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8509 SameDiff Listener API - provide frame + iteration

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8520 ND4J Environment

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Deconv3d

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Deconv3d fixes + gradient check

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Conv3d fixes + deconv3d DType test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix issue with deconv3d gradinet check weight init

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8579 Fix BaseCudaDataBuffer constructor fix for UINT16

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DataType.isNumerical() returns false for BOOL type

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* #8504 Reduce Spark log spam for tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Clean up DL4J gradient check test spam

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More Gradient check spam reduction

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff test spam reduction

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes for FlatBuffers mapping

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff log spam cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tests should extend BaseNd4jTest

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove debug line in c++ op

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* ND4J test spam cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J test spam reduction

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More Dl4J and datavec test spam cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for bad conv3d test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Additional test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Embedding layers: don't inherit global default activation function

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Trigger CI

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Test fixes and timeout increases

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Timeouts and PReLU fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Restore libnd4j build threads arg for CUDA build

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Increase timeouts on a few tests to avoid spurious failures on some CI machines

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More timeout fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More test timeout fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Tweak timeout for one more test

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Final tweaks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* One more ignore

Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00

293 lines
11 KiB
Java

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.evaluation;
import org.junit.Test;
import org.nd4j.evaluation.classification.*;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.PrecisionRecallCurve;
import org.nd4j.evaluation.curves.RocCurve;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static junit.framework.TestCase.assertNull;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class EvalJsonTest extends BaseNd4jTest {
public EvalJsonTest(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testSerdeEmpty() {
boolean print = false;
IEvaluation[] arr = new IEvaluation[] {new Evaluation(), new EvaluationBinary(), new ROCBinary(10),
new ROCMultiClass(10), new RegressionEvaluation(3), new RegressionEvaluation(),
new EvaluationCalibration()};
for (IEvaluation e : arr) {
String json = e.toJson();
String stats = e.stats();
if (print) {
System.out.println(e.getClass() + "\n" + json + "\n\n");
}
IEvaluation fromJson = BaseEvaluation.fromJson(json, BaseEvaluation.class);
assertEquals(e.toJson(), fromJson.toJson());
}
}
@Test
public void testSerde() {
boolean print = false;
Nd4j.getRandom().setSeed(12345);
Evaluation evaluation = new Evaluation();
EvaluationBinary evaluationBinary = new EvaluationBinary();
ROC roc = new ROC(2);
ROCBinary roc2 = new ROCBinary(2);
ROCMultiClass roc3 = new ROCMultiClass(2);
RegressionEvaluation regressionEvaluation = new RegressionEvaluation();
EvaluationCalibration ec = new EvaluationCalibration();
IEvaluation[] arr = new IEvaluation[] {evaluation, evaluationBinary, roc, roc2, roc3, regressionEvaluation, ec};
INDArray evalLabel = Nd4j.create(10, 3);
for (int i = 0; i < 10; i++) {
evalLabel.putScalar(i, i % 3, 1.0);
}
INDArray evalProb = Nd4j.rand(10, 3);
evalProb.diviColumnVector(evalProb.sum(1));
evaluation.eval(evalLabel, evalProb);
roc3.eval(evalLabel, evalProb);
ec.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 3), 0.5));
evalProb = Nd4j.rand(10, 3);
evaluationBinary.eval(evalLabel, evalProb);
roc2.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(10, 1), 0.5));
evalProb = Nd4j.rand(10, 1);
roc.eval(evalLabel, evalProb);
regressionEvaluation.eval(Nd4j.rand(10, 3), Nd4j.rand(10, 3));
for (IEvaluation e : arr) {
String json = e.toJson();
if (print) {
System.out.println(e.getClass() + "\n" + json + "\n\n");
}
IEvaluation fromJson = BaseEvaluation.fromJson(json, BaseEvaluation.class);
assertEquals(e.toJson(), fromJson.toJson());
}
}
@Test
public void testSerdeExactRoc() {
Nd4j.getRandom().setSeed(12345);
boolean print = false;
ROC roc = new ROC(0);
ROCBinary roc2 = new ROCBinary(0);
ROCMultiClass roc3 = new ROCMultiClass(0);
IEvaluation[] arr = new IEvaluation[] {roc, roc2, roc3};
INDArray evalLabel = Nd4j.create(100, 3);
for (int i = 0; i < 100; i++) {
evalLabel.putScalar(i, i % 3, 1.0);
}
INDArray evalProb = Nd4j.rand(100, 3);
evalProb.diviColumnVector(evalProb.sum(1));
roc3.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 3), 0.5));
evalProb = Nd4j.rand(100, 3);
roc2.eval(evalLabel, evalProb);
evalLabel = Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
evalProb = Nd4j.rand(100, 1);
roc.eval(evalLabel, evalProb);
for (IEvaluation e : arr) {
System.out.println(e.getClass());
String json = e.toJson();
String stats = e.stats();
if (print) {
System.out.println(json + "\n\n");
}
IEvaluation fromJson = BaseEvaluation.fromJson(json, BaseEvaluation.class);
assertEquals(e, fromJson);
if (fromJson instanceof ROC) {
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC
assertNull(((ROC) fromJson).getProbAndLabel());
assertTrue(((ROC) fromJson).calculateAUC() > 0.0);
assertTrue(((ROC) fromJson).calculateAUCPR() > 0.0);
assertEquals(((ROC) e).getRocCurve(), ((ROC) fromJson).getRocCurve());
assertEquals(((ROC) e).getPrecisionRecallCurve(), ((ROC) fromJson).getPrecisionRecallCurve());
} else if (e instanceof ROCBinary) {
org.nd4j.evaluation.classification.ROC[] rocs = ((ROCBinary) fromJson).getUnderlying();
org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCBinary) e).getUnderlying();
// for(ROC r : rocs ){
for (int i = 0; i < origRocs.length; i++) {
org.nd4j.evaluation.classification.ROC r = rocs[i];
org.nd4j.evaluation.classification.ROC origR = origRocs[i];
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
assertNull(r.getProbAndLabel());
assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
assertEquals(origR.getRocCurve(), origR.getRocCurve());
assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve());
}
} else if (e instanceof ROCMultiClass) {
org.nd4j.evaluation.classification.ROC[] rocs = ((ROCMultiClass) fromJson).getUnderlying();
org.nd4j.evaluation.classification.ROC[] origRocs = ((ROCMultiClass) e).getUnderlying();
for (int i = 0; i < origRocs.length; i++) {
org.nd4j.evaluation.classification.ROC r = rocs[i];
org.nd4j.evaluation.classification.ROC origR = origRocs[i];
//Shouldn't have probAndLabel, but should have stored AUC and AUPRC, AND stored curves
assertNull(r.getProbAndLabel());
assertEquals(origR.calculateAUC(), origR.calculateAUC(), 1e-6);
assertEquals(origR.calculateAUCPR(), origR.calculateAUCPR(), 1e-6);
assertEquals(origR.getRocCurve(), origR.getRocCurve());
assertEquals(origR.getPrecisionRecallCurve(), origR.getPrecisionRecallCurve());
}
}
}
}
@Test
public void testJsonYamlCurves() {
ROC roc = new ROC(0);
INDArray evalLabel =
Nd4j.getExecutioner().exec(new BernoulliDistribution(Nd4j.createUninitialized(100, 1), 0.5));
INDArray evalProb = Nd4j.rand(100, 1);
roc.eval(evalLabel, evalProb);
RocCurve c = roc.getRocCurve();
PrecisionRecallCurve prc = roc.getPrecisionRecallCurve();
String json1 = c.toJson();
String json2 = prc.toJson();
RocCurve c2 = RocCurve.fromJson(json1);
PrecisionRecallCurve prc2 = PrecisionRecallCurve.fromJson(json2);
assertEquals(c, c2);
assertEquals(prc, prc2);
// System.out.println(json1);
//Also test: histograms
EvaluationCalibration ec = new EvaluationCalibration();
evalLabel = Nd4j.create(10, 3);
for (int i = 0; i < 10; i++) {
evalLabel.putScalar(i, i % 3, 1.0);
}
evalProb = Nd4j.rand(10, 3);
evalProb.diviColumnVector(evalProb.sum(1));
ec.eval(evalLabel, evalProb);
Histogram[] histograms = new Histogram[] {ec.getResidualPlotAllClasses(), ec.getResidualPlot(0),
ec.getResidualPlot(1), ec.getProbabilityHistogramAllClasses(), ec.getProbabilityHistogram(0),
ec.getProbabilityHistogram(1)};
for (Histogram h : histograms) {
String json = h.toJson();
String yaml = h.toYaml();
Histogram h2 = Histogram.fromJson(json);
Histogram h3 = Histogram.fromYaml(yaml);
assertEquals(h, h2);
assertEquals(h2, h3);
}
}
@Test
public void testJsonWithCustomThreshold() {
//Evaluation - binary threshold
Evaluation e = new Evaluation(0.25);
String json = e.toJson();
String yaml = e.toYaml();
Evaluation eFromJson = Evaluation.fromJson(json);
Evaluation eFromYaml = Evaluation.fromYaml(yaml);
assertEquals(0.25, eFromJson.getBinaryDecisionThreshold(), 1e-6);
assertEquals(0.25, eFromYaml.getBinaryDecisionThreshold(), 1e-6);
//Evaluation: custom cost array
INDArray costArray = Nd4j.create(new double[] {1.0, 2.0, 3.0});
Evaluation e2 = new Evaluation(costArray);
json = e2.toJson();
yaml = e2.toYaml();
eFromJson = Evaluation.fromJson(json);
eFromYaml = Evaluation.fromYaml(yaml);
assertEquals(e2.getCostArray(), eFromJson.getCostArray());
assertEquals(e2.getCostArray(), eFromYaml.getCostArray());
//EvaluationBinary - per-output binary threshold
INDArray threshold = Nd4j.create(new double[] {1.0, 0.5, 0.25});
EvaluationBinary eb = new EvaluationBinary(threshold);
json = eb.toJson();
yaml = eb.toYaml();
EvaluationBinary ebFromJson = EvaluationBinary.fromJson(json);
EvaluationBinary ebFromYaml = EvaluationBinary.fromYaml(yaml);
assertEquals(threshold, ebFromJson.getDecisionThreshold());
assertEquals(threshold, ebFromYaml.getDecisionThreshold());
}
}