cavis/datavec/datavec-spark/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java

126 lines
5.1 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.spark.functions;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.input.PortableDataStream;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import org.datavec.codec.reader.CodecRecordReader;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
Refactor packages to fix split package issues (#411) * Refactor nd4j-common: org.nd4j.* -> org.nd4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * Fix CUDA (missed nd4j-common package refactoring changes) Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-kryo: org.nd4j -> org.nd4j.kryo Signed-off-by: Alex Black <blacka101@gmail.com> * Fix nd4j-common for deeplearning4j-cuda Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-grppc-client: org.nd4j.graph -> org.nd4j.remote.grpc Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-common: org.deeplearning4.* -> org.deeplearning4j.common.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-core: org.deeplearning4j.* -> org.deeplearning.core.* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-cuda: org.deeplearning4j.nn.layers.* -> org.deeplearning4j.cuda.* Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-nlp-*: org.deeplearning4.text.* -> org.deeplearning4j.nlp.(language).* Signed-off-by: Alex Black <blacka101@gmail.com> * deeplearning4j-ui-model: org.deeplearning4j.ui -> org.deeplearning4j.ui.model Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-spark-inference-{server/model/client}: org.datavec.spark.transform -> org.datavec.spark.inference.{server/model/client} Signed-off-by: Alex Black <blacka101@gmail.com> * datavec-jdbc: org.datavec.api -> org.datavec.jdbc Signed-off-by: Alex Black <blacka101@gmail.com> * Delete org.deeplearning4j.datasets.iterator.impl.MultiDataSetIteratorAdapter in favor of (essentially identical) org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter Signed-off-by: Alex Black <blacka101@gmail.com> * ND4S fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * nd4j-common-tests: org.nd4j.* -> org.nd4j.common.tests Signed-off-by: Alex Black <blacka101@gmail.com> * Trigger CI Signed-off-by: Alex Black <blacka101@gmail.com> * Fixes Signed-off-by: Alex Black <blacka101@gmail.com> * #8878 Ignore CUDA tests on modules with 'nd4j-native under cuda' issue Signed-off-by: Alex Black <blacka101@gmail.com> * Fix bad imports in tests Signed-off-by: Alex Black <blacka101@gmail.com> * Add ignore on test (already failing) due to #8882 Signed-off-by: Alex Black <blacka101@gmail.com> * Import fixes Signed-off-by: Alex Black <blacka101@gmail.com> * Additional import fixes Signed-off-by: Alex Black <blacka101@gmail.com>
2020-04-29 11:19:26 +10:00
import org.nd4j.common.io.ClassPathResource;
2019-06-06 15:21:15 +03:00
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testRecordReaderBytesFunction() throws Exception {
//Local file path
File f = testDir.newFolder();
new ClassPathResource("datavec-spark/video/").copyDirectory(f);
String path = f.getAbsolutePath() + "/*";
//Load binary data from local file system, convert to a sequence file:
//Load and convert
JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
JavaPairRDD<Text, BytesWritable> filesAsBytes = origData.mapToPair(new FilesAsBytesFunction());
//Write the sequence file:
Path p = Files.createTempDirectory("dl4j_rrbytesTest");
p.toFile().deleteOnExit();
String outPath = p.toString() + "/out";
filesAsBytes.saveAsNewAPIHadoopFile(outPath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class);
//Load data from sequence file, parse via SequenceRecordReader:
JavaPairRDD<Text, BytesWritable> fromSeqFile = sc.sequenceFile(outPath, Text.class, BytesWritable.class);
SequenceRecordReader seqRR = new CodecRecordReader();
Configuration conf = new Configuration();
conf.set(CodecRecordReader.RAVEL, "true");
conf.set(CodecRecordReader.START_FRAME, "0");
conf.set(CodecRecordReader.TOTAL_FRAMES, "25");
conf.set(CodecRecordReader.ROWS, "64");
conf.set(CodecRecordReader.COLUMNS, "64");
Configuration confCopy = new Configuration(conf);
seqRR.setConf(conf);
JavaRDD<List<List<Writable>>> dataVecData = fromSeqFile.map(new SequenceRecordReaderBytesFunction(seqRR));
//Next: do the same thing locally, and compare the results
InputSplit is = new FileSplit(f, new String[] {"mp4"}, true);
SequenceRecordReader srr = new CodecRecordReader();
srr.initialize(is);
srr.setConf(confCopy);
List<List<List<Writable>>> list = new ArrayList<>(4);
while (srr.hasNext()) {
list.add(srr.sequenceRecord());
}
assertEquals(4, list.size());
List<List<List<Writable>>> fromSequenceFile = dataVecData.collect();
assertEquals(4, list.size());
assertEquals(4, fromSequenceFile.size());
boolean[] found = new boolean[4];
for (int i = 0; i < 4; i++) {
int foundIndex = -1;
List<List<Writable>> collection = fromSequenceFile.get(i);
for (int j = 0; j < 4; j++) {
if (collection.equals(list.get(j))) {
if (foundIndex != -1)
fail(); //Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
foundIndex = j;
if (found[foundIndex])
fail(); //One of the other spark values was equal to this one -> suggests duplicates in Spark list
found[foundIndex] = true; //mark this one as seen before
}
}
}
int count = 0;
for (boolean b : found)
if (b)
count++;
assertEquals(4, count); //Expect all 4 and exactly 4 pairwise matches between spark and local versions
}
}