223 lines
8.3 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.datasets.iterator;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback;
import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator;
import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.Before;
import org.junit.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
/**
* @author raver119@gmail.com
*/
@Slf4j
public class AsyncDataSetIteratorTest extends BaseDL4JTest {
private ExistingDataSetIterator backIterator;
private static final int TEST_SIZE = 100;
private static final int ITERATIONS = 10;
// time spent in consumer thread, milliseconds
private static final long EXECUTION_TIME = 5;
private static final long EXECUTION_SMALL = 1;
@Before
public void setUp() throws Exception {
List<DataSet> iterable = new ArrayList<>();
for (int i = 0; i < TEST_SIZE; i++) {
iterable.add(new DataSet(Nd4j.create(new float[100]), Nd4j.create(new float[10])));
}
backIterator = new ExistingDataSetIterator(iterable);
}
@Test
public void hasNext1() throws Exception {
for (int iter = 0; iter < ITERATIONS; iter++) {
for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
int cnt = 0;
while (iterator.hasNext()) {
DataSet ds = iterator.next();
assertNotEquals(null, ds);
cnt++;
}
assertEquals("Failed on iteration: " + iter + ", prefetchSize: " + prefetchSize, TEST_SIZE, cnt);
iterator.shutdown();
}
}
}
@Test
public void hasNextWithResetAndLoad() throws Exception {
for (int iter = 0; iter < ITERATIONS; iter++) {
for (int prefetchSize = 2; prefetchSize <= 8; prefetchSize++) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, prefetchSize);
TestDataSetConsumer consumer = new TestDataSetConsumer(EXECUTION_SMALL);
int cnt = 0;
while (iterator.hasNext()) {
DataSet ds = iterator.next();
consumer.consumeOnce(ds, false);
cnt++;
if (cnt == TEST_SIZE / 2)
iterator.reset();
}
assertEquals(TEST_SIZE + (TEST_SIZE / 2), cnt);
iterator.shutdown();
}
}
}
@Test
public void testWithLoad() {
for (int iter = 0; iter < ITERATIONS; iter++) {
AsyncDataSetIterator iterator = new AsyncDataSetIterator(backIterator, 8);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_TIME);
consumer.consumeWhileHasNext(true);
assertEquals(TEST_SIZE, consumer.getCount());
iterator.shutdown();
}
}
@Test(expected = ArrayIndexOutOfBoundsException.class)
public void testWithException() {
ExistingDataSetIterator crashingIterator = new ExistingDataSetIterator(new IterableWithException(100));
AsyncDataSetIterator iterator = new AsyncDataSetIterator(crashingIterator, 8);
TestDataSetConsumer consumer = new TestDataSetConsumer(iterator, EXECUTION_SMALL);
consumer.consumeWhileHasNext(true);
iterator.shutdown();
}
private class IterableWithException implements Iterable<DataSet> {
private final AtomicLong counter = new AtomicLong(0);
private final int crashIteration;
public IterableWithException(int iteration) {
crashIteration = iteration;
}
@Override
public Iterator<DataSet> iterator() {
counter.set(0);
return new Iterator<DataSet>() {
@Override
public boolean hasNext() {
return true;
}
@Override
public DataSet next() {
if (counter.incrementAndGet() >= crashIteration)
throw new ArrayIndexOutOfBoundsException("Thrown as expected");
return new DataSet(Nd4j.create(10), Nd4j.create(10));
}
@Override
public void remove() {
}
};
}
}
@Test
public void testVariableTimeSeries1() throws Exception {
AsyncDataSetIterator adsi = new AsyncDataSetIterator(
new VariableTimeseriesGenerator(1192, 1000, 32, 128, 10, 500, 10), 2, true);
for (int e = 0; e < 10; e++) {
int cnt = 0;
while (adsi.hasNext()) {
DataSet ds = adsi.next();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
ds.getFeatures().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
ds.getLabels().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10);
cnt++;
}
adsi.reset();
log.info("Epoch {} finished...", e);
}
}
@Test
public void testVariableTimeSeries2() throws Exception {
AsyncDataSetIterator adsi =
new AsyncDataSetIterator(new VariableTimeseriesGenerator(1192, 100, 32, 128, 100, 100, 100), 2,
true, new InterleavedDataSetCallback(2 * 2));
for (int e = 0; e < 5; e++) {
int cnt = 0;
while (adsi.hasNext()) {
DataSet ds = adsi.next();
ds.detach();
//log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address());
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt,
ds.getFeatures().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.25,
ds.getLabels().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.5,
ds.getFeaturesMaskArray().meanNumber().doubleValue(), 1e-10);
assertEquals("Failed on epoch " + e + "; iteration: " + cnt + ";", (double) cnt + 0.75,
ds.getLabelsMaskArray().meanNumber().doubleValue(), 1e-10);
cnt++;
}
adsi.reset();
log.info("Epoch {} finished...", e);
}
}
}