103 lines
3.9 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.earlystopping.scorecalc;
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.evaluation.regression.RegressionEvaluation;
SameDiff: Listener changes and training api update (#99) * example api Signed-off-by: Ryan Nett <rnett@skymind.io> * Lambda based evaluation Signed-off-by: Ryan Nett <rnett@skymind.io> * lambda test Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * partial fixes, use get-variable listener framework, example EvaluationListener Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc fix and newInstance implementations Signed-off-by: Ryan Nett <rnett@skymind.io> * fit and evaluate methods with validation data (for fit) and listeners Signed-off-by: Ryan Nett <rnett@skymind.io> * output method overloads + listener args Signed-off-by: Ryan Nett <rnett@skymind.io> * history and evaluation helpers Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * FitConfig and added getters and setters Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, javadoc, added activations to history, added latest activation listener Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, start of tests Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes and updates Signed-off-by: Ryan Nett <rnett@skymind.io> * newInstance fixes, tests Signed-off-by: Ryan Nett <rnett@skymind.io> * test fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs, getters with SDVariable overrides, CustomEvaluation fix Signed-off-by: Ryan Nett <rnett@skymind.io> * more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests Signed-off-by: Ryan Nett <rnett@skymind.io> * merge fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * fix Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, most old fit/evaluate/output methods use the builders Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * numerous fixes/cleanup Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * Polish round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Formatting + round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 4 Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
/**
* Score function for a MultiLayerNetwork or ComputationGraph with a single
* {@link org.deeplearning4j.nn.conf.layers.AutoEncoder} layer.
SameDiff: Listener changes and training api update (#99) * example api Signed-off-by: Ryan Nett <rnett@skymind.io> * Lambda based evaluation Signed-off-by: Ryan Nett <rnett@skymind.io> * lambda test Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * partial fixes, use get-variable listener framework, example EvaluationListener Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc fix and newInstance implementations Signed-off-by: Ryan Nett <rnett@skymind.io> * fit and evaluate methods with validation data (for fit) and listeners Signed-off-by: Ryan Nett <rnett@skymind.io> * output method overloads + listener args Signed-off-by: Ryan Nett <rnett@skymind.io> * history and evaluation helpers Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * FitConfig and added getters and setters Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, javadoc, added activations to history, added latest activation listener Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, start of tests Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes and updates Signed-off-by: Ryan Nett <rnett@skymind.io> * newInstance fixes, tests Signed-off-by: Ryan Nett <rnett@skymind.io> * test fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs, getters with SDVariable overrides, CustomEvaluation fix Signed-off-by: Ryan Nett <rnett@skymind.io> * more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests Signed-off-by: Ryan Nett <rnett@skymind.io> * merge fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * fix Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, most old fit/evaluate/output methods use the builders Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * numerous fixes/cleanup Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * Polish round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Formatting + round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 4 Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
* Calculates the specified {@link Metric} on the layer's reconstructions.
2019-06-06 15:21:15 +03:00
*
* @author Alex Black
*/
public class AutoencoderScoreCalculator extends BaseScoreCalculator<Model> {
SameDiff: Listener changes and training api update (#99) * example api Signed-off-by: Ryan Nett <rnett@skymind.io> * Lambda based evaluation Signed-off-by: Ryan Nett <rnett@skymind.io> * lambda test Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * partial fixes, use get-variable listener framework, example EvaluationListener Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc fix and newInstance implementations Signed-off-by: Ryan Nett <rnett@skymind.io> * fit and evaluate methods with validation data (for fit) and listeners Signed-off-by: Ryan Nett <rnett@skymind.io> * output method overloads + listener args Signed-off-by: Ryan Nett <rnett@skymind.io> * history and evaluation helpers Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * FitConfig and added getters and setters Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, javadoc, added activations to history, added latest activation listener Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, start of tests Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes and updates Signed-off-by: Ryan Nett <rnett@skymind.io> * newInstance fixes, tests Signed-off-by: Ryan Nett <rnett@skymind.io> * test fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs, getters with SDVariable overrides, CustomEvaluation fix Signed-off-by: Ryan Nett <rnett@skymind.io> * more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests Signed-off-by: Ryan Nett <rnett@skymind.io> * merge fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * fix Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, most old fit/evaluate/output methods use the builders Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * numerous fixes/cleanup Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * Polish round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Formatting + round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 4 Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
protected final Metric metric;
2019-06-06 15:21:15 +03:00
protected RegressionEvaluation evaluation;
SameDiff: Listener changes and training api update (#99) * example api Signed-off-by: Ryan Nett <rnett@skymind.io> * Lambda based evaluation Signed-off-by: Ryan Nett <rnett@skymind.io> * lambda test Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * partial fixes, use get-variable listener framework, example EvaluationListener Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc fix and newInstance implementations Signed-off-by: Ryan Nett <rnett@skymind.io> * fit and evaluate methods with validation data (for fit) and listeners Signed-off-by: Ryan Nett <rnett@skymind.io> * output method overloads + listener args Signed-off-by: Ryan Nett <rnett@skymind.io> * history and evaluation helpers Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * more fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * FitConfig and added getters and setters Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, javadoc, added activations to history, added latest activation listener Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, start of tests Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes and updates Signed-off-by: Ryan Nett <rnett@skymind.io> * newInstance fixes, tests Signed-off-by: Ryan Nett <rnett@skymind.io> * test fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadocs, getters with SDVariable overrides, CustomEvaluation fix Signed-off-by: Ryan Nett <rnett@skymind.io> * more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests Signed-off-by: Ryan Nett <rnett@skymind.io> * merge fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * fix Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes, most old fit/evaluate/output methods use the builders Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * numerous fixes/cleanup Signed-off-by: Ryan Nett <rnett@skymind.io> * fixes Signed-off-by: Ryan Nett <rnett@skymind.io> * javadoc Signed-off-by: Ryan Nett <rnett@skymind.io> * Polish round 1 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 2 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Formatting + round 3 Signed-off-by: AlexDBlack <blacka101@gmail.com> * Round 4 Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
public AutoencoderScoreCalculator(Metric metric, DataSetIterator iterator){
2019-06-06 15:21:15 +03:00
super(iterator);
this.metric = metric;
}
@Override
protected void reset() {
evaluation = new RegressionEvaluation();
}
@Override
protected INDArray output(Model net, INDArray input, INDArray fMask, INDArray lMask) {
Layer l;
if(net instanceof MultiLayerNetwork) {
MultiLayerNetwork network = (MultiLayerNetwork)net;
l = network.getLayer(0);
} else {
ComputationGraph network = (ComputationGraph)net;
l = network.getLayer(0);
}
if (!(l instanceof AutoEncoder)) {
throw new UnsupportedOperationException("Can only score networks with autoencoder layers as first layer -" +
" got " + l.getClass().getSimpleName());
}
AutoEncoder ae = (AutoEncoder) l;
LayerWorkspaceMgr workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
INDArray encode = ae.encode(input, false, workspaceMgr);
return ae.decode(encode, workspaceMgr);
}
@Override
protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) {
return new INDArray[]{output(network, get0(input), get0(fMask), get0(lMask))};
}
@Override
protected double scoreMinibatch(Model network, INDArray features, INDArray labels, INDArray fMask,
INDArray lMask, INDArray output) {
evaluation.eval(features, output);
return 0.0; //Not used
}
@Override
protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
return scoreMinibatch(network, get0(features), get0(labels), get0(fMask), get0(lMask), get0(output));
}
@Override
protected double finalScore(double scoreSum, int minibatchCount, int exampleCount) {
return evaluation.scoreForMetric(metric);
}
@Override
public boolean minimizeScore() {
return metric.minimize();
}
}