146 lines
5.4 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.eval;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.core.evaluation.EvaluationTools;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.Arrays;
import java.util.Random;
public class EvaluationToolsTests extends BaseDL4JTest {
@Test
public void testRocHtml() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list()
.layer(0, DenseLayer.builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
OutputLayer.builder().nIn(4).nOut(2).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
INDArray newLabels = Nd4j.create(150, 2);
newLabels.getColumn(0).assign(ds.getLabels().getColumn(0));
newLabels.getColumn(0).addi(ds.getLabels().getColumn(1));
newLabels.getColumn(1).assign(ds.getLabels().getColumn(2));
ds.setLabels(newLabels);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
for (int numSteps : new int[] {20, 0}) {
ROC roc = new ROC(numSteps);
iter.reset();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
roc.eval(l, out);
String str = EvaluationTools.rocChartToHtml(roc);
System.out.println(str);
}
}
@Test
public void testRocMultiToHtml() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list()
.layer(0, DenseLayer.builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1,
OutputLayer.builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
for (int numSteps : new int[] {20, 0}) {
ROCMultiClass roc = new ROCMultiClass(numSteps);
iter.reset();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
roc.eval(l, out);
String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
// System.out.println(str);
}
}
@Test
public void testEvaluationCalibrationToHtml() throws Exception {
int minibatch = 1000;
int nClasses = 3;
INDArray arr = Nd4j.rand(minibatch, nClasses);
arr.diviColumnVector(arr.sum(1));
INDArray labels = Nd4j.zeros(minibatch, nClasses);
Random r = new Random(12345);
for (int i = 0; i < minibatch; i++) {
labels.putScalar(i, r.nextInt(nClasses), 1.0);
}
int numBins = 10;
EvaluationCalibration ec = new EvaluationCalibration(numBins, numBins);
ec.eval(labels, arr);
String str = EvaluationTools.evaluationCalibrationToHtml(ec);
// System.out.println(str);
}
}