Add tags for junit 5

master
agibsonccc 2021-03-20 19:06:24 +09:00
parent 3c205548af
commit 5e8951cd8e
773 changed files with 4684 additions and 1152 deletions

View File

@ -12,7 +12,7 @@ DL4J was a junit 4 based code based for testing.
It's now based on junit 5's jupiter API, which has support for [Tags](https://junit.org/junit5/docs/5.0.1/api/org/junit/jupiter/api/Tag.html). It's now based on junit 5's jupiter API, which has support for [Tags](https://junit.org/junit5/docs/5.0.1/api/org/junit/jupiter/api/Tag.html).
DL4j's code base has a number of different kinds of tests that fall in to several categories: DL4j's code base has a number of different kinds of tests that fall in to several categories:
1. Long and flaky involving distributed systems (spark, parameter server) 1. Long and flaky involving distributed systems (spark, parameter-server)
2. Code that requires large downloads, but runs quickly 2. Code that requires large downloads, but runs quickly
3. Quick tests that test basic functionality 3. Quick tests that test basic functionality
4. Comprehensive integration tests that test several parts of a code base 4. Comprehensive integration tests that test several parts of a code base
@ -38,8 +38,10 @@ A few kinds of tags exist:
3. Distributed systems: spark, multi-threaded 3. Distributed systems: spark, multi-threaded
4. Functional cross-cutting concerns: multi module tests, similar functionality (excludes time based) 4. Functional cross-cutting concerns: multi module tests, similar functionality (excludes time based)
5. Platform specific tests that can vary on different hardware: cpu, gpu 5. Platform specific tests that can vary on different hardware: cpu, gpu
6. JVM crash: Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash 6. JVM crash: (jvm-crash) Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash
7. RNG: (rng) for RNG related tests
8. Samediff:(samediff) samediff related tests
9. Training related functionality
## Consequences ## Consequences

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
@ -38,8 +39,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Line Sequence Record Reader Test") @DisplayName("Csv Line Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVLineSequenceRecordReaderTest extends BaseND4JTest { class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
@TempDir @TempDir
@ -54,8 +58,8 @@ class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8); FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
SequenceRecordReader rr = new CSVLineSequenceRecordReader(); SequenceRecordReader rr = new CSVLineSequenceRecordReader();
rr.initialize(new FileSplit(source)); rr.initialize(new FileSplit(source));
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c"))); List<List<Writable>> exp0 = Arrays.asList(Collections.singletonList(new Text("a")), Collections.singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4"))); List<List<Writable>> exp1 = Arrays.asList(Collections.singletonList(new Text("1")), Collections.singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
int count = 0; int count = 0;
while (rr.hasNext()) { while (rr.hasNext()) {

View File

@ -27,6 +27,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
@ -41,8 +42,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Multi Sequence Record Reader Test") @DisplayName("Csv Multi Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
@TempDir @TempDir

View File

@ -26,6 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -34,12 +35,15 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csvn Lines Sequence Record Reader Test") @DisplayName("Csvn Lines Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test @Test
@DisplayName("Test CSVN Lines Sequence Record Reader") @DisplayName("Test CSV Lines Sequence Record Reader")
void testCSVNLinesSequenceRecordReader() throws Exception { void testCSVNLinesSequenceRecordReader() throws Exception {
int nLinesPerSequence = 10; int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);

View File

@ -34,9 +34,12 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.file.Files; import java.nio.file.Files;
@ -50,6 +53,8 @@ import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Csv Record Reader Test") @DisplayName("Csv Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVRecordReaderTest extends BaseND4JTest { class CSVRecordReaderTest extends BaseND4JTest {
@Test @Test

View File

@ -27,6 +27,7 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
@ -43,8 +44,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Sequence Record Reader Test") @DisplayName("Csv Sequence Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVSequenceRecordReaderTest extends BaseND4JTest { class CSVSequenceRecordReaderTest extends BaseND4JTest {
@TempDir @TempDir

View File

@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader; import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -32,8 +33,11 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Variable Sliding Window Record Reader Test") @DisplayName("Csv Variable Sliding Window Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest { class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test @Test

View File

@ -28,6 +28,7 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
@ -42,9 +43,12 @@ import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@DisplayName("File Batch Record Reader Test") @DisplayName("File Batch Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class FileBatchRecordReaderTest extends BaseND4JTest { public class FileBatchRecordReaderTest extends BaseND4JTest {
@TempDir Path testDir; @TempDir Path testDir;

View File

@ -25,6 +25,7 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("File Record Reader Test") @DisplayName("File Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class FileRecordReaderTest extends BaseND4JTest { class FileRecordReaderTest extends BaseND4JTest {
@Test @Test

View File

@ -28,10 +28,12 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.File; import java.io.File;
@ -45,6 +47,8 @@ import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Jackson Line Record Reader Test") @DisplayName("Jackson Line Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class JacksonLineRecordReaderTest extends BaseND4JTest { class JacksonLineRecordReaderTest extends BaseND4JTest {
@TempDir @TempDir

View File

@ -31,10 +31,12 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory; import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
@ -51,6 +53,8 @@ import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Jackson Record Reader Test") @DisplayName("Jackson Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class JacksonRecordReaderTest extends BaseND4JTest { class JacksonRecordReaderTest extends BaseND4JTest {
@TempDir @TempDir

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Lib Svm Record Reader Test") @DisplayName("Lib Svm Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class LibSvmRecordReaderTest extends BaseND4JTest { class LibSvmRecordReaderTest extends BaseND4JTest {
@Test @Test

View File

@ -30,6 +30,7 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
@ -47,8 +48,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Line Reader Test") @DisplayName("Line Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class LineReaderTest extends BaseND4JTest { class LineReaderTest extends BaseND4JTest {

View File

@ -33,6 +33,7 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
@ -46,8 +47,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Regex Record Reader Test") @DisplayName("Regex Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class RegexRecordReaderTest extends BaseND4JTest { class RegexRecordReaderTest extends BaseND4JTest {
@TempDir @TempDir

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Svm Light Record Reader Test") @DisplayName("Svm Light Record Reader Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class SVMLightRecordReaderTest extends BaseND4JTest { class SVMLightRecordReaderTest extends BaseND4JTest {
@Test @Test

View File

@ -26,15 +26,18 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestCollectionRecordReaders extends BaseND4JTest { public class TestCollectionRecordReaders extends BaseND4JTest {
@Test @Test

View File

@ -23,12 +23,15 @@ package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestConcatenatingRecordReader extends BaseND4JTest { public class TestConcatenatingRecordReader extends BaseND4JTest {
@Test @Test

View File

@ -37,9 +37,11 @@ import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper; import org.nd4j.shade.jackson.databind.ObjectMapper;
@ -48,7 +50,8 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestSerialization extends BaseND4JTest { public class TestSerialization extends BaseND4JTest {
@Test @Test

View File

@ -30,9 +30,11 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -41,6 +43,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TransformProcessRecordReaderTests extends BaseND4JTest { public class TransformProcessRecordReaderTests extends BaseND4JTest {
@Test @Test
@ -74,11 +78,11 @@ public class TransformProcessRecordReaderTests extends BaseND4JTest {
public void simpleTransformTestSequence() { public void simpleTransformTestSequence() {
List<List<Writable>> sequence = new ArrayList<>(); List<List<Writable>> sequence = new ArrayList<>();
//First window: //First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0),
new IntWritable(0))); new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1),
new IntWritable(0))); new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2),
new IntWritable(0))); new IntWritable(0)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)

View File

@ -26,6 +26,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.File; import java.io.File;
@ -34,8 +35,11 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Csv Record Writer Test") @DisplayName("Csv Record Writer Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class CSVRecordWriterTest extends BaseND4JTest { class CSVRecordWriterTest extends BaseND4JTest {
@BeforeEach @BeforeEach

View File

@ -29,8 +29,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -46,6 +48,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Lib Svm Record Writer Test") @DisplayName("Lib Svm Record Writer Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class LibSvmRecordWriterTest extends BaseND4JTest { class LibSvmRecordWriterTest extends BaseND4JTest {
@Test @Test

View File

@ -26,8 +26,10 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -43,6 +45,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Svm Light Record Writer Test") @DisplayName("Svm Light Record Writer Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class SVMLightRecordWriterTest extends BaseND4JTest { class SVMLightRecordWriterTest extends BaseND4JTest {
@Test @Test

View File

@ -20,7 +20,9 @@
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter; import org.datavec.api.io.filters.RandomPathFilter;
@ -42,6 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
* *
* @author saudet * @author saudet
*/ */
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class InputSplitTests extends BaseND4JTest { public class InputSplitTests extends BaseND4JTest {
@Test @Test

View File

@ -20,13 +20,16 @@
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.net.URI; import java.net.URI;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class NumberedFileInputSplitTests extends BaseND4JTest { public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test @Test
public void testNumberedFileInputSplitBasic() { public void testNumberedFileInputSplitBasic() {

View File

@ -26,11 +26,13 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.function.Function; import org.nd4j.common.function.Function;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -46,7 +48,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestStreamInputSplit extends BaseND4JTest { public class TestStreamInputSplit extends BaseND4JTest {

View File

@ -19,6 +19,7 @@
*/ */
package org.datavec.api.split; package org.datavec.api.split;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI; import java.net.URI;
@ -28,11 +29,14 @@ import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
/** /**
* @author Ede Meijer * @author Ede Meijer
*/ */
@DisplayName("Transform Split Test") @DisplayName("Transform Split Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class TransformSplitTest extends BaseND4JTest { class TransformSplitTest extends BaseND4JTest {
@Test @Test

View File

@ -20,7 +20,9 @@
package org.datavec.api.split.parittion; package org.datavec.api.split.parittion;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.datavec.api.conf.Configuration; import org.datavec.api.conf.Configuration;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
@ -33,7 +35,8 @@ import java.io.File;
import java.io.OutputStream; import java.io.OutputStream;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class PartitionerTests extends BaseND4JTest { public class PartitionerTests extends BaseND4JTest {
@Test @Test
public void testRecordsPerFilePartition() { public void testRecordsPerFilePartition() {

View File

@ -29,13 +29,16 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestTransformProcess extends BaseND4JTest { public class TestTransformProcess extends BaseND4JTest {
@Test @Test

View File

@ -27,14 +27,17 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestConditions extends BaseND4JTest { public class TestConditions extends BaseND4JTest {
@Test @Test

View File

@ -27,8 +27,10 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -38,7 +40,8 @@ import java.util.List;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestFilters extends BaseND4JTest { public class TestFilters extends BaseND4JTest {

View File

@ -26,9 +26,11 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.nio.file.Path; import java.nio.file.Path;
import java.util.ArrayList; import java.util.ArrayList;
@ -37,7 +39,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestJoin extends BaseND4JTest { public class TestJoin extends BaseND4JTest {
@Test @Test

View File

@ -33,14 +33,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestMultiOpReduce extends BaseND4JTest { public class TestMultiOpReduce extends BaseND4JTest {
@Test @Test

View File

@ -24,14 +24,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestReductions extends BaseND4JTest { public class TestReductions extends BaseND4JTest {
@Test @Test

View File

@ -22,11 +22,15 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
public class TestJsonYaml extends BaseND4JTest { public class TestJsonYaml extends BaseND4JTest {
@Test @Test

View File

@ -21,11 +21,14 @@
package org.datavec.api.transform.schema; package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ColumnType;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestSchemaMethods extends BaseND4JTest { public class TestSchemaMethods extends BaseND4JTest {
@Test @Test

View File

@ -33,8 +33,10 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -42,7 +44,8 @@ import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestReduceSequenceByWindowFunction extends BaseND4JTest { public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
@Test @Test

View File

@ -27,8 +27,10 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -36,7 +38,8 @@ import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestSequenceSplit extends BaseND4JTest { public class TestSequenceSplit extends BaseND4JTest {
@Test @Test
@ -46,13 +49,13 @@ public class TestSequenceSplit extends BaseND4JTest {
.build(); .build();
List<List<Writable>> inputSequence = new ArrayList<>(); List<List<Writable>> inputSequence = new ArrayList<>();
inputSequence.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0"))); inputSequence.add(Arrays.asList(new LongWritable(0), new Text("t0")));
inputSequence.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1"))); inputSequence.add(Arrays.asList(new LongWritable(1000), new Text("t1")));
//Second split: 74 seconds later //Second split: 74 seconds later
inputSequence.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2"))); inputSequence.add(Arrays.asList(new LongWritable(75000), new Text("t2")));
inputSequence.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3"))); inputSequence.add(Arrays.asList(new LongWritable(100000), new Text("t3")));
//Third split: 1 minute and 1 milliseconds later //Third split: 1 minute and 1 milliseconds later
inputSequence.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4"))); inputSequence.add(Arrays.asList(new LongWritable(160001), new Text("t4")));
SequenceSplit seqSplit = new SequenceSplitTimeSeparation("time", 1, TimeUnit.MINUTES); SequenceSplit seqSplit = new SequenceSplitTimeSeparation("time", 1, TimeUnit.MINUTES);
seqSplit.setInputSchema(schema); seqSplit.setInputSchema(schema);
@ -61,13 +64,13 @@ public class TestSequenceSplit extends BaseND4JTest {
assertEquals(3, splits.size()); assertEquals(3, splits.size());
List<List<Writable>> exp0 = new ArrayList<>(); List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0"))); exp0.add(Arrays.asList(new LongWritable(0), new Text("t0")));
exp0.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1"))); exp0.add(Arrays.asList(new LongWritable(1000), new Text("t1")));
List<List<Writable>> exp1 = new ArrayList<>(); List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2"))); exp1.add(Arrays.asList(new LongWritable(75000), new Text("t2")));
exp1.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3"))); exp1.add(Arrays.asList(new LongWritable(100000), new Text("t3")));
List<List<Writable>> exp2 = new ArrayList<>(); List<List<Writable>> exp2 = new ArrayList<>();
exp2.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4"))); exp2.add(Arrays.asList(new LongWritable(160001), new Text("t4")));
assertEquals(exp0, splits.get(0)); assertEquals(exp0, splits.get(0));
assertEquals(exp1, splits.get(1)); assertEquals(exp1, splits.get(1));

View File

@ -29,8 +29,10 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -38,7 +40,8 @@ import java.util.List;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestWindowFunctions extends BaseND4JTest { public class TestWindowFunctions extends BaseND4JTest {
@Test @Test
@ -49,15 +52,15 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data. //Create some data.
List<List<Writable>> sequence = new ArrayList<>(); List<List<Writable>> sequence = new ArrayList<>();
//First window: //First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
//Second window: //Second window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
//Third window: empty //Third window: empty
//Fourth window: //Fourth window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").build(); .addColumnInteger("intcolumn").build();
@ -100,15 +103,15 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data. //Create some data.
List<List<Writable>> sequence = new ArrayList<>(); List<List<Writable>> sequence = new ArrayList<>();
//First window: //First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
//Second window: //Second window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
//Third window: empty //Third window: empty
//Fourth window: //Fourth window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").build(); .addColumnInteger("intcolumn").build();
@ -150,15 +153,15 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data. //Create some data.
List<List<Writable>> sequence = new ArrayList<>(); List<List<Writable>> sequence = new ArrayList<>();
//First window: //First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
//Second window: //Second window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
//Third window: empty //Third window: empty
//Fourth window: //Fourth window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
.addColumnInteger("intcolumn").build(); .addColumnInteger("intcolumn").build();
@ -188,13 +191,13 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data. //Create some data.
List<List<Writable>> sequence = new ArrayList<>(); List<List<Writable>> sequence = new ArrayList<>();
//First window: //First window:
sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
@ -207,32 +210,32 @@ public class TestWindowFunctions extends BaseND4JTest {
//First window: -1000 to 1000 //First window: -1000 to 1000
List<List<Writable>> exp0 = new ArrayList<>(); List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
//Second window: 0 to 2000 //Second window: 0 to 2000
List<List<Writable>> exp1 = new ArrayList<>(); List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
//Third window: 1000 to 3000 //Third window: 1000 to 3000
List<List<Writable>> exp2 = new ArrayList<>(); List<List<Writable>> exp2 = new ArrayList<>();
exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fourth window: 2000 to 4000 //Fourth window: 2000 to 4000
List<List<Writable>> exp3 = new ArrayList<>(); List<List<Writable>> exp3 = new ArrayList<>();
exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fifth window: 3000 to 5000 //Fifth window: 3000 to 5000
List<List<Writable>> exp4 = new ArrayList<>(); List<List<Writable>> exp4 = new ArrayList<>();
//Sixth window: 4000 to 6000 //Sixth window: 4000 to 6000
List<List<Writable>> exp5 = new ArrayList<>(); List<List<Writable>> exp5 = new ArrayList<>();
exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
//Seventh window: 5000 to 7000 //Seventh window: 5000 to 7000
List<List<Writable>> exp6 = new ArrayList<>(); List<List<Writable>> exp6 = new ArrayList<>();
exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp4, exp5, exp6); List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp4, exp5, exp6);
@ -250,13 +253,13 @@ public class TestWindowFunctions extends BaseND4JTest {
//Create some data. //Create some data.
List<List<Writable>> sequence = new ArrayList<>(); List<List<Writable>> sequence = new ArrayList<>();
//First window: //First window:
sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
@ -272,31 +275,31 @@ public class TestWindowFunctions extends BaseND4JTest {
//First window: -1000 to 1000 //First window: -1000 to 1000
List<List<Writable>> exp0 = new ArrayList<>(); List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
//Second window: 0 to 2000 //Second window: 0 to 2000
List<List<Writable>> exp1 = new ArrayList<>(); List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
//Third window: 1000 to 3000 //Third window: 1000 to 3000
List<List<Writable>> exp2 = new ArrayList<>(); List<List<Writable>> exp2 = new ArrayList<>();
exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fourth window: 2000 to 4000 //Fourth window: 2000 to 4000
List<List<Writable>> exp3 = new ArrayList<>(); List<List<Writable>> exp3 = new ArrayList<>();
exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
//Fifth window: 3000 to 5000 -> Empty: excluded //Fifth window: 3000 to 5000 -> Empty: excluded
//Sixth window: 4000 to 6000 //Sixth window: 4000 to 6000
List<List<Writable>> exp5 = new ArrayList<>(); List<List<Writable>> exp5 = new ArrayList<>();
exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
//Seventh window: 5000 to 7000 //Seventh window: 5000 to 7000
List<List<Writable>> exp6 = new ArrayList<>(); List<List<Writable>> exp6 = new ArrayList<>();
exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp5, exp6); List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp5, exp6);

View File

@ -26,11 +26,17 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.testClasses.CustomCondition; import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform; import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
@Tag(TagNames.CUSTOM_FUNCTIONALITY)
public class TestCustomTransformJsonYaml extends BaseND4JTest { public class TestCustomTransformJsonYaml extends BaseND4JTest {
@Test @Test

View File

@ -64,14 +64,19 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*; import java.util.*;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
public class TestYamlJsonSerde extends BaseND4JTest { public class TestYamlJsonSerde extends BaseND4JTest {
public static YamlSerializer y = new YamlSerializer(); public static YamlSerializer y = new YamlSerializer();

View File

@ -24,22 +24,26 @@ import org.datavec.api.transform.StringReduceOp;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestReduce extends BaseND4JTest { public class TestReduce extends BaseND4JTest {
@Test @Test
public void testReducerDouble() { public void testReducerDouble() {
List<List<Writable>> inputs = new ArrayList<>(); List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); inputs.add(Arrays.asList(new Text("1"), new Text("2")));
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); inputs.add(Arrays.asList(new Text("1"), new Text("2")));
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); inputs.add(Arrays.asList(new Text("1"), new Text("2")));
Map<StringReduceOp, String> exp = new LinkedHashMap<>(); Map<StringReduceOp, String> exp = new LinkedHashMap<>();
exp.put(StringReduceOp.MERGE, "12"); exp.put(StringReduceOp.MERGE, "12");

View File

@ -37,10 +37,12 @@ import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.nio.file.Path; import java.nio.file.Path;
@ -49,7 +51,9 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.UI)
public class TestUI extends BaseND4JTest { public class TestUI extends BaseND4JTest {

View File

@ -20,6 +20,7 @@
package org.datavec.api.util; package org.datavec.api.util;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import java.io.BufferedReader; import java.io.BufferedReader;
@ -33,8 +34,11 @@ import static org.hamcrest.core.AnyOf.anyOf;
import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.IsEqual.equalTo;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Class Path Resource Test") @DisplayName("Class Path Resource Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class ClassPathResourceTest extends BaseND4JTest { class ClassPathResourceTest extends BaseND4JTest {
// File sizes are reported slightly different on Linux vs. Windows // File sizes are reported slightly different on Linux vs. Windows

View File

@ -22,8 +22,10 @@ package org.datavec.api.util;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -32,6 +34,8 @@ import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Time Series Utils Test") @DisplayName("Time Series Utils Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class TimeSeriesUtilsTest extends BaseND4JTest { class TimeSeriesUtilsTest extends BaseND4JTest {
@Test @Test

View File

@ -19,7 +19,9 @@
*/ */
package org.datavec.api.writable; package org.datavec.api.writable;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.collect.Lists; import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.util.ndarray.RecordConverter;
@ -36,6 +38,8 @@ import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Record Converter Test") @DisplayName("Record Converter Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class RecordConverterTest extends BaseND4JTest { class RecordConverterTest extends BaseND4JTest {
@Test @Test

View File

@ -21,15 +21,18 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.io.*; import java.io.*;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class TestNDArrayWritableAndSerialization extends BaseND4JTest { public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
@Test @Test

View File

@ -20,8 +20,10 @@
package org.datavec.api.writable; package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -37,6 +39,8 @@ import org.junit.jupiter.api.DisplayName;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Writable Test") @DisplayName("Writable Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class WritableTest extends BaseND4JTest { class WritableTest extends BaseND4JTest {
@Test @Test

View File

@ -42,9 +42,11 @@ import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
@ -62,6 +64,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j @Slf4j
@DisplayName("Arrow Converter Test") @DisplayName("Arrow Converter Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class ArrowConverterTest extends BaseND4JTest { class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@ -142,8 +146,8 @@ class ArrowConverterTest extends BaseND4JTest {
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch); List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build()); List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
List<List<Writable>> assertionBatch = new ArrayList<>(); List<List<Writable>> assertionBatch = new ArrayList<>();
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0))); assertionBatch.add(Arrays.asList(new IntWritable(0), new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1))); assertionBatch.add(Arrays.asList(new IntWritable(1), new IntWritable(1)));
assertEquals(assertionBatch, batchRecords); assertEquals(assertionBatch, batchRecords);
} }
@ -156,11 +160,11 @@ class ArrowConverterTest extends BaseND4JTest {
schema.addColumnTime(String.valueOf(i), TimeZone.getDefault()); schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3))); List<List<Writable>> input = Arrays.asList(Arrays.asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input); List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build()); ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)); List<Writable> assertion = Arrays.asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5))); writableRecordBatch.set(1, Arrays.asList(new LongWritable(4), new LongWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1); List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion, recordTest); assertEquals(assertion, recordTest);
} }
@ -174,11 +178,11 @@ class ArrowConverterTest extends BaseND4JTest {
schema.addColumnInteger(String.valueOf(i)); schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i)); single.add(String.valueOf(i));
} }
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3))); List<List<Writable>> input = Arrays.asList(Arrays.asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input); List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build()); ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)); List<Writable> assertion = Arrays.asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5))); writableRecordBatch.set(1, Arrays.asList(new IntWritable(4), new IntWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1); List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion, recordTest); assertEquals(assertion, recordTest);
} }

View File

@ -33,6 +33,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.primitives.Triple; import org.nd4j.common.primitives.Triple;
@ -44,8 +45,11 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Record Mapper Test") @DisplayName("Record Mapper Test")
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
class RecordMapperTest extends BaseND4JTest { class RecordMapperTest extends BaseND4JTest {
@Test @Test

View File

@ -30,8 +30,10 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter; import org.datavec.arrow.ArrowConverter;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
@ -39,7 +41,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.FILE_IO)
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

View File

@ -25,6 +25,7 @@ import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -36,8 +37,12 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Label Generator Test") @DisplayName("Label Generator Test")
@NativeTag
@Tag(TagNames.FILE_IO)
class LabelGeneratorTest { class LabelGeneratorTest {

View File

@ -23,7 +23,10 @@ package org.datavec.image.loader;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import java.io.File; import java.io.File;
@ -39,6 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
/** /**
* *
*/ */
@NativeTag
@Tag(TagNames.FILE_IO)
public class LoaderTests { public class LoaderTests {
private static void ensureDataAvailable(){ private static void ensureDataAvailable(){
@ -81,7 +86,7 @@ public class LoaderTests {
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin"; String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir); String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
byte[] fullDataExpected = new byte[3073]; byte[] fullDataExpected = new byte[3073];
FileInputStream inExpected = new FileInputStream(new File(path)); FileInputStream inExpected = new FileInputStream(path);
inExpected.read(fullDataExpected); inExpected.read(fullDataExpected);
byte[] fullDataActual = new byte[3073]; byte[] fullDataActual = new byte[3073];
@ -94,7 +99,7 @@ public class LoaderTests {
subDir = "cifar/cifar-10-batches-bin/test_batch.bin"; subDir = "cifar/cifar-10-batches-bin/test_batch.bin";
path = FilenameUtils.concat(System.getProperty("user.home"), subDir); path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
fullDataExpected = new byte[3073]; fullDataExpected = new byte[3073];
inExpected = new FileInputStream(new File(path)); inExpected = new FileInputStream(path);
inExpected.read(fullDataExpected); inExpected.read(fullDataExpected);
fullDataActual = new byte[3073]; fullDataActual = new byte[3073];

View File

@ -21,8 +21,11 @@
package org.datavec.image.loader; package org.datavec.image.loader;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
@ -34,7 +37,8 @@ import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestImageLoader { public class TestImageLoader {
private static long seed = 10; private static long seed = 10;

View File

@ -31,10 +31,13 @@ import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image; import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -60,6 +63,8 @@ import static org.junit.jupiter.api.Assertions.fail;
* @author saudet * @author saudet
*/ */
@Slf4j @Slf4j
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestNativeImageLoader { public class TestNativeImageLoader {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);

View File

@ -28,9 +28,12 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.loader.NativeImageLoader;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch; import org.nd4j.common.loader.FileBatch;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.io.File; import java.io.File;
@ -41,6 +44,8 @@ import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("File Batch Record Reader Test") @DisplayName("File Batch Record Reader Test")
@NativeTag
@Tag(TagNames.FILE_IO)
class FileBatchRecordReaderTest { class FileBatchRecordReaderTest {
@TempDir @TempDir

View File

@ -37,9 +37,12 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -54,7 +57,8 @@ import java.util.List;
import java.util.Random; import java.util.Random;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestImageRecordReader { public class TestImageRecordReader {

View File

@ -36,9 +36,12 @@ import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ResizeImageTransform;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
@ -54,7 +57,8 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestObjectDetectionRecordReader { public class TestObjectDetectionRecordReader {

View File

@ -22,10 +22,13 @@ package org.datavec.image.recordreader.objdetect;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.nio.file.Path; import java.nio.file.Path;
@ -34,7 +37,8 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestVocLabelProvider { public class TestVocLabelProvider {

View File

@ -20,6 +20,7 @@
package org.datavec.image.transform; package org.datavec.image.transform;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -28,8 +29,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Json Yaml Test") @DisplayName("Json Yaml Test")
@NativeTag
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JACKSON_SERDE)
class JsonYamlTest { class JsonYamlTest {
@Test @Test

View File

@ -22,12 +22,17 @@ package org.datavec.image.transform;
import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Frame;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Resize Image Transform Test") @DisplayName("Resize Image Transform Test")
@NativeTag
@Tag(TagNames.FILE_IO)
class ResizeImageTransformTest { class ResizeImageTransformTest {
@BeforeEach @BeforeEach

View File

@ -24,6 +24,7 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacv.CanvasFrame; import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Frame;
import org.bytedeco.javacv.OpenCVFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
import org.datavec.image.data.ImageWritable; import org.datavec.image.data.ImageWritable;
@ -37,6 +38,8 @@ import java.util.List;
import java.util.Random; import java.util.Random;
import org.bytedeco.opencv.opencv_core.*; import org.bytedeco.opencv.opencv_core.*;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*; import static org.bytedeco.opencv.global.opencv_imgproc.*;
@ -46,6 +49,8 @@ import static org.junit.jupiter.api.Assertions.*;
* *
* @author saudet * @author saudet
*/ */
@NativeTag
@Tag(TagNames.FILE_IO)
public class TestImageTransform { public class TestImageTransform {
static final long seed = 10; static final long seed = 10;
static final Random rng = new Random(seed); static final Random rng = new Random(seed);

View File

@ -22,6 +22,7 @@ package org.datavec.poi.excel;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import java.util.List; import java.util.List;
@ -29,8 +30,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Excel Record Reader Test") @DisplayName("Excel Record Reader Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
class ExcelRecordReaderTest { class ExcelRecordReaderTest {
@Test @Test

View File

@ -26,6 +26,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple; import org.nd4j.common.primitives.Triple;
@ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
@DisplayName("Excel Record Writer Test") @DisplayName("Excel Record Writer Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
class ExcelRecordWriterTest { class ExcelRecordWriterTest {
@TempDir @TempDir

View File

@ -47,17 +47,19 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path; import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.TagNames;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@DisplayName("Jdbc Record Reader Test") @DisplayName("Jdbc Record Reader Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
class JDBCRecordReaderTest { class JDBCRecordReaderTest {
@TempDir @TempDir

View File

@ -36,15 +36,18 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class LocalTransformProcessRecordReaderTests { public class LocalTransformProcessRecordReaderTests {
@Test @Test
@ -64,11 +67,11 @@ public class LocalTransformProcessRecordReaderTests {
public void simpleTransformTestSequence() { public void simpleTransformTestSequence() {
List<List<Writable>> sequence = new ArrayList<>(); List<List<Writable>> sequence = new ArrayList<>();
//First window: //First window:
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0),
new IntWritable(0))); new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1),
new IntWritable(0))); new IntWritable(0)));
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2),
new IntWritable(0))); new IntWritable(0)));
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)

View File

@ -30,8 +30,10 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.AnalyzeLocal; import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -40,7 +42,8 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestAnalyzeLocal { public class TestAnalyzeLocal {

View File

@ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.util.HashSet; import java.util.HashSet;
@ -38,7 +40,8 @@ import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestLineRecordReaderFunction { public class TestLineRecordReaderFunction {
@Test @Test

View File

@ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction; import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -34,7 +37,8 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@NativeTag
public class TestNDArrayToWritablesFunction { public class TestNDArrayToWritablesFunction {
@Test @Test
@ -50,7 +54,7 @@ public class TestNDArrayToWritablesFunction {
@Test @Test
public void testNDArrayToWritablesArray() throws Exception { public void testNDArrayToWritablesArray() throws Exception {
INDArray arr = Nd4j.arange(5); INDArray arr = Nd4j.arange(5);
List<Writable> expected = Arrays.asList((Writable) new NDArrayWritable(arr)); List<Writable> expected = Arrays.asList(new NDArrayWritable(arr));
List<Writable> actual = new NDArrayToWritablesFunction(true).apply(arr); List<Writable> actual = new NDArrayToWritablesFunction(true).apply(arr);
assertEquals(expected, actual); assertEquals(expected, actual);
} }

View File

@ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction; import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -34,7 +37,8 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@NativeTag
public class TestWritablesToNDArrayFunction { public class TestWritablesToNDArrayFunction {
@Test @Test

View File

@ -30,13 +30,16 @@ import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction; import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
import org.datavec.local.transforms.misc.WritablesToStringFunction; import org.datavec.local.transforms.misc.WritablesToStringFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestWritablesToStringFunctions { public class TestWritablesToStringFunctions {
@ -44,7 +47,7 @@ public class TestWritablesToStringFunctions {
@Test @Test
public void testWritablesToString() throws Exception { public void testWritablesToString() throws Exception {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")); List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue"));
String expected = l.get(0).toString() + "," + l.get(1).toString(); String expected = l.get(0).toString() + "," + l.get(1).toString();
assertEquals(expected, new WritablesToStringFunction(",").apply(l)); assertEquals(expected, new WritablesToStringFunction(",").apply(l));
@ -53,8 +56,8 @@ public class TestWritablesToStringFunctions {
@Test @Test
public void testSequenceWritablesToString() throws Exception { public void testSequenceWritablesToString() throws Exception {
List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")), List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")),
Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue"))); Arrays.asList(new DoubleWritable(2.5), new Text("otherValue")));
String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n" String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n"
+ l.get(1).get(0).toString() + "," + l.get(1).get(1).toString(); + l.get(1).get(0).toString() + "," + l.get(1).get(1).toString();

View File

@ -31,7 +31,10 @@ import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.linalg.ops.transforms.Transforms;
@ -42,6 +45,8 @@ import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout; import static org.junit.jupiter.api.Assertions.assertTimeout;
@DisplayName("Execution Test") @DisplayName("Execution Test")
@Tag(TagNames.FILE_IO)
@NativeTag
class ExecutionTest { class ExecutionTest {
@Test @Test
@ -71,18 +76,12 @@ class ExecutionTest {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build(); Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
List<List<Writable>> rdd = (inputData); List<List<Writable>> rdd = (inputData);
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp)); List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()));
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f))); expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f))); expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
@ -95,9 +94,9 @@ class ExecutionTest {
void testFilter() { void testFilter() {
Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build(); Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); inputData.add(Arrays.asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build(); TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess); List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
assertEquals(2, execute.size()); assertEquals(2, execute.size());
@ -110,31 +109,25 @@ class ExecutionTest {
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>(); List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>(); List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
List<List<Writable>> seq2 = new ArrayList<>(); List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1); inputSequences.add(seq1);
inputSequences.add(seq2); inputSequences.add(seq2);
List<List<List<Writable>>> rdd = (inputSequences); List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp); List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
Collections.sort(out, new Comparator<List<List<Writable>>>() { Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size()));
@Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size());
}
});
List<List<List<Writable>>> expectedSequence = new ArrayList<>(); List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>(); List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
List<List<Writable>> seq2e = new ArrayList<>(); List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e); expectedSequence.add(seq1e);
expectedSequence.add(seq2e); expectedSequence.add(seq2e);
assertEquals(expectedSequence, out); assertEquals(expectedSequence, out);
@ -143,26 +136,26 @@ class ExecutionTest {
@Test @Test
@DisplayName("Test Reduction Global") @DisplayName("Test Reduction Global")
void testReductionGlobal() { void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))); List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
List<List<Writable>> inData = in; List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build(); Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd; List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0))); List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test @Test
@DisplayName("Test Reduction By Key") @DisplayName("Test Reduction By Key")
void testReductionByKey() { void testReductionByKey() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
List<List<Writable>> inData = in; List<List<Writable>> inData = in;
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build(); Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd; List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out); out = new ArrayList<>(out);
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt())); Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
assertEquals(expOut, out); assertEquals(expOut, out);

View File

@ -28,12 +28,15 @@ import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestJoin { public class TestJoin {
@Test @Test
@ -46,27 +49,27 @@ public class TestJoin {
.addColumnDouble("amount").build(); .addColumnDouble("amount").build();
List<List<Writable>> infoList = new ArrayList<>(); List<List<Writable>> infoList = new ArrayList<>();
infoList.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"))); infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345")));
infoList.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765"))); infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765")));
infoList.add(Arrays.<Writable>asList(new LongWritable(50000), new Text("Customer50000"))); infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000")));
List<List<Writable>> purchaseList = new ArrayList<>(); List<List<Writable>> purchaseList = new ArrayList<>();
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000000), new LongWritable(12345), purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345),
new DoubleWritable(10.00))); new DoubleWritable(10.00)));
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000001), new LongWritable(12345), purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345),
new DoubleWritable(20.00))); new DoubleWritable(20.00)));
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000002), new LongWritable(98765), purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765),
new DoubleWritable(30.00))); new DoubleWritable(30.00)));
Join join = new Join.Builder(Join.JoinType.RightOuter).setJoinColumns("customerID") Join join = new Join.Builder(Join.JoinType.RightOuter).setJoinColumns("customerID")
.setSchemas(customerInfoSchema, purchasesSchema).build(); .setSchemas(customerInfoSchema, purchasesSchema).build();
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"), expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"),
new LongWritable(1000000), new DoubleWritable(10.00))); new LongWritable(1000000), new DoubleWritable(10.00)));
expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"), expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"),
new LongWritable(1000001), new DoubleWritable(20.00))); new LongWritable(1000001), new DoubleWritable(20.00)));
expected.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765"), expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"),
new LongWritable(1000002), new DoubleWritable(30.00))); new LongWritable(1000002), new DoubleWritable(30.00)));
@ -77,12 +80,7 @@ public class TestJoin {
List<List<Writable>> joined = LocalTransformExecutor.executeJoin(join, info, purchases); List<List<Writable>> joined = LocalTransformExecutor.executeJoin(join, info, purchases);
List<List<Writable>> joinedList = new ArrayList<>(joined); List<List<Writable>> joinedList = new ArrayList<>(joined);
//Sort by order ID (column 3, index 2) //Sort by order ID (column 3, index 2)
Collections.sort(joinedList, new Comparator<List<Writable>>() { Collections.sort(joinedList, (o1, o2) -> Long.compare(o1.get(2).toLong(), o2.get(2).toLong()));
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Long.compare(o1.get(2).toLong(), o2.get(2).toLong());
}
});
assertEquals(expected, joinedList); assertEquals(expected, joinedList);
assertEquals(3, joinedList.size()); assertEquals(3, joinedList.size());
@ -110,12 +108,7 @@ public class TestJoin {
List<List<Writable>> joined2 = LocalTransformExecutor.executeJoin(join2, purchases, info); List<List<Writable>> joined2 = LocalTransformExecutor.executeJoin(join2, purchases, info);
List<List<Writable>> joinedList2 = new ArrayList<>(joined2); List<List<Writable>> joinedList2 = new ArrayList<>(joined2);
//Sort by order ID (column 0) //Sort by order ID (column 0)
Collections.sort(joinedList2, new Comparator<List<Writable>>() { Collections.sort(joinedList2, (o1, o2) -> Long.compare(o1.get(0).toLong(), o2.get(0).toLong()));
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Long.compare(o1.get(0).toLong(), o2.get(0).toLong());
}
});
assertEquals(3, joinedList2.size()); assertEquals(3, joinedList2.size());
assertEquals(expectedManyToOne, joinedList2); assertEquals(expectedManyToOne, joinedList2);
@ -189,29 +182,26 @@ public class TestJoin {
new ArrayList<>(LocalTransformExecutor.executeJoin(join, firstRDD, secondRDD)); new ArrayList<>(LocalTransformExecutor.executeJoin(join, firstRDD, secondRDD));
//Sort output by column 0, then column 1, then column 2 for comparison to expected... //Sort output by column 0, then column 1, then column 2 for comparison to expected...
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, (o1, o2) -> {
@Override Writable w1 = o1.get(0);
public int compare(List<Writable> o1, List<Writable> o2) { Writable w2 = o2.get(0);
Writable w1 = o1.get(0); if (w1 instanceof NullWritable)
Writable w2 = o2.get(0); return 1;
if (w1 instanceof NullWritable) else if (w2 instanceof NullWritable)
return 1; return -1;
else if (w2 instanceof NullWritable) int c = Long.compare(w1.toLong(), w2.toLong());
return -1; if (c != 0)
int c = Long.compare(w1.toLong(), w2.toLong()); return c;
if (c != 0) c = o1.get(1).toString().compareTo(o2.get(1).toString());
return c; if (c != 0)
c = o1.get(1).toString().compareTo(o2.get(1).toString()); return c;
if (c != 0) w1 = o1.get(2);
return c; w2 = o2.get(2);
w1 = o1.get(2); if (w1 instanceof NullWritable)
w2 = o2.get(2); return 1;
if (w1 instanceof NullWritable) else if (w2 instanceof NullWritable)
return 1; return -1;
else if (w2 instanceof NullWritable) return Long.compare(w1.toLong(), w2.toLong());
return -1;
return Long.compare(w1.toLong(), w2.toLong());
}
}); });
switch (jt) { switch (jt) {

View File

@ -31,14 +31,17 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestCalculateSortedRank { public class TestCalculateSortedRank {
@Test @Test

View File

@ -31,7 +31,9 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -39,7 +41,8 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
public class TestConvertToSequence { public class TestConvertToSequence {
@Test @Test
@ -48,12 +51,12 @@ public class TestConvertToSequence {
Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build(); Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build();
List<List<Writable>> allExamples = List<List<Writable>> allExamples =
Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), Arrays.asList(new Text("k1a"), new Text("k2a"),
new LongWritable(-10)), new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
TransformProcess tp = new TransformProcess.Builder(s) TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time")) .convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time"))
@ -75,13 +78,13 @@ public class TestConvertToSequence {
} }
List<List<Writable>> expSeq0 = Arrays.asList( List<List<Writable>> expSeq0 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
List<List<Writable>> expSeq1 = Arrays.asList( List<List<Writable>> expSeq1 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
assertEquals(expSeq0, seq0); assertEquals(expSeq0, seq0);
assertEquals(expSeq1, seq1); assertEquals(expSeq1, seq1);
@ -96,9 +99,9 @@ public class TestConvertToSequence {
.build(); .build();
List<List<Writable>> allExamples = Arrays.asList( List<List<Writable>> allExamples = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new LongWritable(0)), Arrays.asList(new Text("a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("b"), new LongWritable(1)), Arrays.asList(new Text("b"), new LongWritable(1)),
Arrays.<Writable>asList(new Text("c"), new LongWritable(2))); Arrays.asList(new Text("c"), new LongWritable(2)));
TransformProcess tp = new TransformProcess.Builder(s) TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence() .convertToSequence()

View File

@ -25,8 +25,10 @@ import org.apache.spark.serializer.SerializerInstance;
import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -34,7 +36,10 @@ import java.nio.ByteBuffer;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestKryoSerialization extends BaseSparkTest { public class TestKryoSerialization extends BaseSparkTest {
@Override @Override

View File

@ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit; import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.util.HashSet; import java.util.HashSet;
@ -37,7 +39,10 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestLineRecordReaderFunction extends BaseSparkTest { public class TestLineRecordReaderFunction extends BaseSparkTest {
@Test @Test

View File

@ -24,7 +24,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.transform.misc.NDArrayToWritablesFunction; import org.datavec.spark.transform.misc.NDArrayToWritablesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +36,10 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
@NativeTag
public class TestNDArrayToWritablesFunction { public class TestNDArrayToWritablesFunction {
@Test @Test

View File

@ -39,10 +39,12 @@ import org.datavec.spark.functions.pairdata.PathToKeyConverter;
import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename; import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename;
import org.datavec.spark.util.DataVecSparkUtil; import org.datavec.spark.util.DataVecSparkUtil;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import scala.Tuple2; import scala.Tuple2;
import java.io.File; import java.io.File;
@ -53,7 +55,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest { public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
@Test @Test

View File

@ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction; import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.RecordReaderBytesFunction; import org.datavec.spark.functions.data.RecordReaderBytesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.nio.file.Files; import java.nio.file.Files;
@ -51,7 +53,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestRecordReaderBytesFunction extends BaseSparkTest { public class TestRecordReaderBytesFunction extends BaseSparkTest {

View File

@ -32,10 +32,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.nio.file.Path; import java.nio.file.Path;
@ -45,7 +47,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestRecordReaderFunction extends BaseSparkTest { public class TestRecordReaderFunction extends BaseSparkTest {
@Test @Test

View File

@ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.functions.data.FilesAsBytesFunction; import org.datavec.spark.functions.data.FilesAsBytesFunction;
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction; import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.nio.file.Files; import java.nio.file.Files;
@ -50,7 +52,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest { public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {

View File

@ -34,10 +34,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.codec.reader.CodecRecordReader; import org.datavec.codec.reader.CodecRecordReader;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.nio.file.Path; import java.nio.file.Path;
@ -46,7 +48,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSequenceRecordReaderFunction extends BaseSparkTest { public class TestSequenceRecordReaderFunction extends BaseSparkTest {

View File

@ -22,7 +22,10 @@ package org.datavec.spark.functions;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.spark.transform.misc.WritablesToNDArrayFunction; import org.datavec.spark.transform.misc.WritablesToNDArrayFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -31,7 +34,10 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
@NativeTag
public class TestWritablesToNDArrayFunction { public class TestWritablesToNDArrayFunction {
@Test @Test

View File

@ -29,7 +29,9 @@ import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction; import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction;
import org.datavec.spark.transform.misc.WritablesToStringFunction; import org.datavec.spark.transform.misc.WritablesToStringFunction;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import scala.Tuple2; import scala.Tuple2;
import java.util.ArrayList; import java.util.ArrayList;
@ -37,7 +39,10 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestWritablesToStringFunctions extends BaseSparkTest { public class TestWritablesToStringFunctions extends BaseSparkTest {
@Test @Test
@ -57,19 +62,9 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
JavaSparkContext sc = getContext(); JavaSparkContext sc = getContext();
JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() { JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2);
@Override
public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception {
return stringStringTuple2;
}
});
JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() { JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2);
@Override
public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception {
return stringStringTuple2;
}
});
System.out.println(left.cogroup(right).collect()); System.out.println(left.cogroup(right).collect());
} }
@ -77,7 +72,7 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
@Test @Test
public void testWritablesToString() throws Exception { public void testWritablesToString() throws Exception {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")); List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue"));
String expected = l.get(0).toString() + "," + l.get(1).toString(); String expected = l.get(0).toString() + "," + l.get(1).toString();
assertEquals(expected, new WritablesToStringFunction(",").call(l)); assertEquals(expected, new WritablesToStringFunction(",").call(l));
@ -86,8 +81,8 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
@Test @Test
public void testSequenceWritablesToString() throws Exception { public void testSequenceWritablesToString() throws Exception {
List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")), List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")),
Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue"))); Arrays.asList(new DoubleWritable(2.5), new Text("otherValue")));
String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n" String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n"
+ l.get(1).get(0).toString() + "," + l.get(1).get(1).toString(); + l.get(1).get(0).toString() + "," + l.get(1).get(1).toString();

View File

@ -21,6 +21,8 @@
package org.datavec.spark.storage; package org.datavec.spark.storage;
import com.sun.jna.Platform; import com.sun.jna.Platform;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
@ -37,7 +39,10 @@ import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSparkStorageUtils extends BaseSparkTest { public class TestSparkStorageUtils extends BaseSparkTest {
@Test @Test
@ -46,11 +51,11 @@ public class TestSparkStorageUtils extends BaseSparkTest {
return; return;
} }
List<List<Writable>> l = new ArrayList<>(); List<List<Writable>> l = new ArrayList<>();
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0), l.add(Arrays.asList(new Text("zero"), new IntWritable(0),
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0)))); new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))));
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(11), l.add(Arrays.asList(new Text("one"), new IntWritable(11),
new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0)))); new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))));
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(22), l.add(Arrays.asList(new Text("two"), new IntWritable(22),
new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0)))); new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0))));
JavaRDD<List<Writable>> rdd = sc.parallelize(l); JavaRDD<List<Writable>> rdd = sc.parallelize(l);
@ -92,17 +97,17 @@ public class TestSparkStorageUtils extends BaseSparkTest {
} }
List<List<List<Writable>>> l = new ArrayList<>(); List<List<List<Writable>>> l = new ArrayList<>();
l.add(Arrays.asList( l.add(Arrays.asList(
Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0), Arrays.asList(new Text("zero"), new IntWritable(0),
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))), new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(1), Arrays.asList(new Text("one"), new IntWritable(1),
new DoubleWritable(1.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 1.0))), new DoubleWritable(1.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 1.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(2), Arrays.asList(new Text("two"), new IntWritable(2),
new DoubleWritable(2.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 2.0))))); new DoubleWritable(2.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 2.0)))));
l.add(Arrays.asList( l.add(Arrays.asList(
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bzero"), new IntWritable(10), Arrays.asList(new Text("Bzero"), new IntWritable(10),
new DoubleWritable(10), new NDArrayWritable(Nd4j.valueArrayOf(10, 10.0))), new DoubleWritable(10), new NDArrayWritable(Nd4j.valueArrayOf(10, 10.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bone"), new IntWritable(11), Arrays.asList(new Text("Bone"), new IntWritable(11),
new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))), new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))),
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Btwo"), new IntWritable(12), Arrays.<org.datavec.api.writable.Writable>asList(new Text("Btwo"), new IntWritable(12),
new DoubleWritable(12.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 12.0))))); new DoubleWritable(12.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 12.0)))));

View File

@ -30,14 +30,19 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class DataFramesTests extends BaseSparkTest { public class DataFramesTests extends BaseSparkTest {
@Test @Test
@ -110,15 +115,15 @@ public class DataFramesTests extends BaseSparkTest {
public void testNormalize() { public void testNormalize() {
List<List<Writable>> data = new ArrayList<>(); List<List<Writable>> data = new ArrayList<>();
data.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10))); data.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10)));
data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20))); data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20)));
data.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30))); data.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30)));
List<List<Writable>> expMinMax = new ArrayList<>(); List<List<Writable>> expMinMax = new ArrayList<>();
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0))); expMinMax.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new DoubleWritable(0.5))); expMinMax.add(Arrays.asList(new DoubleWritable(0.5), new DoubleWritable(0.5)));
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(1.0))); expMinMax.add(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(1.0)));
double m1 = (1 + 2 + 3) / 3.0; double m1 = (1 + 2 + 3) / 3.0;
double s1 = new StandardDeviation().evaluate(new double[] {1, 2, 3}); double s1 = new StandardDeviation().evaluate(new double[] {1, 2, 3});
@ -127,11 +132,11 @@ public class DataFramesTests extends BaseSparkTest {
List<List<Writable>> expStandardize = new ArrayList<>(); List<List<Writable>> expStandardize = new ArrayList<>();
expStandardize.add( expStandardize.add(
Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2))); Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2)));
expStandardize.add( expStandardize.add(
Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2))); Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2)));
expStandardize.add( expStandardize.add(
Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2))); Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2)));
JavaRDD<List<Writable>> rdd = sc.parallelize(data); JavaRDD<List<Writable>> rdd = sc.parallelize(data);
@ -178,13 +183,13 @@ public class DataFramesTests extends BaseSparkTest {
List<List<List<Writable>>> sequences = new ArrayList<>(); List<List<List<Writable>>> sequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>(); List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100))); seq1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100)));
seq1.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200))); seq1.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200)));
seq1.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300))); seq1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300)));
List<List<Writable>> seq2 = new ArrayList<>(); List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400))); seq2.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400)));
seq2.add(Arrays.<Writable>asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500))); seq2.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500)));
sequences.add(seq1); sequences.add(seq1);
sequences.add(seq2); sequences.add(seq2);
@ -199,21 +204,21 @@ public class DataFramesTests extends BaseSparkTest {
//Min/max normalization: //Min/max normalization:
List<List<Writable>> expSeq1MinMax = new ArrayList<>(); List<List<Writable>> expSeq1MinMax = new ArrayList<>();
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)), expSeq1MinMax.add(Arrays.asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((10 - 10.0) / (50.0 - 10.0)), new DoubleWritable((10 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((100 - 100.0) / (500.0 - 100.0)))); new DoubleWritable((100 - 100.0) / (500.0 - 100.0))));
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)), expSeq1MinMax.add(Arrays.asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((20 - 10.0) / (50.0 - 10.0)), new DoubleWritable((20 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((200 - 100.0) / (500.0 - 100.0)))); new DoubleWritable((200 - 100.0) / (500.0 - 100.0))));
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)), expSeq1MinMax.add(Arrays.asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((30 - 10.0) / (50.0 - 10.0)), new DoubleWritable((30 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((300 - 100.0) / (500.0 - 100.0)))); new DoubleWritable((300 - 100.0) / (500.0 - 100.0))));
List<List<Writable>> expSeq2MinMax = new ArrayList<>(); List<List<Writable>> expSeq2MinMax = new ArrayList<>();
expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)), expSeq2MinMax.add(Arrays.asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((40 - 10.0) / (50.0 - 10.0)), new DoubleWritable((40 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((400 - 100.0) / (500.0 - 100.0)))); new DoubleWritable((400 - 100.0) / (500.0 - 100.0))));
expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)), expSeq2MinMax.add(Arrays.asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)),
new DoubleWritable((50 - 10.0) / (50.0 - 10.0)), new DoubleWritable((50 - 10.0) / (50.0 - 10.0)),
new DoubleWritable((500 - 100.0) / (500.0 - 100.0)))); new DoubleWritable((500 - 100.0) / (500.0 - 100.0))));
@ -246,17 +251,17 @@ public class DataFramesTests extends BaseSparkTest {
double s3 = new StandardDeviation().evaluate(new double[] {100, 200, 300, 400, 500}); double s3 = new StandardDeviation().evaluate(new double[] {100, 200, 300, 400, 500});
List<List<Writable>> expSeq1Std = new ArrayList<>(); List<List<Writable>> expSeq1Std = new ArrayList<>();
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2), expSeq1Std.add(Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2),
new DoubleWritable((100 - m3) / s3))); new DoubleWritable((100 - m3) / s3)));
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2), expSeq1Std.add(Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2),
new DoubleWritable((200 - m3) / s3))); new DoubleWritable((200 - m3) / s3)));
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2), expSeq1Std.add(Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2),
new DoubleWritable((300 - m3) / s3))); new DoubleWritable((300 - m3) / s3)));
List<List<Writable>> expSeq2Std = new ArrayList<>(); List<List<Writable>> expSeq2Std = new ArrayList<>();
expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2), expSeq2Std.add(Arrays.asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2),
new DoubleWritable((400 - m3) / s3))); new DoubleWritable((400 - m3) / s3)));
expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2), expSeq2Std.add(Arrays.asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2),
new DoubleWritable((500 - m3) / s3))); new DoubleWritable((500 - m3) / s3)));

View File

@ -31,22 +31,24 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text; import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.spark.BaseSparkTest;
import org.datavec.python.PythonTransform; import org.datavec.python.PythonTransform;
import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*;
import static java.time.Duration.ofMillis; import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout; import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Execution Test") @DisplayName("Execution Test")
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
class ExecutionTest extends BaseSparkTest { class ExecutionTest extends BaseSparkTest {
@Test @Test
@ -55,22 +57,16 @@ class ExecutionTest extends BaseSparkTest {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out); assertEquals(expected, out);
} }
@ -81,31 +77,25 @@ class ExecutionTest extends BaseSparkTest {
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>(); List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>(); List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
List<List<Writable>> seq2 = new ArrayList<>(); List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
inputSequences.add(seq1); inputSequences.add(seq1);
inputSequences.add(seq2); inputSequences.add(seq2);
JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences); JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences);
List<List<List<Writable>>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect()); List<List<List<Writable>>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect());
Collections.sort(out, new Comparator<List<List<Writable>>>() { Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size()));
@Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size());
}
});
List<List<List<Writable>>> expectedSequence = new ArrayList<>(); List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>(); List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
List<List<Writable>> seq2e = new ArrayList<>(); List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e); expectedSequence.add(seq1e);
expectedSequence.add(seq2e); expectedSequence.add(seq2e);
assertEquals(expectedSequence, out); assertEquals(expectedSequence, out);
@ -114,34 +104,28 @@ class ExecutionTest extends BaseSparkTest {
@Test @Test
@DisplayName("Test Reduction Global") @DisplayName("Test Reduction Global")
void testReductionGlobal() { void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))); List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
JavaRDD<List<Writable>> inData = sc.parallelize(in); JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build(); Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp); JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect(); List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0))); List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@Test @Test
@DisplayName("Test Reduction By Key") @DisplayName("Test Reduction By Key")
void testReductionByKey() { void testReductionByKey() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
JavaRDD<List<Writable>> inData = sc.parallelize(in); JavaRDD<List<Writable>> inData = sc.parallelize(in);
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build(); Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp); JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd.collect(); List<List<Writable>> out = outRdd.collect();
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out); out = new ArrayList<>(out);
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()));
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
assertEquals(expOut, out); assertEquals(expOut, out);
} }
@ -150,15 +134,15 @@ class ExecutionTest extends BaseSparkTest {
void testUniqueMultiCol() { void testUniqueMultiCol() {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
Map<String, List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd); Map<String, List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);
assertEquals(2, l.size()); assertEquals(2, l.size());
@ -180,58 +164,20 @@ class ExecutionTest extends BaseSparkTest {
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0"; String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build(); TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build();
List<List<Writable>> inputData = new ArrayList<>(); List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() { Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>(); List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
assertEquals(expected, out); assertEquals(expected, out);
}); });
} }
@Test
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
@DisplayName("Test Python Execution With ND Arrays")
void testPythonExecutionWithNDArrays() {
assertTimeout(ofMillis(60000), () -> {
long[] shape = new long[] { 3, 2 };
Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).build();
Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
INDArray zeros = Nd4j.zeros(shape);
INDArray ones = Nd4j.ones(shape);
INDArray twos = ones.add(ones);
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones)));
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos)));
});
}
@Test @Test
@DisplayName("Test First Digit Transform Benfords Law") @DisplayName("Test First Digit Transform Benfords Law")

View File

@ -28,7 +28,10 @@ import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -43,7 +46,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
@NativeTag
public class NormalizationTests extends BaseSparkTest { public class NormalizationTests extends BaseSparkTest {

View File

@ -38,7 +38,9 @@ import org.datavec.local.transforms.AnalyzeLocal;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.AnalyzeSpark; import org.datavec.spark.transform.AnalyzeSpark;
import org.joda.time.DateTimeZone; import org.joda.time.DateTimeZone;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
@ -48,7 +50,10 @@ import java.nio.file.Files;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.*;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestAnalysis extends BaseSparkTest { public class TestAnalysis extends BaseSparkTest {
@Test @Test

View File

@ -27,12 +27,17 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*; import org.datavec.api.writable.*;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.*; import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestJoin extends BaseSparkTest { public class TestJoin extends BaseSparkTest {
@Test @Test

View File

@ -30,24 +30,29 @@ import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestCalculateSortedRank extends BaseSparkTest { public class TestCalculateSortedRank extends BaseSparkTest {
@Test @Test
public void testCalculateSortedRank() { public void testCalculateSortedRank() {
List<List<Writable>> data = new ArrayList<>(); List<List<Writable>> data = new ArrayList<>();
data.add(Arrays.asList((Writable) new Text("0"), new DoubleWritable(0.0))); data.add(Arrays.asList(new Text("0"), new DoubleWritable(0.0)));
data.add(Arrays.asList((Writable) new Text("3"), new DoubleWritable(0.3))); data.add(Arrays.asList(new Text("3"), new DoubleWritable(0.3)));
data.add(Arrays.asList((Writable) new Text("2"), new DoubleWritable(0.2))); data.add(Arrays.asList(new Text("2"), new DoubleWritable(0.2)));
data.add(Arrays.asList((Writable) new Text("1"), new DoubleWritable(0.1))); data.add(Arrays.asList(new Text("1"), new DoubleWritable(0.1)));
JavaRDD<List<Writable>> rdd = sc.parallelize(data); JavaRDD<List<Writable>> rdd = sc.parallelize(data);

View File

@ -29,7 +29,9 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.SparkTransformExecutor;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
@ -37,7 +39,10 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
@Tag(TagNames.FILE_IO)
@Tag(TagNames.JAVA_ONLY)
@Tag(TagNames.SPARK)
@Tag(TagNames.DIST_SYSTEMS)
public class TestConvertToSequence extends BaseSparkTest { public class TestConvertToSequence extends BaseSparkTest {
@Test @Test
@ -45,13 +50,13 @@ public class TestConvertToSequence extends BaseSparkTest {
Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build(); Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build();
List<List<Writable>> allExamples = List<List<Writable>> allExamples;
Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), allExamples = Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), Arrays.asList(new Text("k1a"), new Text("k2a"),
new LongWritable(-10)), new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
TransformProcess tp = new TransformProcess.Builder(s) TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time")) .convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time"))
@ -73,13 +78,13 @@ public class TestConvertToSequence extends BaseSparkTest {
} }
List<List<Writable>> expSeq0 = Arrays.asList( List<List<Writable>> expSeq0 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
List<List<Writable>> expSeq1 = Arrays.asList( List<List<Writable>> expSeq1 = Arrays.asList(
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
assertEquals(expSeq0, seq0); assertEquals(expSeq0, seq0);
assertEquals(expSeq1, seq1); assertEquals(expSeq1, seq1);
@ -94,9 +99,9 @@ public class TestConvertToSequence extends BaseSparkTest {
.build(); .build();
List<List<Writable>> allExamples = Arrays.asList( List<List<Writable>> allExamples = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new LongWritable(0)), Arrays.asList(new Text("a"), new LongWritable(0)),
Arrays.<Writable>asList(new Text("b"), new LongWritable(1)), Arrays.asList(new Text("b"), new LongWritable(1)),
Arrays.<Writable>asList(new Text("c"), new LongWritable(2))); Arrays.asList(new Text("c"), new LongWritable(2)));
TransformProcess tp = new TransformProcess.Builder(s) TransformProcess tp = new TransformProcess.Builder(s)
.convertToSequence() .convertToSequence()

View File

@ -28,7 +28,9 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable; import org.datavec.api.writable.Writable;
import org.datavec.spark.BaseSparkTest; import org.datavec.spark.BaseSparkTest;
import org.datavec.spark.transform.utils.SparkUtils; import org.datavec.spark.transform.utils.SparkUtils;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.TagNames;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
@ -38,6 +40,9 @@ import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
@Tag(TagNames.SPARK)
@Tag(TagNames.FILE_IO)
@Tag(TagNames.DIST_SYSTEMS)
public class TestSparkUtil extends BaseSparkTest { public class TestSparkUtil extends BaseSparkTest {
@Test @Test
@ -46,8 +51,8 @@ public class TestSparkUtil extends BaseSparkTest {
return; return;
} }
List<List<Writable>> l = new ArrayList<>(); List<List<Writable>> l = new ArrayList<>();
l.add(Arrays.<Writable>asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1))); l.add(Arrays.asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1)));
l.add(Arrays.<Writable>asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2))); l.add(Arrays.asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2)));
File f = File.createTempFile("testSparkUtil", "txt"); File f = File.createTempFile("testSparkUtil", "txt");
f.deleteOnExit(); f.deleteOnExit();

View File

@ -27,8 +27,11 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources; import org.nd4j.common.resources.Resources;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -39,6 +42,8 @@ import java.nio.file.Files;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@Disabled @Disabled
@NativeTag
@Tag(TagNames.RNG)
public class RandomTests extends BaseDL4JTest { public class RandomTests extends BaseDL4JTest {
@Test @Test

View File

@ -23,11 +23,11 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.base.MnistFetcher; import org.deeplearning4j.datasets.base.MnistFetcher;
import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.*;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir; import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
@ -41,10 +41,13 @@ import java.util.Set;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Mnist Fetcher Test") @DisplayName("Mnist Fetcher Test")
@NativeTag
@Tag(TagNames.FILE_IO)
@Tag(TagNames.NDARRAY_ETL)
class MnistFetcherTest extends BaseDL4JTest { class MnistFetcherTest extends BaseDL4JTest {

View File

@ -23,8 +23,14 @@ package org.deeplearning4j.datasets;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher; import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher; import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
@NativeTag
@Tag(TagNames.FILE_IO)
@Tag(TagNames.NDARRAY_ETL)
public class TestDataSets extends BaseDL4JTest { public class TestDataSets extends BaseDL4JTest {
@Override @Override

View File

@ -20,8 +20,11 @@
package org.deeplearning4j.datasets.datavec; package org.deeplearning4j.datasets.datavec;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -76,6 +79,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j @Slf4j
@DisplayName("Record Reader Data Setiterator Test") @DisplayName("Record Reader Data Setiterator Test")
@Disabled @Disabled
@NativeTag
class RecordReaderDataSetiteratorTest extends BaseDL4JTest { class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@Override @Override
@ -148,6 +152,7 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
@ParameterizedTest @ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Sequence Record Reader") @DisplayName("Test Sequence Record Reader")
@Tag(TagNames.NDARRAY_INDEXING)
void testSequenceRecordReader(Nd4jBackend backend) throws Exception { void testSequenceRecordReader(Nd4jBackend backend) throws Exception {
File rootDir = temporaryFolder.toFile(); File rootDir = temporaryFolder.toFile();
// need to manually extract // need to manually extract

View File

@ -21,6 +21,8 @@ package org.deeplearning4j.datasets.datavec;
import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Tag;
import org.nd4j.common.tests.tags.TagNames;
import org.nd4j.shade.guava.io.Files; import org.nd4j.shade.guava.io.Files;
import org.apache.commons.io.FileUtils; import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.FilenameUtils;
@ -70,6 +72,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
@DisplayName("Record Reader Multi Data Set Iterator Test") @DisplayName("Record Reader Multi Data Set Iterator Test")
@Disabled @Disabled
@Tag(TagNames.FILE_IO)
class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
@TempDir @TempDir

View File

@ -21,6 +21,7 @@ package org.deeplearning4j.datasets.fetchers;
import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.BaseDL4JTest;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
@ -28,11 +29,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue; import static org.junit.jupiter.api.Assumptions.assumeTrue;
import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.common.tests.tags.TagNames;
/** /**
* @author saudet * @author saudet
*/ */
@DisplayName("Svhn Data Fetcher Test") @DisplayName("Svhn Data Fetcher Test")
@Tag(TagNames.FILE_IO)
@NativeTag
class SvhnDataFetcherTest extends BaseDL4JTest { class SvhnDataFetcherTest extends BaseDL4JTest {
@Override @Override

View File

@ -26,6 +26,7 @@ import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator;
import org.deeplearning4j.nn.util.TestDataSetConsumer; import org.deeplearning4j.nn.util.TestDataSetConsumer;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.tags.NativeTag;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
@ -40,6 +41,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
@Slf4j @Slf4j
@DisplayName("Async Data Set Iterator Test") @DisplayName("Async Data Set Iterator Test")
@NativeTag
class AsyncDataSetIteratorTest extends BaseDL4JTest { class AsyncDataSetIteratorTest extends BaseDL4JTest {
private ExistingDataSetIterator backIterator; private ExistingDataSetIterator backIterator;

Some files were not shown because too many files have changed in this diff Show More