174 lines
6.0 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datasets.iterator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.resources.Resources;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
import static org.junit.Assert.*;
public class MultipleEpochsIteratorTest extends BaseDL4JTest {
@Test
public void testNextAndReset() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, iter);
assertTrue(multiIter.hasNext());
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertFalse(path == null);
}
assertEquals(epochs, multiIter.epochs);
}
@Test
public void testLoadFullDataSet() throws Exception {
int epochs = 3;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150);
DataSet ds = iter.next(50);
assertEquals(50, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
assertTrue(multiIter.hasNext());
int count = 0;
while (multiIter.hasNext()) {
DataSet path = multiIter.next();
assertNotNull(path);
assertEquals(50, path.numExamples(), 0);
count++;
}
assertEquals(epochs, count);
assertEquals(epochs, multiIter.epochs);
}
@Test
public void testLoadBatchDataSet() throws Exception {
int epochs = 2;
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
DataSetIterator iter = new RecordReaderDataSetIterator(rr, 150, 4, 3);
DataSet ds = iter.next(20);
assertEquals(20, ds.getFeatures().size(0));
MultipleEpochsIterator multiIter = new MultipleEpochsIterator(epochs, ds);
while (multiIter.hasNext()) {
DataSet path = multiIter.next(10);
assertNotNull(path);
assertEquals(path.numExamples(), 10, 0.0);
}
assertEquals(epochs, multiIter.epochs);
}
@Test
public void testMEDIWithLoad1() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 1);
long num = consumer.consumeWhileHasNext(true);
assertEquals(10 * 100, num);
}
@Test
public void testMEDIWithLoad2() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(100));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(10, iter, 24);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2);
long num1 = 0;
for (; num1 < 150; num1++) {
consumer.consumeOnce(iterator.next(), true);
}
iterator.reset();
long num2 = consumer.consumeWhileHasNext(true);
assertEquals((10 * 100) + 150, num1 + num2);
}
@Test
public void testMEDIWithLoad3() throws Exception {
ExistingDataSetIterator iter = new ExistingDataSetIterator(new IterableWithoutException(10000));
MultipleEpochsIterator iterator = new MultipleEpochsIterator(iter, 24, 136);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, 2);
long num1 = 0;
while (iterator.hasNext()) {
consumer.consumeOnce(iterator.next(), true);
num1++;
}
assertEquals(136, num1);
}
private class IterableWithoutException implements Iterable<DataSet> {
private final AtomicLong counter = new AtomicLong(0);
private final int datasets;
public IterableWithoutException(int datasets) {
this.datasets = datasets;
}
@Override
public Iterator<DataSet> iterator() {
counter.set(0);
return new Iterator<DataSet>() {
@Override
public boolean hasNext() {
return counter.get() < datasets;
}
@Override
public DataSet next() {
counter.incrementAndGet();
return new DataSet(Nd4j.create(100), Nd4j.create(10));
}
@Override
public void remove() {
}
};
}
}
}