cavis/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java

115 lines
4.0 KiB
Java
Raw Normal View History

2021-02-01 14:31:20 +09:00
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
2021-02-01 17:47:29 +09:00
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
2021-02-01 14:31:20 +09:00
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
2019-06-06 15:21:15 +03:00
package org.deeplearning4j.datasets.iterator;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.Assertions;
2021-03-15 13:02:01 +09:00
import org.junit.jupiter.api.Test;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
2019-06-06 15:21:15 +03:00
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
2019-06-06 15:21:15 +03:00
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
2021-03-16 11:57:24 +09:00
2022-10-21 15:19:32 +02:00
import static org.junit.jupiter.api.Assertions.*;
2021-03-16 11:57:24 +09:00
public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest {
2019-06-06 15:21:15 +03:00
int minibatchSize = 5;
int numExamples = 105;
2021-03-15 13:02:01 +09:00
2019-06-06 15:21:15 +03:00
@Test
public void testNextAndReset() throws Exception {
2019-06-06 15:21:15 +03:00
int terminateAfter = 2;
MultiDataSetIterator iter =
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
2019-06-06 15:21:15 +03:00
int count = 0;
List<MultiDataSet> seenMDS = new ArrayList<>();
while (count < terminateAfter) {
seenMDS.add(iter.next());
count++;
}
iter.reset();
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
2019-06-06 15:21:15 +03:00
assertTrue(earlyEndIter.hasNext());
count = 0;
while (earlyEndIter.hasNext()) {
MultiDataSet path = earlyEndIter.next();
assertEquals(path.getFeatures()[0], seenMDS.get(count).getFeatures()[0]);
assertEquals(path.getLabels()[0], seenMDS.get(count).getLabels()[0]);
count++;
}
assertEquals(count, terminateAfter);
//check data is repeated
2019-06-06 15:21:15 +03:00
earlyEndIter.reset();
count = 0;
while (earlyEndIter.hasNext()) {
MultiDataSet path = earlyEndIter.next();
assertEquals(path.getFeatures()[0], seenMDS.get(count).getFeatures()[0]);
assertEquals(path.getLabels()[0], seenMDS.get(count).getLabels()[0]);
count++;
}
}
@Test
public void testNextNum() throws IOException {
2019-06-06 15:21:15 +03:00
int terminateAfter = 1;
MultiDataSetIterator iter =
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
2019-06-06 15:21:15 +03:00
earlyEndIter.next(10);
2022-10-21 15:19:32 +02:00
assertFalse(earlyEndIter.hasNext());
2019-06-06 15:21:15 +03:00
earlyEndIter.reset();
2022-10-21 15:19:32 +02:00
assertTrue(earlyEndIter.hasNext());
2019-06-06 15:21:15 +03:00
}
@Test
public void testCallstoNextNotAllowed() throws IOException {
int terminateAfter = 1;
MultiDataSetIterator iter =
new MultiDataSetIteratorAdapter(new MnistDataSetIterator(minibatchSize, numExamples));
EarlyTerminationMultiDataSetIterator earlyEndIter =
new EarlyTerminationMultiDataSetIterator(iter, terminateAfter);
earlyEndIter.next(10);
iter.reset();
Assertions.assertThrows(RuntimeException.class, () -> {
2021-03-16 11:57:24 +09:00
earlyEndIter.next(10);
});
2019-06-06 15:21:15 +03:00
}
2019-06-06 15:21:15 +03:00
}