Add tags for junit 5

master
agibsonccc 2021-03-20 19:06:24 +09:00
parent 3c205548af
commit 5e8951cd8e
773 changed files with 4684 additions and 1152 deletions

View File

@ -12,7 +12,7 @@ DL4J was a junit 4 based code based for testing.
It's now based on junit 5's jupiter API, which has support for [Tags](https://junit.org/junit5/docs/5.0.1/api/org/junit/jupiter/api/Tag.html).
DL4j's code base has a number of different kinds of tests that fall in to several categories:
1. Long and flaky involving distributed systems (spark, parameter server)
1. Long and flaky involving distributed systems (spark, parameter-server)
2. Code that requires large downloads, but runs quickly
3. Quick tests that test basic functionality
4. Comprehensive integration tests that test several parts of a code base
@ -38,8 +38,10 @@ A few kinds of tags exist:
3. Distributed systems: spark, multi-threaded
4. Functional cross-cutting concerns: multi module tests, similar functionality (excludes time based)
5. Platform specific tests that can vary on different hardware: cpu, gpu
6. JVM crash: Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash
6. JVM crash: (jvm-crash) Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash
7. RNG: (rng) for RNG related tests
8. Samediff:(samediff) samediff related tests
9. Training related functionality
## Consequences

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
@ -38,8 +39,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Line Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
@TempDir
@ -54,8 +58,8 @@ class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
rr.initialize(new FileSplit(source));
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
List<List<Writable>> exp0 = Arrays.asList(Collections.singletonList(new Text("a")), Collections.singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(Collections.singletonList(new Text("1")), Collections.singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
for (int i = 0; i < 3; i++) {
int count = 0;
while (rr.hasNext()) {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
@ -41,8 +42,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Multi Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
@TempDir

View File

@ -26,6 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -34,12 +35,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csvn Lines Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test
@DisplayName("Test CSVN Lines Sequence Record Reader")
@DisplayName("Test CSV Lines Sequence Record Reader")
void testCSVNLinesSequenceRecordReader() throws Exception {
int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);

View File

@ -34,9 +34,12 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
@ -50,6 +53,8 @@ import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Csv Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVRecordReaderTest extends BaseND4JTest {
@Test

View File

@ -27,6 +27,7 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
@ -43,8 +44,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVSequenceRecordReaderTest extends BaseND4JTest {
@TempDir

View File

@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -32,8 +33,11 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Variable Sliding Window Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test

View File

@ -28,6 +28,7 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
@ -42,9 +43,12 @@ import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.factory.Nd4jBackend;
@DisplayName("File Batch Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class FileBatchRecordReaderTest extends BaseND4JTest {
@TempDir Path testDir;

View File

@ -25,6 +25,7 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("File Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class FileRecordReaderTest extends BaseND4JTest {
@Test

View File

@ -28,10 +28,12 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.File;
@ -45,6 +47,8 @@ import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Jackson Line Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class JacksonLineRecordReaderTest extends BaseND4JTest {
@TempDir

View File

@ -31,10 +31,12 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
@ -51,6 +53,8 @@ import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Jackson Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class JacksonRecordReaderTest extends BaseND4JTest {
@TempDir

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Lib Svm Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class LibSvmRecordReaderTest extends BaseND4JTest {
@Test

View File

@ -30,6 +30,7 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
@ -47,8 +48,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Line Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class LineReaderTest extends BaseND4JTest {

View File

@ -33,6 +33,7 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
@ -46,8 +47,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Regex Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class RegexRecordReaderTest extends BaseND4JTest {
@TempDir

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Svm Light Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class SVMLightRecordReaderTest extends BaseND4JTest {
@Test

View File

@ -26,15 +26,18 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestCollectionRecordReaders extends BaseND4JTest {
@Test

View File

@ -23,12 +23,15 @@ package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestConcatenatingRecordReader extends BaseND4JTest {
@Test

View File

@ -37,9 +37,11 @@ import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -48,7 +50,8 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestSerialization extends BaseND4JTest {
@Test

View File

@ -30,9 +30,11 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
@ -41,6 +43,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TransformProcessRecordReaderTests extends BaseND4JTest {
@Test
@ -74,11 +78,11 @@ public class TransformProcessRecordReaderTests extends BaseND4JTest {
public void simpleTransformTestSequence() {
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2),
new IntWritable(0)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
@ -34,8 +35,11 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Record Writer Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVRecordWriterTest extends BaseND4JTest {
@BeforeEach

View File

@ -29,8 +29,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
@ -46,6 +48,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Lib Svm Record Writer Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class LibSvmRecordWriterTest extends BaseND4JTest {
@Test

View File

@ -26,8 +26,10 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
@ -43,6 +45,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Svm Light Record Writer Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class SVMLightRecordWriterTest extends BaseND4JTest {
@Test

View File

@ -20,7 +20,9 @@
package org.datavec.api.split;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter;
@ -42,6 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
*
* @author saudet
*/
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class InputSplitTests extends BaseND4JTest {
@Test

View File

@ -20,13 +20,16 @@
package org.datavec.api.split;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.net.URI;
import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test
public void testNumberedFileInputSplitBasic() {

View File

@ -26,11 +26,13 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.function.Function;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.io.FileInputStream;
@ -46,7 +48,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestStreamInputSplit extends BaseND4JTest {

View File

@ -19,6 +19,7 @@
*/
package org.datavec.api.split;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI;
@ -28,11 +29,14 @@ import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
/**
* @author Ede Meijer
*/
@DisplayName("Transform Split Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class TransformSplitTest extends BaseND4JTest {
@Test

View File

@ -20,7 +20,9 @@
package org.datavec.api.split.parittion;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files;
import org.datavec.api.conf.Configuration;
import org.datavec.api.split.FileSplit;
@ -33,7 +35,8 @@ import java.io.File;
import java.io.OutputStream;
import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class PartitionerTests extends BaseND4JTest {
@Test
public void testRecordsPerFilePartition() {

View File

@ -29,13 +29,16 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestTransformProcess extends BaseND4JTest {
@Test

View File

@ -27,14 +27,17 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestConditions extends BaseND4JTest {
@Test

View File

@ -27,8 +27,10 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
@ -38,7 +40,8 @@ import java.util.List;
import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestFilters extends BaseND4JTest {

View File

@ -26,9 +26,11 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.nio.file.Path;
import java.util.ArrayList;
@ -37,7 +39,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestJoin extends BaseND4JTest {
@Test

View File

@ -33,14 +33,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestMultiOpReduce extends BaseND4JTest {
@Test

View File

@ -24,14 +24,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestReductions extends BaseND4JTest {
@Test

View File

@ -22,11 +22,15 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
public class TestJsonYaml extends BaseND4JTest {
@Test

View File

@ -21,11 +21,14 @@
package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestSchemaMethods extends BaseND4JTest {
@Test

View File

@ -33,8 +33,10 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
@ -42,7 +44,8 @@ import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
@Test

View File

@ -27,8 +27,10 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
@ -36,7 +38,8 @@ import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestSequenceSplit extends BaseND4JTest {
@Test
@ -46,13 +49,13 @@ public class TestSequenceSplit extends BaseND4JTest {
.build();
List<List<Writable>> inputSequence = new ArrayList<>();
inputSequence.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0")));
inputSequence.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1")));
inputSequence.add(Arrays.asList(new LongWritable(0), new Text("t0")));
inputSequence.add(Arrays.asList(new LongWritable(1000), new Text("t1")));
//Second split: 74 seconds later
inputSequence.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2")));
inputSequence.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3")));
inputSequence.add(Arrays.asList(new LongWritable(75000), new Text("t2")));
inputSequence.add(Arrays.asList(new LongWritable(100000), new Text("t3")));
//Third split: 1 minute and 1 milliseconds later
inputSequence.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4")));
inputSequence.add(Arrays.asList(new LongWritable(160001), new Text("t4")));
SequenceSplit seqSplit = new SequenceSplitTimeSeparation("time", 1, TimeUnit.MINUTES);
seqSplit.setInputSchema(schema);
@ -61,13 +64,13 @@ public class TestSequenceSplit extends BaseND4JTest {
assertEquals(3, splits.size());
List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0")));
exp0.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1")));
exp0.add(Arrays.asList(new LongWritable(0), new Text("t0")));
exp0.add(Arrays.asList(new LongWritable(1000), new Text("t1")));
List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2")));
exp1.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3")));
exp1.add(Arrays.asList(new LongWritable(75000), new Text("t2")));
exp1.add(Arrays.asList(new LongWritable(100000), new Text("t3")));
List<List<Writable>> exp2 = new ArrayList<>();
exp2.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4")));
exp2.add(Arrays.asList(new LongWritable(160001), new Text("t4")));
assertEquals(exp0, splits.get(0));
assertEquals(exp1, splits.get(1));

View File

@ -29,8 +29,10 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
@ -38,7 +40,8 @@ import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestWindowFunctions extends BaseND4JTest {
@Test
@ -49,15 +52,15 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data.
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
//Second window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
//Third window: empty
//Fourth window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").build();
@ -100,15 +103,15 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data.
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
//Second window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
//Third window: empty
//Fourth window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").build();
@ -150,15 +153,15 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data.
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
//Second window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
//Third window: empty
//Fourth window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").build();
@ -188,13 +191,13 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data.
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
@ -207,32 +210,32 @@ public class TestWindowFunctions extends BaseND4JTest {
//First window: -1000 to 1000
List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
//Second window: 0 to 2000
List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
//Third window: 1000 to 3000
List<List<Writable>> exp2 = new ArrayList<>();
exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fourth window: 2000 to 4000
List<List<Writable>> exp3 = new ArrayList<>();
exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fifth window: 3000 to 5000
List<List<Writable>> exp4 = new ArrayList<>();
//Sixth window: 4000 to 6000
List<List<Writable>> exp5 = new ArrayList<>();
exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
//Seventh window: 5000 to 7000
List<List<Writable>> exp6 = new ArrayList<>();
exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp4, exp5, exp6);
@ -250,13 +253,13 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data.
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
@ -272,31 +275,31 @@ public class TestWindowFunctions extends BaseND4JTest {
//First window: -1000 to 1000
List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
//Second window: 0 to 2000
List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
//Third window: 1000 to 3000
List<List<Writable>> exp2 = new ArrayList<>();
exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fourth window: 2000 to 4000
List<List<Writable>> exp3 = new ArrayList<>();
exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fifth window: 3000 to 5000 -> Empty: excluded
//Sixth window: 4000 to 6000
List<List<Writable>> exp5 = new ArrayList<>();
exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
//Seventh window: 5000 to 7000
List<List<Writable>> exp6 = new ArrayList<>();
exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp5, exp6);

View File

@ -26,11 +26,17 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
@Tag(TagNames.CUSTOM_FUNCTIONALITY)
public class TestCustomTransformJsonYaml extends BaseND4JTest {
@Test

View File

@ -64,14 +64,19 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
public class TestYamlJsonSerde extends BaseND4JTest {
public static YamlSerializer y = new YamlSerializer();

View File

@ -24,22 +24,26 @@ import org.datavec.api.transform.StringReduceOp;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestReduce extends BaseND4JTest {
@Test
public void testReducerDouble() {
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2")));
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2")));
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2")));
inputs.add(Arrays.asList(new Text("1"), new Text("2")));
inputs.add(Arrays.asList(new Text("1"), new Text("2")));
inputs.add(Arrays.asList(new Text("1"), new Text("2")));
Map<StringReduceOp, String> exp = new LinkedHashMap<>();
exp.put(StringReduceOp.MERGE, "12");

View File

@ -37,10 +37,12 @@ import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.nio.file.Path;
@ -49,7 +51,9 @@ import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.UI)
public class TestUI extends BaseND4JTest {

View File

@ -20,6 +20,7 @@
package org.datavec.api.util;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.BufferedReader;
@ -33,8 +34,11 @@ import static org.hamcrest.core.AnyOf.anyOf;
import static org.hamcrest.core.IsEqual.equalTo;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Class Path Resource Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class ClassPathResourceTest extends BaseND4JTest {
// File sizes are reported slightly different on Linux vs. Windows

View File

@ -22,8 +22,10 @@ package org.datavec.api.util;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
@ -32,6 +34,8 @@ import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Time Series Utils Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class TimeSeriesUtilsTest extends BaseND4JTest {
@Test

View File

@ -19,7 +19,9 @@
*/
package org.datavec.api.writable;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
@ -36,6 +38,8 @@ import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Record Converter Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class RecordConverterTest extends BaseND4JTest {
@Test

View File

@ -21,15 +21,18 @@
package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.*;
import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
@Test

View File

@ -20,8 +20,10 @@
package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -37,6 +39,8 @@ import org.junit.jupiter.api.DisplayName;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Writable Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class WritableTest extends BaseND4JTest {
@Test

View File

@ -42,9 +42,11 @@ import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
@ -62,6 +64,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
@DisplayName("Arrow Converter Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@ -142,8 +146,8 @@ class ArrowConverterTest extends BaseND4JTest {
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
List<List<Writable>> assertionBatch = new ArrayList<>();
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
assertionBatch.add(Arrays.asList(new IntWritable(0), new IntWritable(0)));
assertionBatch.add(Arrays.asList(new IntWritable(1), new IntWritable(1)));
assertEquals(assertionBatch, batchRecords);
}
@ -156,11 +160,11 @@ class ArrowConverterTest extends BaseND4JTest {
schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
List<List<Writable>> input = Arrays.asList(Arrays.asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
List<Writable> assertion = Arrays.asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.asList(new LongWritable(4), new LongWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion, recordTest);
}
@ -174,11 +178,11 @@ class ArrowConverterTest extends BaseND4JTest {
schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
List<List<Writable>> input = Arrays.asList(Arrays.asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
List<Writable> assertion = Arrays.asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.asList(new IntWritable(4), new IntWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion, recordTest);
}

View File

@ -33,6 +33,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.primitives.Triple;
@ -44,8 +45,11 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Record Mapper Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class RecordMapperTest extends BaseND4JTest {
@Test

View File

@ -30,8 +30,10 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
@ -39,7 +41,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

View File

@ -25,6 +25,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
@ -36,8 +37,12 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Label Generator Test")
@NativeTag
@Tag(TagNames.FILE_IO)
class LabelGeneratorTest {

View File

@ -23,7 +23,10 @@ package org.datavec.image.loader;
import org.apache.commons.io.FilenameUtils;
import org.datavec.api.records.reader.RecordReader;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.dataset.DataSet;
import java.io.File;
@ -39,6 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
/**
*
*/
@NativeTag
@Tag(TagNames.FILE_IO)
public class LoaderTests {
private static void ensureDataAvailable(){
@ -81,7 +86,7 @@ public class LoaderTests {
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
byte[] fullDataExpected = new byte[3073];
FileInputStream inExpected = new FileInputStream(new File(path));
FileInputStream inExpected = new FileInputStream(path);
inExpected.read(fullDataExpected);
byte[] fullDataActual = new byte[3073];
@ -94,7 +99,7 @@ public class LoaderTests {
subDir = "cifar/cifar-10-batches-bin/test_batch.bin";
path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
fullDataExpected = new byte[3073];
inExpected = new FileInputStream(new File(path));
inExpected = new FileInputStream(path);
inExpected.read(fullDataExpected);
fullDataActual = new byte[3073];

View File

@ -21,8 +21,11 @@
package org.datavec.image.loader;
import org.datavec.image.data.Image;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.awt.image.BufferedImage;
@ -34,7 +37,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals;
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestImageLoader {
private static long seed = 10;

View File

@ -31,10 +31,13 @@ import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.resources.Resources;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -60,6 +63,8 @@ import static org.junit.jupiter.api.Assertions.fail;
* @author saudet
*/
@Slf4j
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestNativeImageLoader {
static final long seed = 10;
static final Random rng = new Random(seed);

View File

@ -28,9 +28,12 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
@ -41,6 +44,8 @@ import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("File Batch Record Reader Test")
@NativeTag
@Tag(TagNames.FILE_IO)
class FileBatchRecordReaderTest {
@TempDir

View File

@ -37,9 +37,12 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -54,7 +57,8 @@ import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.*;
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestImageRecordReader {

View File

@ -36,9 +36,12 @@ import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
@ -54,7 +57,8 @@ import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestObjectDetectionRecordReader {

View File

@ -22,10 +22,13 @@ package org.datavec.image.recordreader.objdetect;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.nio.file.Path;
@ -34,7 +37,8 @@ import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestVocLabelProvider {

View File

@ -20,6 +20,7 @@
package org.datavec.image.transform;
import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.List;
@ -28,8 +29,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Json Yaml Test")
@NativeTag
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
class JsonYamlTest {
@Test

View File

@ -22,12 +22,17 @@ package org.datavec.image.transform;
import org.bytedeco.javacv.Frame;
import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Resize Image Transform Test")
@NativeTag
@Tag(TagNames.FILE_IO)
class ResizeImageTransformTest {
@BeforeEach

View File

@ -24,6 +24,7 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.primitives.Pair;
import org.datavec.image.data.ImageWritable;
@ -37,6 +38,8 @@ import java.util.List;
import java.util.Random;
import org.bytedeco.opencv.opencv_core.*;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
@ -46,6 +49,8 @@ import static org.junit.jupiter.api.Assertions.*;
*
* @author saudet
*/
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestImageTransform {
static final long seed = 10;
static final Random rng = new Random(seed);

View File

@ -22,6 +22,7 @@ package org.datavec.poi.excel;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.util.List;
@ -29,8 +30,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Excel Record Reader Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
class ExcelRecordReaderTest {
@Test

View File

@ -26,6 +26,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple;
@ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Excel Record Writer Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
class ExcelRecordWriterTest {
@TempDir

View File

@ -47,17 +47,19 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Jdbc Record Reader Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
class JDBCRecordReaderTest {
@TempDir

View File

@ -36,15 +36,18 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class LocalTransformProcessRecordReaderTests {
@Test
@ -64,11 +67,11 @@ public class LocalTransformProcessRecordReaderTests {
public void simpleTransformTestSequence() {
List<List<Writable>> sequence = new ArrayList<>();
//First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1),
new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2),
new IntWritable(0)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)

View File

@ -30,8 +30,10 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
@ -40,7 +42,8 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestAnalyzeLocal {

View File

@ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.util.HashSet;
@ -38,7 +40,8 @@ import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestLineRecordReaderFunction {
@Test

View File

@ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -34,7 +37,8 @@ import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@NativeTag
public class TestNDArrayToWritablesFunction {
@Test
@ -50,7 +54,7 @@ public class TestNDArrayToWritablesFunction {
@Test
public void testNDArrayToWritablesArray() throws Exception {
INDArray arr = Nd4j.arange(5);
List<Writable> expected = Arrays.asList((Writable) new NDArrayWritable(arr));
List<Writable> expected = Arrays.asList(new NDArrayWritable(arr));
List<Writable> actual = new NDArrayToWritablesFunction(true).apply(arr);
assertEquals(expected, actual);
}

View File

@ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -34,7 +37,8 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@NativeTag
public class TestWritablesToNDArrayFunction {
@Test

View File

@ -30,13 +30,16 @@ import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
import org.datavec.local.transforms.misc.WritablesToStringFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestWritablesToStringFunctions {
@ -44,7 +47,7 @@ public class TestWritablesToStringFunctions {
@Test
public void testWritablesToString() throws Exception {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue"));
List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue"));
String expected = l.get(0).toString() + "," + l.get(1).toString();
assertEquals(expected, new WritablesToStringFunction(",").apply(l));
@ -53,8 +56,8 @@ public class TestWritablesToStringFunctions {
@Test
public void testSequenceWritablesToString() throws Exception {
List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")),
Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue")));
List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")),
Arrays.asList(new DoubleWritable(2.5), new Text("otherValue")));
String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n"
+ l.get(1).get(0).toString() + "," + l.get(1).get(1).toString();

View File

@ -31,7 +31,10 @@ import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
@ -42,6 +45,8 @@ import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
@DisplayName("Execution Test")
@Tag(TagNames.FILE_IO)
@NativeTag
class ExecutionTest {
@Test
@ -71,18 +76,12 @@ class ExecutionTest {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
List<List<Writable>> rdd = (inputData);
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()));
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
@ -95,9 +94,9 @@ class ExecutionTest {
void testFilter() {
Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
inputData.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
assertEquals(2, execute.size());
@ -110,31 +109,25 @@ class ExecutionTest {
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1);
inputSequences.add(seq2);
List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size());
}
});
Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size()));
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e);
expectedSequence.add(seq2e);
assertEquals(expectedSequence, out);
@ -143,26 +136,26 @@ class ExecutionTest {
@Test
@DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out);
}
@Test
@DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out);
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
assertEquals(expOut, out);

View File

@ -28,12 +28,15 @@ import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestJoin {
@Test
@ -46,27 +49,27 @@ public class TestJoin {
.addColumnDouble("amount").build();
List<List<Writable>> infoList = new ArrayList<>();
infoList.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345")));
infoList.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765")));
infoList.add(Arrays.<Writable>asList(new LongWritable(50000), new Text("Customer50000")));
infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345")));
infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765")));
infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000")));
List<List<Writable>> purchaseList = new ArrayList<>();
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000000), new LongWritable(12345),
purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345),
new DoubleWritable(10.00)));
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000001), new LongWritable(12345),
purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345),
new DoubleWritable(20.00)));
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000002), new LongWritable(98765),
purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765),
new DoubleWritable(30.00)));
Join join = new Join.Builder(Join.JoinType.RightOuter).setJoinColumns("customerID")
.setSchemas(customerInfoSchema, purchasesSchema).build();
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"),
expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"),
new LongWritable(1000000), new DoubleWritable(10.00)));
expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"),
expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"),
new LongWritable(1000001), new DoubleWritable(20.00)));
expected.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765"),
expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"),
new LongWritable(1000002), new DoubleWritable(30.00)));
@ -77,12 +80,7 @@ public class TestJoin {
List<List<Writable>> joined = LocalTransformExecutor.executeJoin(join, info, purchases);
List<List<Writable>> joinedList = new ArrayList<>(joined);
//Sort by order ID (column 3, index 2)
Collections.sort(joinedList, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Long.compare(o1.get(2).toLong(), o2.get(2).toLong());
}
});
Collections.sort(joinedList, (o1, o2) -> Long.compare(o1.get(2).toLong(), o2.get(2).toLong()));
assertEquals(expected, joinedList);
assertEquals(3, joinedList.size());
@ -110,12 +108,7 @@ public class TestJoin {
List<List<Writable>> joined2 = LocalTransformExecutor.executeJoin(join2, purchases, info);
List<List<Writable>> joinedList2 = new ArrayList<>(joined2);
//Sort by order ID (column 0)
Collections.sort(joinedList2, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Long.compare(o1.get(0).toLong(), o2.get(0).toLong());
}
});
Collections.sort(joinedList2, (o1, o2) -> Long.compare(o1.get(0).toLong(), o2.get(0).toLong()));
assertEquals(3, joinedList2.size());
assertEquals(expectedManyToOne, joinedList2);
@ -189,29 +182,26 @@ public class TestJoin {
new ArrayList<>(LocalTransformExecutor.executeJoin(join, firstRDD, secondRDD));
//Sort output by column 0, then column 1, then column 2 for comparison to expected...
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
Writable w1 = o1.get(0);
Writable w2 = o2.get(0);
if (w1 instanceof NullWritable)
return 1;
else if (w2 instanceof NullWritable)
return -1;
int c = Long.compare(w1.toLong(), w2.toLong());
if (c != 0)
return c;
c = o1.get(1).toString().compareTo(o2.get(1).toString());
if (c != 0)
return c;
w1 = o1.get(2);
w2 = o2.get(2);
if (w1 instanceof NullWritable)
return 1;
else if (w2 instanceof NullWritable)
return -1;
return Long.compare(w1.toLong(), w2.toLong());
}
Collections.sort(out, (o1, o2) -> {
Writable w1 = o1.get(0);
Writable w2 = o2.get(0);
if (w1 instanceof NullWritable)
return 1;
else if (w2 instanceof NullWritable)
return -1;
int c = Long.compare(w1.toLong(), w2.toLong());
if (c != 0)
return c;
c = o1.get(1).toString().compareTo(o2.get(1).toString());
if (c != 0)
return c;
w1 = o1.get(2);
w2 = o2.get(2);
if (w1 instanceof NullWritable)
return 1;
else if (w2 instanceof NullWritable)
return -1;
return Long.compare(w1.toLong(), w2.toLong());
});
switch (jt) {

View File

@ -31,14 +31,17 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestCalculateSortedRank {
@Test

View File

@ -31,7 +31,9 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays;
import java.util.Collections;
@ -39,7 +41,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestConvertToSequence {
@Test
@ -48,12 +51,12 @@ public class TestConvertToSequence {
Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build();
List<List<Writable>> allExamples =
Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"),
Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
Arrays.asList(new Text("k1a"), new Text("k2a"),
new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time"))
@ -75,13 +78,13 @@ public class TestConvertToSequence {
}
List<List<Writable>> expSeq0 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
List<List<Writable>> expSeq1 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
assertEquals(expSeq0, seq0);
assertEquals(expSeq1, seq1);
@ -96,9 +99,9 @@ public class TestConvertToSequence {
.build();
List<List<Writable>> allExamples = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("b"), new LongWritable(1)),
Arrays.<Writable>asList(new Text("c"), new LongWritable(2)));
Arrays.asList(new Text("a"), new LongWritable(0)),
Arrays.asList(new Text("b"), new LongWritable(1)),
Arrays.asList(new Text("c"), new LongWritable(2)));
TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence()

View File

@ -25,8 +25,10 @@ import org.apache.spark.serializer.SerializerInstance;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.nio.ByteBuffer;
@ -34,7 +36,10 @@ import java.nio.ByteBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestKryoSerialization extends BaseSparkTest {
@Override

View File

@ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.util.HashSet;
@ -37,7 +39,10 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestLineRecordReaderFunction extends BaseSparkTest {
@Test

View File

@ -24,7 +24,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.misc.NDArrayToWritablesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +36,10 @@ import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
@NativeTag
public class TestNDArrayToWritablesFunction {
@Test

View File

@ -39,10 +39,12 @@ import org.datavec.spark.functions.pairdata.PathToKeyConverter;
import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename;
import org.datavec.spark.util.DataVecSparkUtil;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import scala.Tuple2;
import java.io.File;
@ -53,7 +55,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
@Test

View File

@ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.RecordReaderBytesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.nio.file.Files;
@ -51,7 +53,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestRecordReaderBytesFunction extends BaseSparkTest {

View File

@ -32,10 +32,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.nio.file.Path;
@ -45,7 +47,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestRecordReaderFunction extends BaseSparkTest {
@Test

View File

@ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.nio.file.Files;
@ -50,7 +52,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {

View File

@ -34,10 +34,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.codec.reader.CodecRecordReader;
import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.nio.file.Path;
@ -46,7 +48,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSequenceRecordReaderFunction extends BaseSparkTest {

View File

@ -22,7 +22,10 @@ package org.datavec.spark.functions;
import org.datavec.api.writable.*;
import org.datavec.spark.transform.misc.WritablesToNDArrayFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -31,7 +34,10 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
@NativeTag
public class TestWritablesToNDArrayFunction {
@Test

View File

@ -29,7 +29,9 @@ import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction;
import org.datavec.spark.transform.misc.WritablesToStringFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import scala.Tuple2;
import java.util.ArrayList;
@ -37,7 +39,10 @@ import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestWritablesToStringFunctions extends BaseSparkTest {
@Test
@ -57,19 +62,9 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
JavaSparkContext sc = getContext();
JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() {
@Override
public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception {
return stringStringTuple2;
}
});
JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2);
JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() {
@Override
public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception {
return stringStringTuple2;
}
});
JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2);
System.out.println(left.cogroup(right).collect());
}
@ -77,7 +72,7 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
@Test
public void testWritablesToString() throws Exception {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue"));
List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue"));
String expected = l.get(0).toString() + "," + l.get(1).toString();
assertEquals(expected, new WritablesToStringFunction(",").call(l));
@ -86,8 +81,8 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
@Test
public void testSequenceWritablesToString() throws Exception {
List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")),
Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue")));
List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")),
Arrays.asList(new DoubleWritable(2.5), new Text("otherValue")));
String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n"
+ l.get(1).get(0).toString() + "," + l.get(1).get(1).toString();

View File

@ -21,6 +21,8 @@
package org.datavec.spark.storage;
import com.sun.jna.Platform;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
@ -37,7 +39,10 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSparkStorageUtils extends BaseSparkTest {
@Test
@ -46,11 +51,11 @@ public class TestSparkStorageUtils extends BaseSparkTest {
return;
}
List<List<Writable>> l = new ArrayList<>();
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
l.add(Arrays.asList(new Text("zero"), new IntWritable(0),
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))));
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(11),
l.add(Arrays.asList(new Text("one"), new IntWritable(11),
new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))));
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(22),
l.add(Arrays.asList(new Text("two"), new IntWritable(22),
new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0))));
JavaRDD<List<Writable>> rdd = sc.parallelize(l);
@ -92,17 +97,17 @@ public class TestSparkStorageUtils extends BaseSparkTest {
}
List<List<List<Writable>>> l = new ArrayList<>();
l.add(Arrays.asList(
Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
Arrays.asList(new Text("zero"), new IntWritable(0),
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(1),
Arrays.asList(new Text("one"), new IntWritable(1),
new DoubleWritable(1.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 1.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(2),
Arrays.asList(new Text("two"), new IntWritable(2),
new DoubleWritable(2.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 2.0)))));
l.add(Arrays.asList(
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bzero"), new IntWritable(10),
Arrays.asList(new Text("Bzero"), new IntWritable(10),
new DoubleWritable(10), new NDArrayWritable(Nd4j.valueArrayOf(10, 10.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bone"), new IntWritable(11),
Arrays.asList(new Text("Bone"), new IntWritable(11),
new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Btwo"), new IntWritable(12),
new DoubleWritable(12.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 12.0)))));

View File

@ -30,14 +30,19 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class DataFramesTests extends BaseSparkTest {
@Test
@ -110,15 +115,15 @@ public class DataFramesTests extends BaseSparkTest {
public void testNormalize() {
List<List<Writable>> data = new ArrayList<>();
data.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10)));
data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20)));
data.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30)));
data.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10)));
data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20)));
data.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30)));
List<List<Writable>> expMinMax = new ArrayList<>();
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new DoubleWritable(0.5)));
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(1.0)));
expMinMax.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
expMinMax.add(Arrays.asList(new DoubleWritable(0.5), new DoubleWritable(0.5)));
expMinMax.add(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(1.0)));
double m1 = (1 + 2 + 3) / 3.0;
double s1 = new StandardDeviation().evaluate(new double[] {1, 2, 3});
@ -127,11 +132,11 @@ public class DataFramesTests extends BaseSparkTest {
List<List<Writable>> expStandardize = new ArrayList<>();
expStandardize.add(
Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2)));
Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2)));
expStandardize.add(
Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2)));
Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2)));
expStandardize.add(
Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2)));
Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2)));
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
@ -178,13 +183,13 @@ public class DataFramesTests extends BaseSparkTest {
List<List<List<Writable>>> sequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100)));
seq1.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200)));
seq1.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300)));
seq1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100)));
seq1.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200)));
seq1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300)));
List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400)));
seq2.add(Arrays.<Writable>asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500)));
seq2.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400)));
seq2.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500)));
sequences.add(seq1);
sequences.add(seq2);
@ -199,21 +204,21 @@ public class DataFramesTests extends BaseSparkTest {
//Min/max normalization:
List<List<Writable>> expSeq1MinMax = new ArrayList<>();
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)),
expSeq1MinMax.add(Arrays.asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((10 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((100 - 100.0) / (500.0 - 100.0))));
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)),
expSeq1MinMax.add(Arrays.asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((20 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((200 - 100.0) / (500.0 - 100.0))));
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)),
expSeq1MinMax.add(Arrays.asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((30 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((300 - 100.0) / (500.0 - 100.0))));
List<List<Writable>> expSeq2MinMax = new ArrayList<>();
expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)),
expSeq2MinMax.add(Arrays.asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((40 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((400 - 100.0) / (500.0 - 100.0))));
expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)),
expSeq2MinMax.add(Arrays.asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((50 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((500 - 100.0) / (500.0 - 100.0))));
@ -246,17 +251,17 @@ public class DataFramesTests extends BaseSparkTest {
double s3 = new StandardDeviation().evaluate(new double[] {100, 200, 300, 400, 500});
List<List<Writable>> expSeq1Std = new ArrayList<>();
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2),
expSeq1Std.add(Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2),
new DoubleWritable((100 - m3) / s3)));
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2),
expSeq1Std.add(Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2),
new DoubleWritable((200 - m3) / s3)));
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2),
expSeq1Std.add(Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2),
new DoubleWritable((300 - m3) / s3)));
List<List<Writable>> expSeq2Std = new ArrayList<>();
expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2),
expSeq2Std.add(Arrays.asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2),
new DoubleWritable((400 - m3) / s3)));
expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2),
expSeq2Std.add(Arrays.asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2),
new DoubleWritable((500 - m3) / s3)));

View File

@ -31,22 +31,24 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.python.PythonTransform;
import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Execution Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
class ExecutionTest extends BaseSparkTest {
@Test
@ -55,22 +57,16 @@ class ExecutionTest extends BaseSparkTest {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out);
}
@ -81,31 +77,25 @@ class ExecutionTest extends BaseSparkTest {
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
inputSequences.add(seq1);
inputSequences.add(seq2);
JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences);
List<List<List<Writable>>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect());
Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size());
}
});
Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size()));
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e);
expectedSequence.add(seq2e);
assertEquals(expectedSequence, out);
@ -114,34 +104,28 @@ class ExecutionTest extends BaseSparkTest {
@Test
@DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out);
}
@Test
@DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out);
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()));
assertEquals(expOut, out);
}
@ -150,15 +134,15 @@ class ExecutionTest extends BaseSparkTest {
void testUniqueMultiCol() {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
Map<String, List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);
assertEquals(2, l.size());
@ -180,58 +164,20 @@ class ExecutionTest extends BaseSparkTest {
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out);
});
}
@Test
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@DisplayName("Test Python Execution With ND Arrays")
void testPythonExecutionWithNDArrays() {
assertTimeout(ofMillis(60000), () -> {
long[] shape = new long[] { 3, 2 };
Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).build();
Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
INDArray zeros = Nd4j.zeros(shape);
INDArray ones = Nd4j.ones(shape);
INDArray twos = ones.add(ones);
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos)));
});
}
@Test
@DisplayName("Test First Digit Transform Benfords Law")

View File

@ -28,7 +28,10 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
@ -43,7 +46,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
@NativeTag
public class NormalizationTests extends BaseSparkTest {

View File

@ -38,7 +38,9 @@ import org.datavec.local.transforms.AnalyzeLocal;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.AnalyzeSpark;
import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
@ -48,7 +50,10 @@ import java.nio.file.Files;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestAnalysis extends BaseSparkTest {
@Test

View File

@ -27,12 +27,17 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestJoin extends BaseSparkTest {
@Test

View File

@ -30,24 +30,29 @@ import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestCalculateSortedRank extends BaseSparkTest {
@Test
public void testCalculateSortedRank() {
List<List<Writable>> data = new ArrayList<>();
data.add(Arrays.asList((Writable) new Text("0"), new DoubleWritable(0.0)));
data.add(Arrays.asList((Writable) new Text("3"), new DoubleWritable(0.3)));
data.add(Arrays.asList((Writable) new Text("2"), new DoubleWritable(0.2)));
data.add(Arrays.asList((Writable) new Text("1"), new DoubleWritable(0.1)));
data.add(Arrays.asList(new Text("0"), new DoubleWritable(0.0)));
data.add(Arrays.asList(new Text("3"), new DoubleWritable(0.3)));
data.add(Arrays.asList(new Text("2"), new DoubleWritable(0.2)));
data.add(Arrays.asList(new Text("1"), new DoubleWritable(0.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(data);

View File

@ -29,7 +29,9 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays;
import java.util.Collections;
@ -37,7 +39,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestConvertToSequence extends BaseSparkTest {
@Test
@ -45,13 +50,13 @@ public class TestConvertToSequence extends BaseSparkTest {
Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build();
List<List<Writable>> allExamples =
Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"),
new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
List<List<Writable>> allExamples;
allExamples = Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
Arrays.asList(new Text("k1a"), new Text("k2a"),
new LongWritable(-10)),
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time"))
@ -73,13 +78,13 @@ public class TestConvertToSequence extends BaseSparkTest {
}
List<List<Writable>> expSeq0 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
List<List<Writable>> expSeq1 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
assertEquals(expSeq0, seq0);
assertEquals(expSeq1, seq1);
@ -94,9 +99,9 @@ public class TestConvertToSequence extends BaseSparkTest {
.build();
List<List<Writable>> allExamples = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("b"), new LongWritable(1)),
Arrays.<Writable>asList(new Text("c"), new LongWritable(2)));
Arrays.asList(new Text("a"), new LongWritable(0)),
Arrays.asList(new Text("b"), new LongWritable(1)),
Arrays.asList(new Text("c"), new LongWritable(2)));
TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence()

View File

@ -28,7 +28,9 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.utils.SparkUtils;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File;
import java.io.FileInputStream;
@ -38,6 +40,9 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.SPARK)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSparkUtil extends BaseSparkTest {
@Test
@ -46,8 +51,8 @@ public class TestSparkUtil extends BaseSparkTest {
return;
}
List<List<Writable>> l = new ArrayList<>();
l.add(Arrays.<Writable>asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1)));
l.add(Arrays.<Writable>asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2)));
l.add(Arrays.asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1)));
l.add(Arrays.asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2)));
File f = File.createTempFile("testSparkUtil", "txt");
f.deleteOnExit();

View File

@ -27,8 +27,11 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
@ -39,6 +42,8 @@ import java.nio.file.Files;
import java.util.concurrent.CountDownLatch;
@Disabled
@NativeTag
@Tag(TagNames.RNG)
public class RandomTests extends BaseDL4JTest {
@Test

View File

@ -23,11 +23,11 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.dataset.DataSet;
@ -41,10 +41,13 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Mnist Fetcher Test")
@NativeTag
@Tag(TagNames.FILE_IO)
@Tag(TagNames.NDARRAY_ETL)
class MnistFetcherTest extends BaseDL4JTest {

View File

@ -23,8 +23,14 @@ package org.deeplearning4j.datasets;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@NativeTag
@Tag(TagNames.FILE_IO)
@Tag(TagNames.NDARRAY_ETL)
public class TestDataSets extends BaseDL4JTest {
@Override

View File

@ -20,8 +20,11 @@
package org.deeplearning4j.datasets.datavec;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j;
@ -76,6 +79,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j
@DisplayName("Record Reader Data Setiterator Test")
@Disabled
@NativeTag
class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Override
@ -148,6 +152,7 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader")
@Tag(TagNames.NDARRAY_INDEXING)
void testSequenceRecordReader(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile();
// need to manually extract

View File

@ -21,6 +21,8 @@ package org.deeplearning4j.datasets.datavec;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
@ -70,6 +72,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Record Reader Multi Data Set Iterator Test")
@Disabled
@Tag(TagNames.FILE_IO)
class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@TempDir

View File

@ -21,6 +21,7 @@ package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import java.io.File;
@ -28,11 +29,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
/**
* @author saudet
*/
@DisplayName("Svhn Data Fetcher Test")
@Tag(TagNames.FILE_IO)
@NativeTag
class SvhnDataFetcherTest extends BaseDL4JTest {
@Override

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator;
import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
@ -40,6 +41,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j
@DisplayName("Async Data Set Iterator Test")
@NativeTag
class AsyncDataSetIteratorTest extends BaseDL4JTest {
private ExistingDataSetIterator backIterator;

Some files were not shown because too many files have changed in this diff Show More