2021-03-18 16:45:39 +09:00

738 lines
39 KiB
Java

/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.datavec;
import org.junit.jupiter.api.Disabled;
import org.nd4j.shade.guava.io.Files;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.resources.Resources;
import java.io.*;
import java.net.URI;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import static org.nd4j.linalg.indexing.NDArrayIndex.all;
import static org.nd4j.linalg.indexing.NDArrayIndex.interval;
import static org.nd4j.linalg.indexing.NDArrayIndex.point;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Record Reader Multi Data Set Iterator Test")
@Disabled
class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@TempDir
public Path temporaryFolder;
@Test
@DisplayName("Tests Basic")
void testsBasic() throws Exception {
// Load details from CSV files; single input/output -> compare to RecordReaderDataSetIterator
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(rrmdsi.hasNext());
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader2).addSequenceReader("out", labelReader2).addInput("in").addOutputOneHot("out", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(1, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray fmds = mds.getFeatures(0);
INDArray lmds = mds.getLabels(0);
assertNotNull(fmds);
assertNotNull(lmds);
assertEquals(fds, fmds);
assertEquals(lds, lmds);
}
assertFalse(srrmdsi.hasNext());
}
@Test
@DisplayName("Tests Basic Meta")
void testsBasicMeta() throws Exception {
// As per testBasic - but also loading metadata
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(150 / 10, count);
}
@Test
@DisplayName("Test Splitting CSV")
void testSplittingCSV() throws Exception {
// Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
// Inputs: columns 0 and 1-2
// Outputs: columns 3, and 4->OneHot
// need to manually extract
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderDataSetIterator rrdsi = new RecordReaderDataSetIterator(rr, 10, 4, 3);
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
while (rrdsi.hasNext()) {
DataSet ds = rrdsi.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = rrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(2, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
// Get the subsets of the original iris data
INDArray expIn1 = fds.get(all(), interval(0, 0, true));
INDArray expIn2 = fds.get(all(), interval(1, 2, true));
INDArray expOut1 = fds.get(all(), interval(3, 3, true));
INDArray expOut2 = lds;
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(expOut1, lmds[0]);
assertEquals(expOut2, lmds[1]);
}
assertFalse(rrmdsi.hasNext());
}
@Test
@DisplayName("Test Splitting CSV Meta")
void testSplittingCSVMeta() throws Exception {
// Here's the idea: take Iris, and split it up into 2 inputs and 2 output arrays
// Inputs: columns 0 and 1-2
// Outputs: columns 3, and 4->OneHot
RecordReader rr2 = new CSVRecordReader(0, ',');
rr2.initialize(new FileSplit(Resources.asFile("iris.txt")));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("reader", rr2).addInput("reader", 0, 0).addInput("reader", 1, 2).addOutput("reader", 3, 3).addOutputOneHot("reader", 4, 3).build();
rrmdsi.setCollectMetaData(true);
int count = 0;
while (rrmdsi.hasNext()) {
MultiDataSet mds = rrmdsi.next();
MultiDataSet fromMeta = rrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(150 / 10, count);
}
@Test
@DisplayName("Test Splitting CSV Sequence")
void testSplittingCSVSequence() throws Exception {
// Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
// as standard one-hot output
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false);
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
MultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
while (iter.hasNext()) {
DataSet ds = iter.next();
INDArray fds = ds.getFeatures();
INDArray lds = ds.getLabels();
MultiDataSet mds = srrmdsi.next();
assertEquals(2, mds.getFeatures().length);
assertEquals(1, mds.getLabels().length);
assertNull(mds.getFeaturesMaskArrays());
assertNull(mds.getLabelsMaskArrays());
INDArray[] fmds = mds.getFeatures();
INDArray[] lmds = mds.getLabels();
assertNotNull(fmds);
assertNotNull(lmds);
for (int i = 0; i < fmds.length; i++) assertNotNull(fmds[i]);
for (int i = 0; i < lmds.length; i++) assertNotNull(lmds[i]);
INDArray expIn1 = fds.get(all(), NDArrayIndex.interval(0, 1, true), all());
INDArray expIn2 = fds.get(all(), NDArrayIndex.interval(2, 2, true), all());
assertEquals(expIn1, fmds[0]);
assertEquals(expIn2, fmds[1]);
assertEquals(lds, lmds[0]);
}
assertFalse(srrmdsi.hasNext());
}
@Test
@DisplayName("Test Splitting CSV Sequence Meta")
void testSplittingCSVSequenceMeta() throws Exception {
// Idea: take CSV sequences, and split "csvsequence_i.txt" into two separate inputs; keep "csvSequencelables_i.txt"
// as standard one-hot output
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabels_%d.txt");
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator srrmdsi = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("seq1", featureReader2).addSequenceReader("seq2", labelReader2).addInput("seq1", 0, 1).addInput("seq1", 2, 2).addOutputOneHot("seq2", 0, 4).build();
srrmdsi.setCollectMetaData(true);
int count = 0;
while (srrmdsi.hasNext()) {
MultiDataSet mds = srrmdsi.next();
MultiDataSet fromMeta = srrmdsi.loadFromMetaData(mds.getExampleMetaData(RecordMetaData.class));
assertEquals(mds, fromMeta);
count++;
}
assertEquals(3, count);
}
@Test
@DisplayName("Test Input Validation")
void testInputValidation() {
// Test: no readers
try {
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addInput("something").addOutput("something").build();
fail("Should have thrown exception");
} catch (Exception e) {
}
// Test: reference to reader that doesn't exist
try {
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).addInput("thisDoesntExist", 0, 3).addOutputOneHot("iris", 4, 3).build();
fail("Should have thrown exception");
} catch (Exception e) {
}
// Test: no inputs or outputs
try {
RecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(Resources.asFile("iris.txt")));
MultiDataSetIterator r = new RecordReaderMultiDataSetIterator.Builder(1).addReader("iris", rr).build();
fail("Should have thrown exception");
} catch (Exception e) {
}
}
@Test
@DisplayName("Test Variable Length TS")
void testVariableLengthTS() throws Exception {
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");
// Set up SequenceRecordReaderDataSetIterators for comparison
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader2 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader2 = new CSVSequenceRecordReader(1, ",");
featureReader2.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader2.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReaderDataSetIterator iterAlignStart = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_START);
SequenceRecordReaderDataSetIterator iterAlignEnd = new SequenceRecordReaderDataSetIterator(featureReader2, labelReader2, 1, 4, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
while (iterAlignStart.hasNext()) {
DataSet dsStart = iterAlignStart.next();
DataSet dsEnd = iterAlignEnd.next();
MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next();
assertEquals(1, mdsStart.getFeatures().length);
assertEquals(1, mdsStart.getLabels().length);
// assertEquals(1, mdsStart.getFeaturesMaskArrays().length); //Features data is always longer -> don't need mask arrays for it
assertEquals(1, mdsStart.getLabelsMaskArrays().length);
assertEquals(1, mdsEnd.getFeatures().length);
assertEquals(1, mdsEnd.getLabels().length);
// assertEquals(1, mdsEnd.getFeaturesMaskArrays().length);
assertEquals(1, mdsEnd.getLabelsMaskArrays().length);
assertEquals(dsStart.getFeatures(), mdsStart.getFeatures(0));
assertEquals(dsStart.getLabels(), mdsStart.getLabels(0));
assertEquals(dsStart.getLabelsMaskArray(), mdsStart.getLabelsMaskArray(0));
assertEquals(dsEnd.getFeatures(), mdsEnd.getFeatures(0));
assertEquals(dsEnd.getLabels(), mdsEnd.getLabels(0));
assertEquals(dsEnd.getLabelsMaskArray(), mdsEnd.getLabelsMaskArray(0));
}
assertFalse(rrmdsiStart.hasNext());
assertFalse(rrmdsiEnd.hasNext());
}
@Test
@DisplayName("Test Variable Length TS Meta")
void testVariableLengthTSMeta() throws Exception {
// need to manually extract
File rootDir = temporaryFolder.toFile();
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabels_%d.txt", i)).getTempFileFromArchive(rootDir);
new ClassPathResource(String.format("csvsequencelabelsShort_%d.txt", i)).getTempFileFromArchive(rootDir);
}
// Set up SequenceRecordReaderDataSetIterators for comparison
String featuresPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequence_%d.txt");
String labelsPath = FilenameUtils.concat(rootDir.getAbsolutePath(), "csvsequencelabelsShort_%d.txt");
// Set up
SequenceRecordReader featureReader3 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader3 = new CSVSequenceRecordReader(1, ",");
featureReader3.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader3.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
SequenceRecordReader featureReader4 = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader4 = new CSVSequenceRecordReader(1, ",");
featureReader4.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
labelReader4.initialize(new NumberedFileInputSplit(labelsPath, 0, 2));
RecordReaderMultiDataSetIterator rrmdsiStart = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader3).addSequenceReader("out", labelReader3).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_START).build();
RecordReaderMultiDataSetIterator rrmdsiEnd = new RecordReaderMultiDataSetIterator.Builder(1).addSequenceReader("in", featureReader4).addSequenceReader("out", labelReader4).addInput("in").addOutputOneHot("out", 0, 4).sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END).build();
rrmdsiStart.setCollectMetaData(true);
rrmdsiEnd.setCollectMetaData(true);
int count = 0;
while (rrmdsiStart.hasNext()) {
MultiDataSet mdsStart = rrmdsiStart.next();
MultiDataSet mdsEnd = rrmdsiEnd.next();
MultiDataSet mdsStartFromMeta = rrmdsiStart.loadFromMetaData(mdsStart.getExampleMetaData(RecordMetaData.class));
MultiDataSet mdsEndFromMeta = rrmdsiEnd.loadFromMetaData(mdsEnd.getExampleMetaData(RecordMetaData.class));
assertEquals(mdsStart, mdsStartFromMeta);
assertEquals(mdsEnd, mdsEndFromMeta);
count++;
}
assertFalse(rrmdsiStart.hasNext());
assertFalse(rrmdsiEnd.hasNext());
assertEquals(3, count);
}
@Test
@DisplayName("Test Images RRDMSI")
void testImagesRRDMSI() throws Exception {
File parentDir = temporaryFolder.toFile();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
Random r = new Random(12345);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
rr1.initialize(new FileSplit(parentDir));
rr1s.initialize(new FileSplit(parentDir));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(1).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
// Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 1, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 1, 1, 2);
for (int i = 0; i < 2; i++) {
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatures(), mds.getFeatures(0));
assertEquals(d2.getFeatures(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
}
}
@Test
@DisplayName("Test Images RRDMSI _ Batched")
void testImagesRRDMSI_Batched() throws Exception {
File parentDir = temporaryFolder.toFile();
parentDir.deleteOnExit();
String str1 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Zico/");
String str2 = FilenameUtils.concat(parentDir.getAbsolutePath(), "Ziwang_Xu/");
File f1 = new File(str1);
File f2 = new File(str2);
f1.mkdirs();
f2.mkdirs();
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f1.getPath(), "Zico_0001.jpg")), new ClassPathResource("lfwtest/Zico/Zico_0001.jpg").getInputStream());
TestUtils.writeStreamToFile(new File(FilenameUtils.concat(f2.getPath(), "Ziwang_Xu_0001.jpg")), new ClassPathResource("lfwtest/Ziwang_Xu/Ziwang_Xu_0001.jpg").getInputStream());
int outputNum = 2;
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr1 = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s = new ImageRecordReader(5, 5, 1, labelMaker);
URI[] uris = new FileSplit(parentDir).locations();
rr1.initialize(new CollectionInputSplit(uris));
rr1s.initialize(new CollectionInputSplit(uris));
MultiDataSetIterator trainDataIterator = new RecordReaderMultiDataSetIterator.Builder(2).addReader("rr1", rr1).addReader("rr1s", rr1s).addInput("rr1", 0, 0).addInput("rr1s", 0, 0).addOutputOneHot("rr1s", 1, outputNum).build();
// Now, do the same thing with ImageRecordReader, and check we get the same results:
ImageRecordReader rr1_b = new ImageRecordReader(10, 10, 1, labelMaker);
ImageRecordReader rr1s_b = new ImageRecordReader(5, 5, 1, labelMaker);
rr1_b.initialize(new FileSplit(parentDir));
rr1s_b.initialize(new FileSplit(parentDir));
DataSetIterator dsi1 = new RecordReaderDataSetIterator(rr1_b, 2, 1, 2);
DataSetIterator dsi2 = new RecordReaderDataSetIterator(rr1s_b, 2, 1, 2);
MultiDataSet mds = trainDataIterator.next();
DataSet d1 = dsi1.next();
DataSet d2 = dsi2.next();
assertEquals(d1.getFeatures(), mds.getFeatures(0));
assertEquals(d2.getFeatures(), mds.getFeatures(1));
assertEquals(d1.getLabels(), mds.getLabels(0));
// Check label assignment:
File currentFile = rr1_b.getCurrentFile();
INDArray expLabels;
if (currentFile.getAbsolutePath().contains("Zico")) {
expLabels = Nd4j.create(new double[][] { { 0, 1 }, { 1, 0 } });
} else {
expLabels = Nd4j.create(new double[][] { { 1, 0 }, { 0, 1 } });
}
assertEquals(expLabels, d1.getLabels());
assertEquals(expLabels, d2.getLabels());
}
@Test
@DisplayName("Test Time Series Random Offset")
void testTimeSeriesRandomOffset() {
// 2 in, 2 out, 3 total sequences of length [1,3,5]
List<List<Writable>> seq1 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(2.0)));
List<List<Writable>> seq2 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(10.0), new DoubleWritable(11.0)), Arrays.<Writable>asList(new DoubleWritable(20.0), new DoubleWritable(21.0)), Arrays.<Writable>asList(new DoubleWritable(30.0), new DoubleWritable(31.0)));
List<List<Writable>> seq3 = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(100.0), new DoubleWritable(101.0)), Arrays.<Writable>asList(new DoubleWritable(200.0), new DoubleWritable(201.0)), Arrays.<Writable>asList(new DoubleWritable(300.0), new DoubleWritable(301.0)), Arrays.<Writable>asList(new DoubleWritable(400.0), new DoubleWritable(401.0)), Arrays.<Writable>asList(new DoubleWritable(500.0), new DoubleWritable(501.0)));
Collection<List<List<Writable>>> seqs = Arrays.asList(seq1, seq2, seq3);
SequenceRecordReader rr = new CollectionSequenceRecordReader(seqs);
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(3).addSequenceReader("rr", rr).addInput("rr", 0, 0).addOutput("rr", 1, 1).timeSeriesRandomOffset(true, 1234L).build();
// Provides seed for each minibatch
Random r = new Random(1234);
long seed = r.nextLong();
// Use same RNG seed in new RNG for each minibatch
Random r2 = new Random(seed);
// 0 to 4 inclusive
int expOffsetSeq1 = r2.nextInt(5 - 1 + 1);
int expOffsetSeq2 = r2.nextInt(5 - 3 + 1);
// Longest TS, always 0
int expOffsetSeq3 = 0;
// With current seed: 3, 1, 0
// System.out.println(expOffsetSeq1 + "\t" + expOffsetSeq2 + "\t" + expOffsetSeq3);
MultiDataSet mds = rrmdsi.next();
INDArray expMask = Nd4j.create(new double[][] { { 0, 0, 0, 1, 0 }, { 0, 1, 1, 1, 0 }, { 1, 1, 1, 1, 1 } });
assertEquals(expMask, mds.getFeaturesMaskArray(0));
assertEquals(expMask, mds.getLabelsMaskArray(0));
INDArray f = mds.getFeatures(0);
INDArray l = mds.getLabels(0);
INDArray expF1 = Nd4j.create(new double[] { 1.0 }, new int[] { 1, 1 });
INDArray expL1 = Nd4j.create(new double[] { 2.0 }, new int[] { 1, 1 });
INDArray expF2 = Nd4j.create(new double[] { 10, 20, 30 }, new int[] { 1, 3 });
INDArray expL2 = Nd4j.create(new double[] { 11, 21, 31 }, new int[] { 1, 3 });
INDArray expF3 = Nd4j.create(new double[] { 100, 200, 300, 400, 500 }, new int[] { 1, 5 });
INDArray expL3 = Nd4j.create(new double[] { 101, 201, 301, 401, 501 }, new int[] { 1, 5 });
assertEquals(expF1, f.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
assertEquals(expL1, l.get(point(0), all(), NDArrayIndex.interval(expOffsetSeq1, expOffsetSeq1 + 1)));
assertEquals(expF2, f.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expL2, l.get(point(1), all(), NDArrayIndex.interval(expOffsetSeq2, expOffsetSeq2 + 3)));
assertEquals(expF3, f.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
assertEquals(expL3, l.get(point(2), all(), NDArrayIndex.interval(expOffsetSeq3, expOffsetSeq3 + 5)));
}
@Test
@DisplayName("Test Seq RRDSI Masking")
void testSeqRRDSIMasking() {
// This also tests RecordReaderMultiDataSetIterator, by virtue of
List<List<List<Writable>>> features = new ArrayList<>();
List<List<List<Writable>>> labels = new ArrayList<>();
features.add(Arrays.asList(l(new DoubleWritable(1)), l(new DoubleWritable(2)), l(new DoubleWritable(3))));
features.add(Arrays.asList(l(new DoubleWritable(4)), l(new DoubleWritable(5))));
labels.add(Arrays.asList(l(new IntWritable(0))));
labels.add(Arrays.asList(l(new IntWritable(1))));
CollectionSequenceRecordReader fR = new CollectionSequenceRecordReader(features);
CollectionSequenceRecordReader lR = new CollectionSequenceRecordReader(labels);
SequenceRecordReaderDataSetIterator seqRRDSI = new SequenceRecordReaderDataSetIterator(fR, lR, 2, 2, false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
DataSet ds = seqRRDSI.next();
INDArray fMask = Nd4j.create(new double[][] { { 1, 1, 1 }, { 1, 1, 0 } });
INDArray lMask = Nd4j.create(new double[][] { { 0, 0, 1 }, { 0, 1, 0 } });
assertEquals(fMask, ds.getFeaturesMaskArray());
assertEquals(lMask, ds.getLabelsMaskArray());
INDArray f = Nd4j.create(new double[][] { { 1, 2, 3 }, { 4, 5, 0 } });
INDArray l = Nd4j.create(2, 2, 3);
l.putScalar(0, 0, 2, 1.0);
l.putScalar(1, 1, 1, 1.0);
assertEquals(f, ds.getFeatures().get(all(), point(0), all()));
assertEquals(l, ds.getLabels());
}
private static List<Writable> l(Writable... in) {
return Arrays.asList(in);
}
@Test
@DisplayName("Test Exclude String Col CSV")
void testExcludeStringColCSV() throws Exception {
File csvFile = temporaryFolder.toFile();
StringBuilder sb = new StringBuilder();
for (int i = 1; i <= 10; i++) {
if (i > 1) {
sb.append("\n");
}
sb.append("skip_").append(i).append(",").append(i).append(",").append(i + 0.5);
}
FileUtils.writeStringToFile(csvFile, sb.toString());
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(csvFile));
RecordReaderMultiDataSetIterator rrmdsi = new RecordReaderMultiDataSetIterator.Builder(10).addReader("rr", rr).addInput("rr", 1, 1).addOutput("rr", 2, 2).build();
INDArray expFeatures = Nd4j.linspace(1, 10, 10).reshape(1, 10).transpose();
INDArray expLabels = Nd4j.linspace(1, 10, 10).addi(0.5).reshape(1, 10).transpose();
MultiDataSet mds = rrmdsi.next();
assertFalse(rrmdsi.hasNext());
assertEquals(expFeatures, mds.getFeatures(0).castTo(expFeatures.dataType()));
assertEquals(expLabels, mds.getLabels(0).castTo(expLabels.dataType()));
}
private static final int nX = 32;
private static final int nY = 32;
private static final int nZ = 28;
@Test
@DisplayName("Test RRMDSI 5 D")
void testRRMDSI5D() {
int batchSize = 5;
CustomRecordReader recordReader = new CustomRecordReader();
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, /* Index of label in records */
2);
int count = 0;
while (dataIter.hasNext()) {
DataSet ds = dataIter.next();
int offset = 5 * count;
for (int i = 0; i < 5; i++) {
INDArray act = ds.getFeatures().get(interval(i, i, true), all(), all(), all(), all());
INDArray exp = Nd4j.valueArrayOf(new int[] { 1, 1, nZ, nX, nY }, i + offset);
assertEquals(exp, act);
}
count++;
}
assertEquals(2, count);
}
@DisplayName("Custom Record Reader")
static class CustomRecordReader extends BaseRecordReader {
int n = 0;
CustomRecordReader() {
}
@Override
public boolean batchesSupported() {
return false;
}
@Override
public List<List<Writable>> next(int num) {
throw new RuntimeException("Not implemented");
}
@Override
public List<Writable> next() {
INDArray nd = Nd4j.create(new float[nZ * nY * nX], new int[] { 1, 1, nZ, nY, nX }, 'c').assign(n);
final List<Writable> res = RecordConverter.toRecord(nd);
res.add(new IntWritable(0));
n++;
return res;
}
@Override
public boolean hasNext() {
return n < 10;
}
final static ArrayList<String> labels = new ArrayList<>(2);
static {
labels.add("lbl0");
labels.add("lbl1");
}
@Override
public List<String> getLabels() {
return labels;
}
@Override
public void reset() {
n = 0;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) {
return next();
}
@Override
public Record nextRecord() {
List<Writable> r = next();
return new org.datavec.api.records.impl.Record(r, null);
}
@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
throw new RuntimeException("Not implemented");
}
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) {
throw new RuntimeException("Not implemented");
}
@Override
public void close() {
}
@Override
public void setConf(Configuration conf) {
}
@Override
public Configuration getConf() {
return null;
}
@Override
public void initialize(InputSplit split) {
n = 0;
}
@Override
public void initialize(Configuration conf, InputSplit split) {
n = 0;
}
}
}