368 lines
14 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.datasets.iterator;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.CifarLoader;
import org.datavec.image.loader.LFWLoader;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.fetchers.DataSetType;
import org.deeplearning4j.datasets.iterator.impl.*;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
2019-06-06 15:21:15 +03:00
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
import org.nd4j.common.io.ClassPathResource;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
public class DataSetIteratorTest extends BaseDL4JTest {
Various fixes (#143) * #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00
@Override
public long getTimeoutMilliseconds() {
return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads
Various fixes (#143) * #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00
}
2019-06-06 15:21:15 +03:00
@Test
public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and
//(b) Labels are a proper one-hot vector (i.e., sum is 1.0)
//Iris:
DataSetIterator iris = new IrisDataSetIterator(1, 5);
int irisC = 0;
while (iris.hasNext()) {
irisC++;
DataSet ds = iris.next();
assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0);
}
assertEquals(5, irisC);
}
@Test
public void testBatchSizeOfOneMnist() throws Exception {
//MNIST:
DataSetIterator mnist = new MnistDataSetIterator(1, 5);
int mnistC = 0;
while (mnist.hasNext()) {
mnistC++;
DataSet ds = mnist.next();
assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0);
}
assertEquals(5, mnistC);
}
@Test
public void testMnist() throws Exception {
ClassPathResource cpr = new ClassPathResource("mnist_first_200.txt");
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(cpr.getTempFileFromArchive()));
RecordReaderDataSetIterator dsi = new RecordReaderDataSetIterator(rr, 10, 0, 10);
MnistDataSetIterator iter = new MnistDataSetIterator(10, 200, false, true, false, 0);
while (dsi.hasNext()) {
DataSet dsExp = dsi.next();
DataSet dsAct = iter.next();
INDArray fExp = dsExp.getFeatures();
fExp.divi(255);
INDArray lExp = dsExp.getLabels();
INDArray fAct = dsAct.getFeatures();
INDArray lAct = dsAct.getLabels();
assertEquals(fExp, fAct.castTo(fExp.dataType()));
assertEquals(lExp, lAct.castTo(lExp.dataType()));
}
assertFalse(iter.hasNext());
}
@Test
public void testLfwIterator() throws Exception {
int numExamples = 1;
int row = 28;
int col = 28;
int channels = 1;
LFWDataSetIterator iter = new LFWDataSetIterator(numExamples, new int[] {row, col, channels}, true);
assertTrue(iter.hasNext());
DataSet data = iter.next();
assertEquals(numExamples, data.getLabels().size(0));
assertEquals(row, data.getFeatures().size(2));
}
@Test
public void testTinyImageNetIterator() throws Exception {
int numClasses = 200;
int row = 64;
int col = 64;
int channels = 3;
TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, DataSetType.TEST);
assertTrue(iter.hasNext());
DataSet data = iter.next();
assertEquals(numClasses, data.getLabels().size(1));
assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape());
}
@Test
public void testTinyImageNetIterator2() throws Exception {
int numClasses = 200;
int row = 224;
int col = 224;
int channels = 3;
TinyImageNetDataSetIterator iter = new TinyImageNetDataSetIterator(1, new int[]{row, col}, DataSetType.TEST);
assertTrue(iter.hasNext());
DataSet data = iter.next();
assertEquals(numClasses, data.getLabels().size(1));
assertArrayEquals(new long[]{1, channels, row, col}, data.getFeatures().shape());
}
@Test
public void testLfwModel() throws Exception {
final int numRows = 28;
final int numColumns = 28;
int numChannels = 3;
int outputNum = LFWLoader.NUM_LABELS;
int numSamples = LFWLoader.NUM_IMAGES;
int batchSize = 2;
int seed = 123;
int listenerFreq = 1;
LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples,
new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed));
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6)
.weightInit(WeightInit.XAVIER).activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.stride(1, 1).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels))
;
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init();
model.setListeners(new ScoreIterationListener(listenerFreq));
model.fit(lfw.next());
DataSet dataTest = lfw.next();
INDArray output = model.output(dataTest.getFeatures());
Evaluation eval = new Evaluation(outputNum);
eval.eval(dataTest.getLabels(), output);
Various fixes (#143) * #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00
// System.out.println(eval.stats());
2019-06-06 15:21:15 +03:00
}
@Test
public void testCifar10Iterator() throws Exception {
int numExamples = 1;
int row = 32;
int col = 32;
int channels = 3;
Cifar10DataSetIterator iter = new Cifar10DataSetIterator(numExamples);
assertTrue(iter.hasNext());
DataSet data = iter.next();
assertEquals(numExamples, data.getLabels().size(0));
assertEquals(channels * row * col, data.getFeatures().ravel().length());
}
@Test @Ignore //Ignored for now - CIFAR iterator needs work - https://github.com/deeplearning4j/deeplearning4j/issues/4673
public void testCifarModel() throws Exception {
// Streaming
runCifar(false);
// Preprocess
runCifar(true);
}
public void runCifar(boolean preProcessCifar) throws Exception {
final int height = 32;
final int width = 32;
int channels = 3;
int outputNum = CifarLoader.NUM_LABELS;
int batchSize = 5;
int seed = 123;
int listenerFreq = 1;
Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize);
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER)
.activation(Activation.RELU).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2})
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(height, width, channels));
MultiLayerNetwork model = new MultiLayerNetwork(builder.build());
model.init();
//model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq)));
CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq);
model.setListeners(listener);
2019-06-06 15:21:15 +03:00
model.fit(cifar);
cifar = new Cifar10DataSetIterator(batchSize);
Evaluation eval = new Evaluation(cifar.getLabels());
while (cifar.hasNext()) {
DataSet testDS = cifar.next(batchSize);
INDArray output = model.output(testDS.getFeatures());
eval.eval(testDS.getLabels(), output);
}
Various fixes (#143) * #8568 ArrayUtil optimization Signed-off-by: AlexDBlack <blacka101@gmail.com> * #6171 Keras ReLU and ELU support Signed-off-by: AlexDBlack <blacka101@gmail.com> * Keras softmax layer import Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8549 Webjars dependency management Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for TF import names ':0' suffix issue / NPE Signed-off-by: AlexDBlack <blacka101@gmail.com> * BiasAdd: fix default data format for TF import Signed-off-by: AlexDBlack <blacka101@gmail.com> * Update zoo test ignores Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8509 SameDiff Listener API - provide frame + iteration Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8520 ND4J Environment Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d Signed-off-by: AlexDBlack <blacka101@gmail.com> * Deconv3d fixes + gradient check Signed-off-by: AlexDBlack <blacka101@gmail.com> * Conv3d fixes + deconv3d DType test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix issue with deconv3d gradinet check weight init Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8579 Fix BaseCudaDataBuffer constructor fix for UINT16 Signed-off-by: AlexDBlack <blacka101@gmail.com> * DataType.isNumerical() returns false for BOOL type Signed-off-by: AlexDBlack <blacka101@gmail.com> * #8504 Reduce Spark log spam for tests Signed-off-by: AlexDBlack <blacka101@gmail.com> * Clean up DL4J gradient check test spam Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Gradient check spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fixes for FlatBuffers mapping Signed-off-by: AlexDBlack <blacka101@gmail.com> * SameDiff log spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tests should extend BaseNd4jTest Signed-off-by: AlexDBlack <blacka101@gmail.com> * Remove debug line in c++ op Signed-off-by: AlexDBlack <blacka101@gmail.com> * ND4J test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * DL4J test spam reduction Signed-off-by: AlexDBlack <blacka101@gmail.com> * More Dl4J and datavec test spam cleanup Signed-off-by: AlexDBlack <blacka101@gmail.com> * Fix for bad conv3d test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Additional test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Embedding layers: don't inherit global default activation function Signed-off-by: AlexDBlack <blacka101@gmail.com> * Trigger CI Signed-off-by: AlexDBlack <blacka101@gmail.com> * Consolidate all BaseDL4JTest classes to single class used everywhere; make timeout configurable per class Signed-off-by: AlexDBlack <blacka101@gmail.com> * Test fixes and timeout increases Signed-off-by: AlexDBlack <blacka101@gmail.com> * Timeouts and PReLU fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Restore libnd4j build threads arg for CUDA build Signed-off-by: AlexDBlack <blacka101@gmail.com> * Increase timeouts on a few tests to avoid spurious failures on some CI machines Signed-off-by: AlexDBlack <blacka101@gmail.com> * More timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * More test timeout fixes Signed-off-by: AlexDBlack <blacka101@gmail.com> * Tweak timeout for one more test Signed-off-by: AlexDBlack <blacka101@gmail.com> * Final tweaks Signed-off-by: AlexDBlack <blacka101@gmail.com> * One more ignore Signed-off-by: AlexDBlack <blacka101@gmail.com>
2020-01-04 13:45:07 +11:00
// System.out.println(eval.stats(true));
listener.exportScores(System.out);
2019-06-06 15:21:15 +03:00
}
@Test
public void testIteratorDataSetIteratorCombining() {
//Test combining of a bunch of small (size 1) data sets together
int batchSize = 3;
int numBatches = 4;
int featureSize = 5;
int labelSize = 6;
Nd4j.getRandom().setSeed(12345);
List<DataSet> orig = new ArrayList<>();
for (int i = 0; i < batchSize * numBatches; i++) {
INDArray features = Nd4j.rand(1, featureSize);
INDArray labels = Nd4j.rand(1, labelSize);
orig.add(new DataSet(features, labels));
}
DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
int count = 0;
while (iter.hasNext()) {
DataSet ds = iter.next();
assertArrayEquals(new long[] {batchSize, featureSize}, ds.getFeatures().shape());
assertArrayEquals(new long[] {batchSize, labelSize}, ds.getLabels().shape());
List<INDArray> fList = new ArrayList<>();
List<INDArray> lList = new ArrayList<>();
for (int i = 0; i < batchSize; i++) {
DataSet dsOrig = orig.get(count * batchSize + i);
fList.add(dsOrig.getFeatures());
lList.add(dsOrig.getLabels());
}
INDArray fExp = Nd4j.vstack(fList);
INDArray lExp = Nd4j.vstack(lList);
assertEquals(fExp, ds.getFeatures());
assertEquals(lExp, ds.getLabels());
count++;
}
assertEquals(count, numBatches);
}
@Test
public void testIteratorDataSetIteratorSplitting() {
//Test splitting large data sets into smaller ones
int origBatchSize = 4;
int origNumDSs = 3;
int batchSize = 3;
int numBatches = 4;
int featureSize = 5;
int labelSize = 6;
Nd4j.getRandom().setSeed(12345);
List<DataSet> orig = new ArrayList<>();
for (int i = 0; i < origNumDSs; i++) {
INDArray features = Nd4j.rand(origBatchSize, featureSize);
INDArray labels = Nd4j.rand(origBatchSize, labelSize);
orig.add(new DataSet(features, labels));
}
List<DataSet> expected = new ArrayList<>();
expected.add(new DataSet(orig.get(0).getFeatures().getRows(0, 1, 2),
orig.get(0).getLabels().getRows(0, 1, 2)));
expected.add(new DataSet(
Nd4j.vstack(orig.get(0).getFeatures().getRows(3),
orig.get(1).getFeatures().getRows(0, 1)),
Nd4j.vstack(orig.get(0).getLabels().getRows(3), orig.get(1).getLabels().getRows(0, 1))));
expected.add(new DataSet(
Nd4j.vstack(orig.get(1).getFeatures().getRows(2, 3),
orig.get(2).getFeatures().getRows(0)),
Nd4j.vstack(orig.get(1).getLabels().getRows(2, 3), orig.get(2).getLabels().getRows(0))));
expected.add(new DataSet(orig.get(2).getFeatures().getRows(1, 2, 3),
orig.get(2).getLabels().getRows(1, 2, 3)));
DataSetIterator iter = new IteratorDataSetIterator(orig.iterator(), batchSize);
int count = 0;
while (iter.hasNext()) {
DataSet ds = iter.next();
assertEquals(expected.get(count), ds);
count++;
}
assertEquals(count, numBatches);
}
}