493 lines
22 KiB
Java
493 lines
22 KiB
Java
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
package org.deeplearning4j.integration.testcases;
|
||
|
|
||
|
import org.datavec.api.split.FileSplit;
|
||
|
import org.datavec.image.loader.NativeImageLoader;
|
||
|
import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader;
|
||
|
import org.datavec.image.recordreader.objdetect.impl.SvhnLabelProvider;
|
||
|
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||
|
import org.deeplearning4j.datasets.fetchers.SvhnDataFetcher;
|
||
|
import org.deeplearning4j.integration.TestCase;
|
||
|
import org.deeplearning4j.datasets.fetchers.DataSetType;
|
||
|
import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator;
|
||
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||
|
import org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter;
|
||
|
import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator;
|
||
|
import org.deeplearning4j.eval.Evaluation;
|
||
|
import org.deeplearning4j.eval.EvaluationCalibration;
|
||
|
import org.deeplearning4j.eval.IEvaluation;
|
||
|
import org.deeplearning4j.eval.ROCMultiClass;
|
||
|
import org.deeplearning4j.nn.api.Model;
|
||
|
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
|
||
|
import org.deeplearning4j.nn.conf.*;
|
||
|
import org.deeplearning4j.nn.conf.inputs.InputType;
|
||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||
|
import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer;
|
||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||
|
import org.deeplearning4j.nn.graph.util.ComputationGraphUtil;
|
||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||
|
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
|
||
|
import org.deeplearning4j.nn.transferlearning.TransferLearning;
|
||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||
|
import org.deeplearning4j.zoo.PretrainedType;
|
||
|
import org.deeplearning4j.zoo.model.TinyYOLO;
|
||
|
import org.deeplearning4j.zoo.model.VGG16;
|
||
|
import org.nd4j.linalg.activations.Activation;
|
||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||
|
import org.nd4j.linalg.dataset.DataSet;
|
||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
||
|
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
|
||
|
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
|
||
|
import org.nd4j.linalg.factory.Nd4j;
|
||
|
import org.nd4j.linalg.learning.config.Adam;
|
||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||
|
import org.nd4j.linalg.primitives.Pair;
|
||
|
|
||
|
import java.io.File;
|
||
|
import java.util.*;
|
||
|
|
||
|
public class CNN2DTestCases {
|
||
|
|
||
|
/**
|
||
|
* Essentially: LeNet MNIST example
|
||
|
*/
|
||
|
public static TestCase getLenetMnist() {
|
||
|
return new TestCase() {
|
||
|
{
|
||
|
testName = "LenetMnist";
|
||
|
testType = TestType.RANDOM_INIT;
|
||
|
testPredictions = true;
|
||
|
testTrainingCurves = true;
|
||
|
testGradients = true;
|
||
|
testParamsPostTraining = true;
|
||
|
testEvaluation = true;
|
||
|
testOverfitting = false;
|
||
|
}
|
||
|
|
||
|
public Object getConfiguration() throws Exception {
|
||
|
int nChannels = 1; // Number of input channels
|
||
|
int outputNum = 10; // The number of possible outcomes
|
||
|
int seed = 123;
|
||
|
|
||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||
|
.seed(seed)
|
||
|
.l2(0.0005)
|
||
|
.weightInit(WeightInit.XAVIER)
|
||
|
.updater(new Nesterovs(0.01, 0.9))
|
||
|
.list()
|
||
|
.layer(0, new ConvolutionLayer.Builder(5, 5)
|
||
|
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
|
||
|
.nIn(nChannels)
|
||
|
.stride(1, 1)
|
||
|
.nOut(20)
|
||
|
.activation(Activation.IDENTITY)
|
||
|
.build())
|
||
|
.layer(1, new SubsamplingLayer.Builder(PoolingType.MAX)
|
||
|
.kernelSize(2, 2)
|
||
|
.stride(2, 2)
|
||
|
.build())
|
||
|
.layer(2, new ConvolutionLayer.Builder(5, 5)
|
||
|
//Note that nIn need not be specified in later layers
|
||
|
.stride(1, 1)
|
||
|
.nOut(50)
|
||
|
.activation(Activation.IDENTITY)
|
||
|
.build())
|
||
|
.layer(3, new SubsamplingLayer.Builder(PoolingType.MAX)
|
||
|
.kernelSize(2, 2)
|
||
|
.stride(2, 2)
|
||
|
.build())
|
||
|
.layer(4, new DenseLayer.Builder().activation(Activation.RELU)
|
||
|
.nOut(500).build())
|
||
|
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||
|
.nOut(outputNum)
|
||
|
.activation(Activation.SOFTMAX)
|
||
|
.build())
|
||
|
.setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below
|
||
|
.build();
|
||
|
|
||
|
return conf;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSet getGradientsTestData() throws Exception {
|
||
|
DataSet ds = new MnistDataSetIterator(8, false, 12345).next();
|
||
|
return new org.nd4j.linalg.dataset.MultiDataSet(ds.getFeatures(), ds.getLabels());
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||
|
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
||
|
|
||
|
iter = new EarlyTerminationDataSetIterator(iter, 60);
|
||
|
return new MultiDataSetIteratorAdapter(iter);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||
|
return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(new MnistDataSetIterator(32, false, 12345), 10));
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public List<Pair<INDArray[],INDArray[]>> getPredictionsTestData() throws Exception {
|
||
|
DataSetIterator iter = new MnistDataSetIterator(8, true, 12345);
|
||
|
List<Pair<INDArray[], INDArray[]>> list = new ArrayList<>();
|
||
|
|
||
|
DataSet ds = iter.next();
|
||
|
ds = ds.asList().get(0);
|
||
|
list.add(new Pair<>(new INDArray[]{ds.getFeatures()}, null));
|
||
|
|
||
|
ds = iter.next();
|
||
|
list.add(new Pair<>(new INDArray[]{ds.getFeatures()}, null));
|
||
|
return list;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public IEvaluation[] getNewEvaluations(){
|
||
|
return new IEvaluation[]{
|
||
|
new Evaluation(),
|
||
|
new ROCMultiClass()};
|
||
|
}
|
||
|
|
||
|
};
|
||
|
}
|
||
|
|
||
|
|
||
|
/**
|
||
|
* VGG16 + transfer learning + tiny imagenet
|
||
|
*/
|
||
|
public static TestCase getVGG16TransferTinyImagenet() {
|
||
|
return new TestCase() {
|
||
|
|
||
|
{
|
||
|
testName = "VGG16TransferTinyImagenet_224";
|
||
|
testType = TestType.PRETRAINED;
|
||
|
testPredictions = true;
|
||
|
testTrainingCurves = true;
|
||
|
testGradients = false; //Skip - requires saving approx 1GB of data (gradients x2)
|
||
|
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
|
||
|
testEvaluation = false;
|
||
|
testOverfitting = false;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public Model getPretrainedModel() throws Exception {
|
||
|
VGG16 vgg16 = VGG16.builder()
|
||
|
.seed(12345)
|
||
|
.build();
|
||
|
|
||
|
ComputationGraph pretrained = (ComputationGraph) vgg16.initPretrained(PretrainedType.IMAGENET);
|
||
|
|
||
|
//Transfer learning
|
||
|
ComputationGraph newGraph = new TransferLearning.GraphBuilder(pretrained)
|
||
|
.fineTuneConfiguration(new FineTuneConfiguration.Builder()
|
||
|
.updater(new Adam(1e-3))
|
||
|
.seed(12345)
|
||
|
.build())
|
||
|
.removeVertexKeepConnections("predictions")
|
||
|
.addLayer("predictions", new OutputLayer.Builder()
|
||
|
.nIn(4096)
|
||
|
.nOut(200) //Tiny imagenet
|
||
|
.build(), "fc2")
|
||
|
.build();
|
||
|
|
||
|
return newGraph;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public List<Pair<INDArray[], INDArray[]>> getPredictionsTestData() throws Exception {
|
||
|
List<Pair<INDArray[], INDArray[]>> out = new ArrayList<>();
|
||
|
|
||
|
DataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[]{224, 224}, DataSetType.TRAIN, null, 12345);
|
||
|
iter.setPreProcessor(new VGG16ImagePreProcessor());
|
||
|
DataSet ds = iter.next();
|
||
|
out.add(new Pair<>(new INDArray[]{ds.getFeatures()}, null));
|
||
|
|
||
|
iter = new TinyImageNetDataSetIterator(3, new int[]{224, 224}, DataSetType.TRAIN, null, 54321);
|
||
|
iter.setPreProcessor(new VGG16ImagePreProcessor());
|
||
|
ds = iter.next();
|
||
|
out.add(new Pair<>(new INDArray[]{ds.getFeatures()}, null));
|
||
|
|
||
|
return out;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSet getGradientsTestData() throws Exception {
|
||
|
DataSet ds = new TinyImageNetDataSetIterator(8, new int[]{224, 224}, DataSetType.TRAIN, null, 12345).next();
|
||
|
return new org.nd4j.linalg.dataset.MultiDataSet(ds.getFeatures(), ds.getLabels());
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||
|
DataSetIterator iter = new TinyImageNetDataSetIterator(4, new int[]{224, 224}, DataSetType.TRAIN, null, 12345);
|
||
|
iter.setPreProcessor(new VGG16ImagePreProcessor());
|
||
|
|
||
|
iter = new EarlyTerminationDataSetIterator(iter, 2);
|
||
|
return new MultiDataSetIteratorAdapter(iter);
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
|
||
|
|
||
|
/**
|
||
|
* Basically a cut-down version of the YOLO house numbers example
|
||
|
*/
|
||
|
public static TestCase getYoloHouseNumbers() {
|
||
|
return new TestCase() {
|
||
|
|
||
|
private int width = 416;
|
||
|
private int height = 416;
|
||
|
private int nChannels = 3;
|
||
|
private int gridWidth = 13;
|
||
|
private int gridHeight = 13;
|
||
|
|
||
|
{
|
||
|
testName = "YOLOHouseNumbers";
|
||
|
testType = TestType.PRETRAINED;
|
||
|
testPredictions = true;
|
||
|
testTrainingCurves = true;
|
||
|
testGradients = false; //Skip - requires saving approx 1GB of data (gradients x2)
|
||
|
testParamsPostTraining = false; //Skip - requires saving all params (approx 500mb)
|
||
|
testEvaluation = false;
|
||
|
testOverfitting = false;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public Model getPretrainedModel() throws Exception {
|
||
|
int nClasses = 10;
|
||
|
int nBoxes = 5;
|
||
|
double lambdaNoObj = 0.5;
|
||
|
double lambdaCoord = 1.0;
|
||
|
double[][] priorBoxes = {{2, 5}, {2.5, 6}, {3, 7}, {3.5, 8}, {4, 9}};
|
||
|
double learningRate = 1e-4;
|
||
|
ComputationGraph pretrained = (ComputationGraph) TinyYOLO.builder().build().initPretrained();
|
||
|
INDArray priors = Nd4j.create(priorBoxes);
|
||
|
|
||
|
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
|
||
|
.seed(12345)
|
||
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
|
||
|
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
|
||
|
.gradientNormalizationThreshold(1.0)
|
||
|
.updater(new Adam(learningRate))
|
||
|
.l2(0.00001)
|
||
|
.activation(Activation.IDENTITY)
|
||
|
.trainingWorkspaceMode(WorkspaceMode.ENABLED)
|
||
|
.inferenceWorkspaceMode(WorkspaceMode.ENABLED)
|
||
|
.build();
|
||
|
|
||
|
ComputationGraph model = new TransferLearning.GraphBuilder(pretrained)
|
||
|
.fineTuneConfiguration(fineTuneConf)
|
||
|
.removeVertexKeepConnections("conv2d_9")
|
||
|
.addLayer("convolution2d_9",
|
||
|
new ConvolutionLayer.Builder(1,1)
|
||
|
.nIn(1024)
|
||
|
.nOut(nBoxes * (5 + nClasses))
|
||
|
.stride(1,1)
|
||
|
.convolutionMode(ConvolutionMode.Same)
|
||
|
.weightInit(WeightInit.XAVIER)
|
||
|
.activation(Activation.IDENTITY)
|
||
|
.build(),
|
||
|
"leaky_re_lu_8")
|
||
|
.addLayer("outputs",
|
||
|
new Yolo2OutputLayer.Builder()
|
||
|
.lambdaNoObj(lambdaNoObj)
|
||
|
.lambdaCoord(lambdaCoord)
|
||
|
.boundingBoxPriors(priors)
|
||
|
.build(),
|
||
|
"convolution2d_9")
|
||
|
.setOutputs("outputs")
|
||
|
.build();
|
||
|
|
||
|
return model;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public List<Pair<INDArray[], INDArray[]>> getPredictionsTestData() throws Exception {
|
||
|
MultiDataSet mds = getTrainingData().next();
|
||
|
return Collections.singletonList(new Pair<>(mds.getFeatures(), null));
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSet getGradientsTestData() throws Exception {
|
||
|
return getTrainingData().next();
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||
|
SvhnDataFetcher fetcher = new SvhnDataFetcher();
|
||
|
File testDir = fetcher.getDataSetPath(DataSetType.TEST);
|
||
|
|
||
|
FileSplit testData = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, new Random(12345));
|
||
|
ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels,
|
||
|
gridHeight, gridWidth, new SvhnLabelProvider(testDir));
|
||
|
recordReaderTest.initialize(testData);
|
||
|
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 2, 1, 1, true);
|
||
|
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));
|
||
|
|
||
|
return new MultiDataSetIteratorAdapter(new EarlyTerminationDataSetIterator(test, 2));
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
|
||
|
|
||
|
/**
|
||
|
* A synthetic 2D CNN that uses all layers:
|
||
|
* Convolution, Subsampling, Upsampling, Cropping, Depthwise conv, separable conv, deconv, space to batch,
|
||
|
* space to depth, zero padding, batch norm, LRN
|
||
|
*/
|
||
|
public static TestCase getCnn2DSynthetic() {
|
||
|
|
||
|
throw new UnsupportedOperationException("Not yet implemented");
|
||
|
}
|
||
|
|
||
|
|
||
|
public static TestCase testLenetTransferDropoutRepeatability() {
|
||
|
return new TestCase() {
|
||
|
|
||
|
{
|
||
|
testName = "LenetDropoutRepeatability";
|
||
|
testType = TestType.PRETRAINED;
|
||
|
testPredictions = true;
|
||
|
testTrainingCurves = true;
|
||
|
testGradients = true;
|
||
|
testParamsPostTraining = true;
|
||
|
testEvaluation = true;
|
||
|
testOverfitting = true;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public Model getPretrainedModel() throws Exception {
|
||
|
|
||
|
Map<Integer, Double> lrSchedule = new HashMap<>();
|
||
|
lrSchedule.put(0, 0.01);
|
||
|
lrSchedule.put(1000, 0.005);
|
||
|
lrSchedule.put(3000, 0.001);
|
||
|
|
||
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||
|
.seed(12345)
|
||
|
.l2(0.0005)
|
||
|
.weightInit(WeightInit.XAVIER)
|
||
|
.updater(new Nesterovs(0.01, 0.9))
|
||
|
.list()
|
||
|
.layer(0, new ConvolutionLayer.Builder(5, 5)
|
||
|
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
|
||
|
.nIn(1)
|
||
|
.stride(1, 1)
|
||
|
.nOut(20)
|
||
|
.activation(Activation.IDENTITY)
|
||
|
.build())
|
||
|
.layer(1, new SubsamplingLayer.Builder(PoolingType.MAX)
|
||
|
.kernelSize(2, 2)
|
||
|
.stride(2, 2)
|
||
|
.build())
|
||
|
.layer(2, new ConvolutionLayer.Builder(5, 5)
|
||
|
//Note that nIn need not be specified in later layers
|
||
|
.stride(1, 1)
|
||
|
.nOut(50)
|
||
|
.activation(Activation.IDENTITY)
|
||
|
.dropOut(0.5) //**** Dropout on conv layer
|
||
|
.build())
|
||
|
.layer(3, new SubsamplingLayer.Builder(PoolingType.MAX)
|
||
|
.kernelSize(2, 2)
|
||
|
.stride(2, 2)
|
||
|
.build())
|
||
|
.layer(4, new DenseLayer.Builder().activation(Activation.RELU)
|
||
|
.dropOut(0.5) //**** Dropout on dense layer
|
||
|
.nOut(500).build())
|
||
|
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||
|
.nOut(10)
|
||
|
.activation(Activation.SOFTMAX)
|
||
|
.build())
|
||
|
.setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below
|
||
|
.build();
|
||
|
|
||
|
|
||
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||
|
net.init();
|
||
|
|
||
|
DataSetIterator iter = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(16, true, 12345), 10);
|
||
|
net.fit(iter);
|
||
|
|
||
|
MultiLayerNetwork pretrained = new TransferLearning.Builder(net)
|
||
|
.fineTuneConfiguration(
|
||
|
FineTuneConfiguration.builder()
|
||
|
.updater(new Nesterovs(0.01, 0.9))
|
||
|
.seed(98765)
|
||
|
.build())
|
||
|
.nOutReplace(5, 10, WeightInit.XAVIER)
|
||
|
.build();
|
||
|
|
||
|
return pretrained;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public List<Pair<INDArray[], INDArray[]>> getPredictionsTestData() throws Exception {
|
||
|
MnistDataSetIterator iter = new MnistDataSetIterator(1, true, 12345);
|
||
|
List<Pair<INDArray[], INDArray[]>> out = new ArrayList<>();
|
||
|
out.add(new Pair<>(new INDArray[]{iter.next().getFeatures()}, null));
|
||
|
|
||
|
iter = new MnistDataSetIterator(10, true, 12345);
|
||
|
out.add(new Pair<>(new INDArray[]{iter.next().getFeatures()}, null));
|
||
|
return out;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSet getGradientsTestData() throws Exception {
|
||
|
DataSet ds = new MnistDataSetIterator(10, true, 12345).next();
|
||
|
return new org.nd4j.linalg.dataset.MultiDataSet(ds.getFeatures(), ds.getLabels());
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSetIterator getTrainingData() throws Exception {
|
||
|
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
||
|
iter = new EarlyTerminationDataSetIterator(iter, 32);
|
||
|
return new MultiDataSetIteratorAdapter(iter);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public IEvaluation[] getNewEvaluations() {
|
||
|
return new IEvaluation[]{
|
||
|
new Evaluation(),
|
||
|
new ROCMultiClass(),
|
||
|
new EvaluationCalibration()
|
||
|
};
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSetIterator getEvaluationTestData() throws Exception {
|
||
|
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
||
|
iter = new EarlyTerminationDataSetIterator(iter, 10);
|
||
|
return new MultiDataSetIteratorAdapter(iter);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public MultiDataSet getOverfittingData() throws Exception {
|
||
|
DataSet ds = new MnistDataSetIterator(1, true, 12345).next();
|
||
|
return ComputationGraphUtil.toMultiDataSet(ds);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public int getOverfitNumIterations() {
|
||
|
return 200;
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
}
|