2021-02-01 14:31:20 +09:00
|
|
|
/*
|
|
|
|
|
* ******************************************************************************
|
|
|
|
|
* *
|
|
|
|
|
* *
|
|
|
|
|
* * This program and the accompanying materials are made available under the
|
|
|
|
|
* * terms of the Apache License, Version 2.0 which is available at
|
|
|
|
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
|
* *
|
2021-02-01 17:47:29 +09:00
|
|
|
* * See the NOTICE file distributed with this work for additional
|
|
|
|
|
* * information regarding copyright ownership.
|
2021-02-01 14:31:20 +09:00
|
|
|
* * Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
|
* * License for the specific language governing permissions and limitations
|
|
|
|
|
* * under the License.
|
|
|
|
|
* *
|
|
|
|
|
* * SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
* *****************************************************************************
|
|
|
|
|
*/
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
package org.deeplearning4j.datasets.iterator;
|
|
|
|
|
|
|
|
|
|
import org.deeplearning4j.BaseDL4JTest;
|
|
|
|
|
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
2022-09-20 15:40:53 +02:00
|
|
|
import org.junit.jupiter.api.Assertions;
|
2021-03-15 13:02:01 +09:00
|
|
|
import org.junit.jupiter.api.Test;
|
2020-04-29 11:19:26 +10:00
|
|
|
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
|
2019-06-06 15:21:15 +03:00
|
|
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
|
|
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
2021-03-18 10:58:50 +09:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
import java.io.IOException;
|
|
|
|
|
import java.util.ArrayList;
|
|
|
|
|
import java.util.List;
|
2021-03-16 11:57:24 +09:00
|
|
|
|
2022-10-21 15:19:32 +02:00
|
|
|
import static org.junit.jupiter.api.Assertions.*;
|
2021-03-16 11:57:24 +09:00
|
|
|
|
2022-09-20 15:40:53 +02:00
|
|
|
public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
|
2019-06-06 15:21:15 +03:00
|
|
|
|
|
|
|
|
int minibatchSize = 5;
|
|
|
|
|
int numExamples = 105;
|
2021-03-15 13:02:01 +09:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testNextAndReset() throws Exception {
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
int terminateAfter = 2;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
MultiDataSetIterator iter =
|
|
|
|
|
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
int count = 0;
|
|
|
|
|
List<MultiDataSet> seenMDS = new ArrayList<>();
|
|
|
|
|
while (count < terminateAfter) {
|
|
|
|
|
seenMDS.add(iter.next());
|
|
|
|
|
count++;
|
|
|
|
|
}
|
|
|
|
|
iter.reset();
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
EarlyTerminationMultiDataSetIterator earlyEndIter =
|
|
|
|
|
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
assertTrue(earlyEndIter.hasNext());
|
|
|
|
|
count = 0;
|
|
|
|
|
while (earlyEndIter.hasNext()) {
|
|
|
|
|
MultiDataSet path = earlyEndIter.next();
|
|
|
|
|
assertEquals(path.getFeatures()[0], seenMDS.get(count).getFeatures()[0]);
|
|
|
|
|
assertEquals(path.getLabels()[0], seenMDS.get(count).getLabels()[0]);
|
|
|
|
|
count++;
|
|
|
|
|
}
|
|
|
|
|
assertEquals(count, terminateAfter);
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
//check data is repeated
|
2019-06-06 15:21:15 +03:00
|
|
|
earlyEndIter.reset();
|
|
|
|
|
count = 0;
|
|
|
|
|
while (earlyEndIter.hasNext()) {
|
|
|
|
|
MultiDataSet path = earlyEndIter.next();
|
|
|
|
|
assertEquals(path.getFeatures()[0], seenMDS.get(count).getFeatures()[0]);
|
|
|
|
|
assertEquals(path.getLabels()[0], seenMDS.get(count).getLabels()[0]);
|
|
|
|
|
count++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testNextNum() throws IOException {
|
2019-06-06 15:21:15 +03:00
|
|
|
int terminateAfter = 1;
|
2022-09-20 15:40:53 +02:00
|
|
|
|
|
|
|
|
MultiDataSetIterator iter =
|
|
|
|
|
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
|
|
|
|
|
EarlyTerminationMultiDataSetIterator earlyEndIter =
|
|
|
|
|
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
|
|
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
earlyEndIter.next(10);
|
2022-10-21 15:19:32 +02:00
|
|
|
assertFalse(earlyEndIter.hasNext());
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
earlyEndIter.reset();
|
2022-10-21 15:19:32 +02:00
|
|
|
assertTrue(earlyEndIter.hasNext());
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Test
|
2022-09-20 15:40:53 +02:00
|
|
|
public void testCallstoNextNotAllowed() throws IOException {
|
|
|
|
|
int terminateAfter = 1;
|
|
|
|
|
|
|
|
|
|
MultiDataSetIterator iter =
|
|
|
|
|
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
|
|
|
|
|
EarlyTerminationMultiDataSetIterator earlyEndIter =
|
|
|
|
|
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
|
|
|
|
|
|
|
|
|
|
earlyEndIter.next(10);
|
|
|
|
|
iter.reset();
|
|
|
|
|
Assertions.assertThrows(RuntimeException.class, () -> {
|
2021-03-16 11:57:24 +09:00
|
|
|
earlyEndIter.next(10);
|
|
|
|
|
});
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|
2022-09-20 15:40:53 +02:00
|
|
|
|
2019-06-06 15:21:15 +03:00
|
|
|
}
|