2019-06-06 15:21:15 +03:00
|
|
|
/*******************************************************************************
|
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
|
*
|
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
*
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
|
* under the License.
|
|
|
|
|
*
|
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
|
|
package org.deeplearning4j.earlystopping.scorecalc;
|
|
|
|
|
|
|
|
|
|
import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
|
|
|
|
|
import org.deeplearning4j.nn.api.Layer;
|
|
|
|
|
import org.deeplearning4j.nn.api.Model;
|
|
|
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
|
|
|
import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder;
|
|
|
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
|
|
|
import org.nd4j.evaluation.regression.RegressionEvaluation;
|
SameDiff: Listener changes and training api update (#99)
* example api
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Lambda based evaluation
Signed-off-by: Ryan Nett <rnett@skymind.io>
* lambda test
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* partial fixes, use get-variable listener framework, example EvaluationListener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc fix and newInstance implementations
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fit and evaluate methods with validation data (for fit) and listeners
Signed-off-by: Ryan Nett <rnett@skymind.io>
* output method overloads + listener args
Signed-off-by: Ryan Nett <rnett@skymind.io>
* history and evaluation helpers
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* FitConfig and added getters and setters
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, javadoc, added activations to history, added latest activation listener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, start of tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes and updates
Signed-off-by: Ryan Nett <rnett@skymind.io>
* newInstance fixes, tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* test fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs, getters with SDVariable overrides, CustomEvaluation fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* merge fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, most old fit/evaluate/output methods use the builders
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* numerous fixes/cleanup
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Polish round 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Formatting + round 3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 4
Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
|
|
|
import org.nd4j.evaluation.regression.RegressionEvaluation.Metric;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
|
|
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Score function for a MultiLayerNetwork or ComputationGraph with a single
|
|
|
|
|
* {@link org.deeplearning4j.nn.conf.layers.AutoEncoder} layer.
|
SameDiff: Listener changes and training api update (#99)
* example api
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Lambda based evaluation
Signed-off-by: Ryan Nett <rnett@skymind.io>
* lambda test
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* partial fixes, use get-variable listener framework, example EvaluationListener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc fix and newInstance implementations
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fit and evaluate methods with validation data (for fit) and listeners
Signed-off-by: Ryan Nett <rnett@skymind.io>
* output method overloads + listener args
Signed-off-by: Ryan Nett <rnett@skymind.io>
* history and evaluation helpers
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* FitConfig and added getters and setters
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, javadoc, added activations to history, added latest activation listener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, start of tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes and updates
Signed-off-by: Ryan Nett <rnett@skymind.io>
* newInstance fixes, tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* test fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs, getters with SDVariable overrides, CustomEvaluation fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* merge fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, most old fit/evaluate/output methods use the builders
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* numerous fixes/cleanup
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Polish round 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Formatting + round 3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 4
Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
|
|
|
* Calculates the specified {@link Metric} on the layer's reconstructions.
|
2019-06-06 15:21:15 +03:00
|
|
|
*
|
|
|
|
|
* @author Alex Black
|
|
|
|
|
*/
|
|
|
|
|
public class AutoencoderScoreCalculator extends BaseScoreCalculator<Model> {
|
|
|
|
|
|
SameDiff: Listener changes and training api update (#99)
* example api
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Lambda based evaluation
Signed-off-by: Ryan Nett <rnett@skymind.io>
* lambda test
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* partial fixes, use get-variable listener framework, example EvaluationListener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc fix and newInstance implementations
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fit and evaluate methods with validation data (for fit) and listeners
Signed-off-by: Ryan Nett <rnett@skymind.io>
* output method overloads + listener args
Signed-off-by: Ryan Nett <rnett@skymind.io>
* history and evaluation helpers
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* FitConfig and added getters and setters
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, javadoc, added activations to history, added latest activation listener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, start of tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes and updates
Signed-off-by: Ryan Nett <rnett@skymind.io>
* newInstance fixes, tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* test fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs, getters with SDVariable overrides, CustomEvaluation fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* merge fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, most old fit/evaluate/output methods use the builders
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* numerous fixes/cleanup
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Polish round 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Formatting + round 3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 4
Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
|
|
|
protected final Metric metric;
|
2019-06-06 15:21:15 +03:00
|
|
|
protected RegressionEvaluation evaluation;
|
|
|
|
|
|
SameDiff: Listener changes and training api update (#99)
* example api
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Lambda based evaluation
Signed-off-by: Ryan Nett <rnett@skymind.io>
* lambda test
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* partial fixes, use get-variable listener framework, example EvaluationListener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc fix and newInstance implementations
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fit and evaluate methods with validation data (for fit) and listeners
Signed-off-by: Ryan Nett <rnett@skymind.io>
* output method overloads + listener args
Signed-off-by: Ryan Nett <rnett@skymind.io>
* history and evaluation helpers
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* FitConfig and added getters and setters
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, javadoc, added activations to history, added latest activation listener
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, start of tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes and updates
Signed-off-by: Ryan Nett <rnett@skymind.io>
* newInstance fixes, tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* test fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadocs, getters with SDVariable overrides, CustomEvaluation fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* more operation config classes (evaluation, output, exec/single batch output), fix custom eval tests
Signed-off-by: Ryan Nett <rnett@skymind.io>
* merge fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fix
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes, most old fit/evaluate/output methods use the builders
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* numerous fixes/cleanup
Signed-off-by: Ryan Nett <rnett@skymind.io>
* fixes
Signed-off-by: Ryan Nett <rnett@skymind.io>
* javadoc
Signed-off-by: Ryan Nett <rnett@skymind.io>
* Polish round 1
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 2
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Formatting + round 3
Signed-off-by: AlexDBlack <blacka101@gmail.com>
* Round 4
Signed-off-by: AlexDBlack <blacka101@gmail.com>
2019-08-09 22:30:31 -07:00
|
|
|
public AutoencoderScoreCalculator(Metric metric, DataSetIterator iterator){
|
2019-06-06 15:21:15 +03:00
|
|
|
super(iterator);
|
|
|
|
|
this.metric = metric;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected void reset() {
|
|
|
|
|
evaluation = new RegressionEvaluation();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected INDArray output(Model net, INDArray input, INDArray fMask, INDArray lMask) {
|
|
|
|
|
|
|
|
|
|
Layer l;
|
|
|
|
|
if(net instanceof MultiLayerNetwork) {
|
|
|
|
|
MultiLayerNetwork network = (MultiLayerNetwork)net;
|
|
|
|
|
l = network.getLayer(0);
|
|
|
|
|
} else {
|
|
|
|
|
ComputationGraph network = (ComputationGraph)net;
|
|
|
|
|
l = network.getLayer(0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!(l instanceof AutoEncoder)) {
|
|
|
|
|
throw new UnsupportedOperationException("Can only score networks with autoencoder layers as first layer -" +
|
|
|
|
|
" got " + l.getClass().getSimpleName());
|
|
|
|
|
}
|
|
|
|
|
AutoEncoder ae = (AutoEncoder) l;
|
|
|
|
|
|
|
|
|
|
LayerWorkspaceMgr workspaceMgr = LayerWorkspaceMgr.noWorkspaces();
|
|
|
|
|
INDArray encode = ae.encode(input, false, workspaceMgr);
|
|
|
|
|
return ae.decode(encode, workspaceMgr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) {
|
|
|
|
|
return new INDArray[]{output(network, get0(input), get0(fMask), get0(lMask))};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected double scoreMinibatch(Model network, INDArray features, INDArray labels, INDArray fMask,
|
|
|
|
|
INDArray lMask, INDArray output) {
|
|
|
|
|
evaluation.eval(features, output);
|
|
|
|
|
return 0.0; //Not used
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
|
|
|
|
|
return scoreMinibatch(network, get0(features), get0(labels), get0(fMask), get0(lMask), get0(output));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
protected double finalScore(double scoreSum, int minibatchCount, int exampleCount) {
|
|
|
|
|
return evaluation.scoreForMetric(metric);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public boolean minimizeScore() {
|
|
|
|
|
return metric.minimize();
|
|
|
|
|
}
|
|
|
|
|
}
|