Add tags for junit 5
parent
3c205548af
commit
5e8951cd8e
|
@ -12,7 +12,7 @@ DL4J was a junit 4 based code based for testing.
|
|||
It's now based on junit 5's jupiter API, which has support for [Tags](https://junit.org/junit5/docs/5.0.1/api/org/junit/jupiter/api/Tag.html).
|
||||
|
||||
DL4j's code base has a number of different kinds of tests that fall in to several categories:
|
||||
1. Long and flaky involving distributed systems (spark, parameter server)
|
||||
1. Long and flaky involving distributed systems (spark, parameter-server)
|
||||
2. Code that requires large downloads, but runs quickly
|
||||
3. Quick tests that test basic functionality
|
||||
4. Comprehensive integration tests that test several parts of a code base
|
||||
|
@ -38,8 +38,10 @@ A few kinds of tags exist:
|
|||
3. Distributed systems: spark, multi-threaded
|
||||
4. Functional cross-cutting concerns: multi module tests, similar functionality (excludes time based)
|
||||
5. Platform specific tests that can vary on different hardware: cpu, gpu
|
||||
6. JVM crash: Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash
|
||||
|
||||
6. JVM crash: (jvm-crash) Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash
|
||||
7. RNG: (rng) for RNG related tests
|
||||
8. Samediff:(samediff) samediff related tests
|
||||
9. Training related functionality
|
||||
|
||||
|
||||
## Consequences
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
@ -38,8 +39,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Csv Line Sequence Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@TempDir
|
||||
|
@ -54,8 +58,8 @@ class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
|
|||
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
|
||||
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
|
||||
rr.initialize(new FileSplit(source));
|
||||
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
|
||||
List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
|
||||
List<List<Writable>> exp0 = Arrays.asList(Collections.singletonList(new Text("a")), Collections.singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
|
||||
List<List<Writable>> exp1 = Arrays.asList(Collections.singletonList(new Text("1")), Collections.singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
|
||||
for (int i = 0; i < 3; i++) {
|
||||
int count = 0;
|
||||
while (rr.hasNext()) {
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.datavec.api.writable.Text;
|
|||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
@ -41,8 +42,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Csv Multi Sequence Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
|
|||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -34,12 +35,15 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Csvn Lines Sequence Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
@DisplayName("Test CSVN Lines Sequence Record Reader")
|
||||
@DisplayName("Test CSV Lines Sequence Record Reader")
|
||||
void testCSVNLinesSequenceRecordReader() throws Exception {
|
||||
int nLinesPerSequence = 10;
|
||||
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
|
||||
|
|
|
@ -34,9 +34,12 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
|
@ -50,6 +53,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
|
||||
|
||||
@DisplayName("Csv Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class CSVRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.datavec.api.split.InputSplit;
|
|||
import org.datavec.api.split.NumberedFileInputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
@ -43,8 +44,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Csv Sequence Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class CSVSequenceRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|||
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -32,8 +33,11 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Csv Variable Sliding Window Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
|
|||
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
|
@ -42,9 +43,12 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
|
||||
@DisplayName("File Batch Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class FileBatchRecordReaderTest extends BaseND4JTest {
|
||||
@TempDir Path testDir;
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.datavec.api.split.CollectionInputSplit;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.split.InputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("File Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class FileRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -28,10 +28,12 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import java.io.File;
|
||||
|
@ -45,6 +47,8 @@ import java.nio.file.Path;
|
|||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@DisplayName("Jackson Line Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class JacksonLineRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -31,10 +31,12 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
|
||||
|
@ -51,6 +53,8 @@ import java.nio.file.Path;
|
|||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@DisplayName("Jackson Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class JacksonRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@DisplayName("Lib Svm Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class LibSvmRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.datavec.api.split.InputSplit;
|
|||
import org.datavec.api.split.InputStreamInputSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
@ -47,8 +48,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Line Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class LineReaderTest extends BaseND4JTest {
|
||||
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.datavec.api.split.NumberedFileInputSplit;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
@ -46,8 +47,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Regex Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class RegexRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@DisplayName("Svm Light Record Reader Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class SVMLightRecordReaderTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -26,15 +26,18 @@ import org.datavec.api.records.reader.SequenceRecordReader;
|
|||
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestCollectionRecordReaders extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -23,12 +23,15 @@ package org.datavec.api.records.reader.impl;
|
|||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestConcatenatingRecordReader extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -37,9 +37,11 @@ import org.datavec.api.transform.TransformProcess;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.jackson.core.JsonFactory;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
|
||||
|
@ -48,7 +50,8 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestSerialization extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -30,9 +30,11 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.LongWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -41,6 +43,8 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TransformProcessRecordReaderTests extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
@ -74,11 +78,11 @@ public class TransformProcessRecordReaderTests extends BaseND4JTest {
|
|||
public void simpleTransformTestSequence() {
|
||||
List<List<Writable>> sequence = new ArrayList<>();
|
||||
//First window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0),
|
||||
new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1),
|
||||
new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2),
|
||||
new IntWritable(0)));
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.io.File;
|
||||
|
@ -34,8 +35,11 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Csv Record Writer Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class CSVRecordWriterTest extends BaseND4JTest {
|
||||
|
||||
@BeforeEach
|
||||
|
|
|
@ -29,8 +29,10 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -46,6 +48,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
|||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@DisplayName("Lib Svm Record Writer Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class LibSvmRecordWriterTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -26,8 +26,10 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -43,6 +45,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
|||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@DisplayName("Svm Light Record Writer Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class SVMLightRecordWriterTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
|
||||
package org.datavec.api.split;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.datavec.api.io.filters.BalancedPathFilter;
|
||||
import org.datavec.api.io.filters.RandomPathFilter;
|
||||
|
@ -42,6 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
*
|
||||
* @author saudet
|
||||
*/
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class InputSplitTests extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -20,13 +20,16 @@
|
|||
|
||||
package org.datavec.api.split;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.net.URI;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class NumberedFileInputSplitTests extends BaseND4JTest {
|
||||
@Test
|
||||
public void testNumberedFileInputSplitBasic() {
|
||||
|
|
|
@ -26,11 +26,13 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.function.Function;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
|
@ -46,7 +48,8 @@ import java.util.Random;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertNotEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestStreamInputSplit extends BaseND4JTest {
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
*/
|
||||
package org.datavec.api.split;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.net.URI;
|
||||
|
@ -28,11 +29,14 @@ import static java.util.Arrays.asList;
|
|||
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
/**
|
||||
* @author Ede Meijer
|
||||
*/
|
||||
@DisplayName("Transform Split Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class TransformSplitTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
|
||||
package org.datavec.api.split.parittion;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.datavec.api.conf.Configuration;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
|
@ -33,7 +35,8 @@ import java.io.File;
|
|||
import java.io.OutputStream;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class PartitionerTests extends BaseND4JTest {
|
||||
@Test
|
||||
public void testRecordsPerFilePartition() {
|
||||
|
|
|
@ -29,13 +29,16 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestTransformProcess extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -27,14 +27,17 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.transform.TestTransforms;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestConditions extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -27,8 +27,10 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -38,7 +40,8 @@ import java.util.List;
|
|||
import static java.util.Arrays.asList;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestFilters extends BaseND4JTest {
|
||||
|
||||
|
||||
|
|
|
@ -26,9 +26,11 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.NullWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
|
@ -37,7 +39,8 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestJoin extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -33,14 +33,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.*;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestMultiOpReduce extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -24,14 +24,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
|
|||
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestReductions extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -22,11 +22,15 @@ package org.datavec.api.transform.schema;
|
|||
|
||||
import org.datavec.api.transform.metadata.ColumnMetaData;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JACKSON_SERDE)
|
||||
public class TestJsonYaml extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -21,11 +21,14 @@
|
|||
package org.datavec.api.transform.schema;
|
||||
|
||||
import org.datavec.api.transform.ColumnType;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestSchemaMethods extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -33,8 +33,10 @@ import org.datavec.api.writable.LongWritable;
|
|||
import org.datavec.api.writable.NullWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -42,7 +44,8 @@ import java.util.List;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -27,8 +27,10 @@ import org.datavec.api.writable.LongWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -36,7 +38,8 @@ import java.util.List;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestSequenceSplit extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
@ -46,13 +49,13 @@ public class TestSequenceSplit extends BaseND4JTest {
|
|||
.build();
|
||||
|
||||
List<List<Writable>> inputSequence = new ArrayList<>();
|
||||
inputSequence.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0")));
|
||||
inputSequence.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1")));
|
||||
inputSequence.add(Arrays.asList(new LongWritable(0), new Text("t0")));
|
||||
inputSequence.add(Arrays.asList(new LongWritable(1000), new Text("t1")));
|
||||
//Second split: 74 seconds later
|
||||
inputSequence.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2")));
|
||||
inputSequence.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3")));
|
||||
inputSequence.add(Arrays.asList(new LongWritable(75000), new Text("t2")));
|
||||
inputSequence.add(Arrays.asList(new LongWritable(100000), new Text("t3")));
|
||||
//Third split: 1 minute and 1 milliseconds later
|
||||
inputSequence.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4")));
|
||||
inputSequence.add(Arrays.asList(new LongWritable(160001), new Text("t4")));
|
||||
|
||||
SequenceSplit seqSplit = new SequenceSplitTimeSeparation("time", 1, TimeUnit.MINUTES);
|
||||
seqSplit.setInputSchema(schema);
|
||||
|
@ -61,13 +64,13 @@ public class TestSequenceSplit extends BaseND4JTest {
|
|||
assertEquals(3, splits.size());
|
||||
|
||||
List<List<Writable>> exp0 = new ArrayList<>();
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0")));
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1")));
|
||||
exp0.add(Arrays.asList(new LongWritable(0), new Text("t0")));
|
||||
exp0.add(Arrays.asList(new LongWritable(1000), new Text("t1")));
|
||||
List<List<Writable>> exp1 = new ArrayList<>();
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2")));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3")));
|
||||
exp1.add(Arrays.asList(new LongWritable(75000), new Text("t2")));
|
||||
exp1.add(Arrays.asList(new LongWritable(100000), new Text("t3")));
|
||||
List<List<Writable>> exp2 = new ArrayList<>();
|
||||
exp2.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4")));
|
||||
exp2.add(Arrays.asList(new LongWritable(160001), new Text("t4")));
|
||||
|
||||
assertEquals(exp0, splits.get(0));
|
||||
assertEquals(exp1, splits.get(1));
|
||||
|
|
|
@ -29,8 +29,10 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.LongWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -38,7 +40,8 @@ import java.util.List;
|
|||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestWindowFunctions extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
@ -49,15 +52,15 @@ public class TestWindowFunctions extends BaseND4JTest {
|
|||
//Create some data.
|
||||
List<List<Writable>> sequence = new ArrayList<>();
|
||||
//First window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
|
||||
//Second window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
|
||||
//Third window: empty
|
||||
//Fourth window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
|
||||
.addColumnInteger("intcolumn").build();
|
||||
|
@ -100,15 +103,15 @@ public class TestWindowFunctions extends BaseND4JTest {
|
|||
//Create some data.
|
||||
List<List<Writable>> sequence = new ArrayList<>();
|
||||
//First window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
|
||||
//Second window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
|
||||
//Third window: empty
|
||||
//Fourth window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
|
||||
.addColumnInteger("intcolumn").build();
|
||||
|
@ -150,15 +153,15 @@ public class TestWindowFunctions extends BaseND4JTest {
|
|||
//Create some data.
|
||||
List<List<Writable>> sequence = new ArrayList<>();
|
||||
//First window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2)));
|
||||
//Second window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)));
|
||||
//Third window: empty
|
||||
//Fourth window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5)));
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
|
||||
.addColumnInteger("intcolumn").build();
|
||||
|
@ -188,13 +191,13 @@ public class TestWindowFunctions extends BaseND4JTest {
|
|||
//Create some data.
|
||||
List<List<Writable>> sequence = new ArrayList<>();
|
||||
//First window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
|
||||
sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
|
||||
sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
|
||||
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
|
||||
|
@ -207,32 +210,32 @@ public class TestWindowFunctions extends BaseND4JTest {
|
|||
|
||||
//First window: -1000 to 1000
|
||||
List<List<Writable>> exp0 = new ArrayList<>();
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
|
||||
exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
|
||||
exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
|
||||
exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
|
||||
//Second window: 0 to 2000
|
||||
List<List<Writable>> exp1 = new ArrayList<>();
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
|
||||
exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
|
||||
exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
|
||||
exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
|
||||
exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
|
||||
exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
|
||||
//Third window: 1000 to 3000
|
||||
List<List<Writable>> exp2 = new ArrayList<>();
|
||||
exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
|
||||
exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
|
||||
exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
|
||||
exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
|
||||
exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
|
||||
exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
|
||||
//Fourth window: 2000 to 4000
|
||||
List<List<Writable>> exp3 = new ArrayList<>();
|
||||
exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
|
||||
exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
|
||||
//Fifth window: 3000 to 5000
|
||||
List<List<Writable>> exp4 = new ArrayList<>();
|
||||
//Sixth window: 4000 to 6000
|
||||
List<List<Writable>> exp5 = new ArrayList<>();
|
||||
exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
|
||||
exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
|
||||
//Seventh window: 5000 to 7000
|
||||
List<List<Writable>> exp6 = new ArrayList<>();
|
||||
exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
|
||||
exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
|
||||
|
||||
List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp4, exp5, exp6);
|
||||
|
||||
|
@ -250,13 +253,13 @@ public class TestWindowFunctions extends BaseND4JTest {
|
|||
//Create some data.
|
||||
List<List<Writable>> sequence = new ArrayList<>();
|
||||
//First window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
|
||||
sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
|
||||
sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
|
||||
sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
|
||||
sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
|
||||
sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
|
||||
sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
|
||||
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
|
||||
|
@ -272,31 +275,31 @@ public class TestWindowFunctions extends BaseND4JTest {
|
|||
|
||||
//First window: -1000 to 1000
|
||||
List<List<Writable>> exp0 = new ArrayList<>();
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
|
||||
exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
|
||||
exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
|
||||
exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
|
||||
exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
|
||||
//Second window: 0 to 2000
|
||||
List<List<Writable>> exp1 = new ArrayList<>();
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
|
||||
exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
|
||||
exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0)));
|
||||
exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1)));
|
||||
exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2)));
|
||||
exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
|
||||
exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
|
||||
//Third window: 1000 to 3000
|
||||
List<List<Writable>> exp2 = new ArrayList<>();
|
||||
exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3)));
|
||||
exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4)));
|
||||
exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
|
||||
exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3)));
|
||||
exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4)));
|
||||
exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
|
||||
//Fourth window: 2000 to 4000
|
||||
List<List<Writable>> exp3 = new ArrayList<>();
|
||||
exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5)));
|
||||
exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5)));
|
||||
//Fifth window: 3000 to 5000 -> Empty: excluded
|
||||
//Sixth window: 4000 to 6000
|
||||
List<List<Writable>> exp5 = new ArrayList<>();
|
||||
exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
|
||||
exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
|
||||
//Seventh window: 5000 to 7000
|
||||
List<List<Writable>> exp6 = new ArrayList<>();
|
||||
exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7)));
|
||||
exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7)));
|
||||
|
||||
List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp5, exp6);
|
||||
|
||||
|
|
|
@ -26,11 +26,17 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.transform.serde.testClasses.CustomCondition;
|
||||
import org.datavec.api.transform.serde.testClasses.CustomFilter;
|
||||
import org.datavec.api.transform.serde.testClasses.CustomTransform;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JACKSON_SERDE)
|
||||
@Tag(TagNames.CUSTOM_FUNCTIONALITY)
|
||||
public class TestCustomTransformJsonYaml extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -64,14 +64,19 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
|
|||
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||
import org.joda.time.DateTimeFieldType;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JACKSON_SERDE)
|
||||
public class TestYamlJsonSerde extends BaseND4JTest {
|
||||
|
||||
public static YamlSerializer y = new YamlSerializer();
|
||||
|
|
|
@ -24,22 +24,26 @@ import org.datavec.api.transform.StringReduceOp;
|
|||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestReduce extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
public void testReducerDouble() {
|
||||
|
||||
List<List<Writable>> inputs = new ArrayList<>();
|
||||
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2")));
|
||||
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2")));
|
||||
inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2")));
|
||||
inputs.add(Arrays.asList(new Text("1"), new Text("2")));
|
||||
inputs.add(Arrays.asList(new Text("1"), new Text("2")));
|
||||
inputs.add(Arrays.asList(new Text("1"), new Text("2")));
|
||||
|
||||
Map<StringReduceOp, String> exp = new LinkedHashMap<>();
|
||||
exp.put(StringReduceOp.MERGE, "12");
|
||||
|
|
|
@ -37,10 +37,12 @@ import org.datavec.api.writable.Writable;
|
|||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
|
@ -49,7 +51,9 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.UI)
|
||||
public class TestUI extends BaseND4JTest {
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package org.datavec.api.util;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.io.BufferedReader;
|
||||
|
@ -33,8 +34,11 @@ import static org.hamcrest.core.AnyOf.anyOf;
|
|||
import static org.hamcrest.core.IsEqual.equalTo;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Class Path Resource Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class ClassPathResourceTest extends BaseND4JTest {
|
||||
|
||||
// File sizes are reported slightly different on Linux vs. Windows
|
||||
|
|
|
@ -22,8 +22,10 @@ package org.datavec.api.util;
|
|||
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
@ -32,6 +34,8 @@ import org.junit.jupiter.api.DisplayName;
|
|||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@DisplayName("Time Series Utils Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class TimeSeriesUtilsTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
*/
|
||||
package org.datavec.api.writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.guava.collect.Lists;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
|
@ -36,6 +38,8 @@ import org.junit.jupiter.api.DisplayName;
|
|||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@DisplayName("Record Converter Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class RecordConverterTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -21,15 +21,18 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
import org.datavec.api.transform.metadata.NDArrayMetaData;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -20,8 +20,10 @@
|
|||
package org.datavec.api.writable;
|
||||
|
||||
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -37,6 +39,8 @@ import org.junit.jupiter.api.DisplayName;
|
|||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@DisplayName("Writable Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class WritableTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -42,9 +42,11 @@ import org.datavec.api.writable.*;
|
|||
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
||||
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
@ -62,6 +64,8 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
|||
|
||||
@Slf4j
|
||||
@DisplayName("Arrow Converter Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class ArrowConverterTest extends BaseND4JTest {
|
||||
|
||||
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||
|
@ -142,8 +146,8 @@ class ArrowConverterTest extends BaseND4JTest {
|
|||
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
|
||||
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
|
||||
List<List<Writable>> assertionBatch = new ArrayList<>();
|
||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
|
||||
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
|
||||
assertionBatch.add(Arrays.asList(new IntWritable(0), new IntWritable(0)));
|
||||
assertionBatch.add(Arrays.asList(new IntWritable(1), new IntWritable(1)));
|
||||
assertEquals(assertionBatch, batchRecords);
|
||||
}
|
||||
|
||||
|
@ -156,11 +160,11 @@ class ArrowConverterTest extends BaseND4JTest {
|
|||
schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
|
||||
single.add(String.valueOf(i));
|
||||
}
|
||||
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
|
||||
List<List<Writable>> input = Arrays.asList(Arrays.asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
|
||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
|
||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
|
||||
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
|
||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
|
||||
List<Writable> assertion = Arrays.asList(new LongWritable(4), new LongWritable(5));
|
||||
writableRecordBatch.set(1, Arrays.asList(new LongWritable(4), new LongWritable(5)));
|
||||
List<Writable> recordTest = writableRecordBatch.get(1);
|
||||
assertEquals(assertion, recordTest);
|
||||
}
|
||||
|
@ -174,11 +178,11 @@ class ArrowConverterTest extends BaseND4JTest {
|
|||
schema.addColumnInteger(String.valueOf(i));
|
||||
single.add(String.valueOf(i));
|
||||
}
|
||||
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
|
||||
List<List<Writable>> input = Arrays.asList(Arrays.asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
|
||||
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
|
||||
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
|
||||
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
|
||||
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
|
||||
List<Writable> assertion = Arrays.asList(new IntWritable(4), new IntWritable(5));
|
||||
writableRecordBatch.set(1, Arrays.asList(new IntWritable(4), new IntWritable(5)));
|
||||
List<Writable> recordTest = writableRecordBatch.get(1);
|
||||
assertEquals(assertion, recordTest);
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.arrow.recordreader.ArrowRecordReader;
|
||||
import org.datavec.arrow.recordreader.ArrowRecordWriter;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
@ -44,8 +45,11 @@ import java.util.List;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Record Mapper Test")
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class RecordMapperTest extends BaseND4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -30,8 +30,10 @@ import org.datavec.api.writable.Text;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.arrow.ArrowConverter;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
@ -39,7 +41,8 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
|
||||
|
||||
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.datavec.api.split.FileSplit;
|
|||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -36,8 +37,12 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Label Generator Test")
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class LabelGeneratorTest {
|
||||
|
||||
|
||||
|
|
|
@ -23,7 +23,10 @@ package org.datavec.image.loader;
|
|||
import org.apache.commons.io.FilenameUtils;
|
||||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -39,6 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
/**
|
||||
*
|
||||
*/
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class LoaderTests {
|
||||
|
||||
private static void ensureDataAvailable(){
|
||||
|
@ -81,7 +86,7 @@ public class LoaderTests {
|
|||
String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin";
|
||||
String path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
|
||||
byte[] fullDataExpected = new byte[3073];
|
||||
FileInputStream inExpected = new FileInputStream(new File(path));
|
||||
FileInputStream inExpected = new FileInputStream(path);
|
||||
inExpected.read(fullDataExpected);
|
||||
|
||||
byte[] fullDataActual = new byte[3073];
|
||||
|
@ -94,7 +99,7 @@ public class LoaderTests {
|
|||
subDir = "cifar/cifar-10-batches-bin/test_batch.bin";
|
||||
path = FilenameUtils.concat(System.getProperty("user.home"), subDir);
|
||||
fullDataExpected = new byte[3073];
|
||||
inExpected = new FileInputStream(new File(path));
|
||||
inExpected = new FileInputStream(path);
|
||||
inExpected.read(fullDataExpected);
|
||||
|
||||
fullDataActual = new byte[3073];
|
||||
|
|
|
@ -21,8 +21,11 @@
|
|||
package org.datavec.image.loader;
|
||||
|
||||
import org.datavec.image.data.Image;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.awt.image.BufferedImage;
|
||||
|
@ -34,7 +37,8 @@ import java.util.Random;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestImageLoader {
|
||||
|
||||
private static long seed = 10;
|
||||
|
|
|
@ -31,10 +31,13 @@ import org.bytedeco.javacv.OpenCVFrameConverter;
|
|||
import org.datavec.image.data.Image;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -60,6 +63,8 @@ import static org.junit.jupiter.api.Assertions.fail;
|
|||
* @author saudet
|
||||
*/
|
||||
@Slf4j
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestNativeImageLoader {
|
||||
static final long seed = 10;
|
||||
static final Random rng = new Random(seed);
|
||||
|
|
|
@ -28,9 +28,12 @@ import org.datavec.api.writable.NDArrayWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.image.loader.NativeImageLoader;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.loader.FileBatch;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import java.io.File;
|
||||
|
@ -41,6 +44,8 @@ import java.nio.file.Path;
|
|||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@DisplayName("File Batch Record Reader Test")
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class FileBatchRecordReaderTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -37,9 +37,12 @@ import org.datavec.api.writable.NDArrayWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.api.writable.batch.NDArrayRecordBatch;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -54,7 +57,8 @@ import java.util.List;
|
|||
import java.util.Random;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestImageRecordReader {
|
||||
|
||||
|
||||
|
|
|
@ -36,9 +36,12 @@ import org.datavec.image.transform.ImageTransform;
|
|||
import org.datavec.image.transform.PipelineImageTransform;
|
||||
import org.datavec.image.transform.ResizeImageTransform;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
|
@ -54,7 +57,8 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestObjectDetectionRecordReader {
|
||||
|
||||
|
||||
|
|
|
@ -22,10 +22,13 @@ package org.datavec.image.recordreader.objdetect;
|
|||
|
||||
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
|
@ -34,7 +37,8 @@ import java.util.Collections;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class TestVocLabelProvider {
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package org.datavec.image.transform;
|
||||
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
@ -28,8 +29,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Json Yaml Test")
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JACKSON_SERDE)
|
||||
class JsonYamlTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -22,12 +22,17 @@ package org.datavec.image.transform;
|
|||
import org.bytedeco.javacv.Frame;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Resize Image Transform Test")
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class ResizeImageTransformTest {
|
||||
|
||||
@BeforeEach
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.bytedeco.javacpp.indexer.UByteIndexer;
|
|||
import org.bytedeco.javacv.CanvasFrame;
|
||||
import org.bytedeco.javacv.Frame;
|
||||
import org.bytedeco.javacv.OpenCVFrameConverter;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
|
@ -37,6 +38,8 @@ import java.util.List;
|
|||
import java.util.Random;
|
||||
|
||||
import org.bytedeco.opencv.opencv_core.*;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.bytedeco.opencv.global.opencv_core.*;
|
||||
import static org.bytedeco.opencv.global.opencv_imgproc.*;
|
||||
|
@ -46,6 +49,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
*
|
||||
* @author saudet
|
||||
*/
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
public class TestImageTransform {
|
||||
static final long seed = 10;
|
||||
static final Random rng = new Random(seed);
|
||||
|
|
|
@ -22,6 +22,7 @@ package org.datavec.poi.excel;
|
|||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import java.util.List;
|
||||
|
@ -29,8 +30,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Excel Record Reader Test")
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
class ExcelRecordReaderTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.primitives.Triple;
|
||||
|
@ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
|
|||
import org.junit.jupiter.api.DisplayName;
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@DisplayName("Excel Record Writer Test")
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
class ExcelRecordWriterTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -47,17 +47,19 @@ import org.datavec.api.writable.IntWritable;
|
|||
import org.datavec.api.writable.LongWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.*;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@DisplayName("Jdbc Record Reader Test")
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
class JDBCRecordReaderTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -36,15 +36,18 @@ import org.datavec.api.writable.LongWritable;
|
|||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class LocalTransformProcessRecordReaderTests {
|
||||
|
||||
@Test
|
||||
|
@ -64,11 +67,11 @@ public class LocalTransformProcessRecordReaderTests {
|
|||
public void simpleTransformTestSequence() {
|
||||
List<List<Writable>> sequence = new ArrayList<>();
|
||||
//First window:
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0),
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0),
|
||||
new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1),
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1),
|
||||
new IntWritable(0)));
|
||||
sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2),
|
||||
sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2),
|
||||
new IntWritable(0)));
|
||||
|
||||
Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC)
|
||||
|
|
|
@ -30,8 +30,10 @@ import org.datavec.api.util.ndarray.RecordConverter;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.local.transforms.AnalyzeLocal;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -40,7 +42,8 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class TestAnalyzeLocal {
|
||||
|
||||
|
||||
|
|
|
@ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.HashSet;
|
||||
|
@ -38,7 +40,8 @@ import java.util.stream.Collectors;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class TestLineRecordReaderFunction {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -34,7 +37,8 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@NativeTag
|
||||
public class TestNDArrayToWritablesFunction {
|
||||
|
||||
@Test
|
||||
|
@ -50,7 +54,7 @@ public class TestNDArrayToWritablesFunction {
|
|||
@Test
|
||||
public void testNDArrayToWritablesArray() throws Exception {
|
||||
INDArray arr = Nd4j.arange(5);
|
||||
List<Writable> expected = Arrays.asList((Writable) new NDArrayWritable(arr));
|
||||
List<Writable> expected = Arrays.asList(new NDArrayWritable(arr));
|
||||
List<Writable> actual = new NDArrayToWritablesFunction(true).apply(arr);
|
||||
assertEquals(expected, actual);
|
||||
}
|
||||
|
|
|
@ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable;
|
|||
import org.datavec.api.writable.Writable;
|
||||
|
||||
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -34,7 +37,8 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@NativeTag
|
||||
public class TestWritablesToNDArrayFunction {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -30,13 +30,16 @@ import org.datavec.api.writable.Writable;
|
|||
|
||||
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
|
||||
import org.datavec.local.transforms.misc.WritablesToStringFunction;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class TestWritablesToStringFunctions {
|
||||
|
||||
|
||||
|
@ -44,7 +47,7 @@ public class TestWritablesToStringFunctions {
|
|||
@Test
|
||||
public void testWritablesToString() throws Exception {
|
||||
|
||||
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue"));
|
||||
List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue"));
|
||||
String expected = l.get(0).toString() + "," + l.get(1).toString();
|
||||
|
||||
assertEquals(expected, new WritablesToStringFunction(",").apply(l));
|
||||
|
@ -53,8 +56,8 @@ public class TestWritablesToStringFunctions {
|
|||
@Test
|
||||
public void testSequenceWritablesToString() throws Exception {
|
||||
|
||||
List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")),
|
||||
Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue")));
|
||||
List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")),
|
||||
Arrays.asList(new DoubleWritable(2.5), new Text("otherValue")));
|
||||
|
||||
String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n"
|
||||
+ l.get(1).get(0).toString() + "," + l.get(1).get(1).toString();
|
||||
|
|
|
@ -31,7 +31,10 @@ import org.datavec.api.transform.schema.SequenceSchema;
|
|||
import org.datavec.api.writable.*;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
@ -42,6 +45,8 @@ import static java.time.Duration.ofMillis;
|
|||
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
||||
|
||||
@DisplayName("Execution Test")
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@NativeTag
|
||||
class ExecutionTest {
|
||||
|
||||
@Test
|
||||
|
@ -71,18 +76,12 @@ class ExecutionTest {
|
|||
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
|
||||
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
|
||||
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
|
||||
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
|
||||
List<List<Writable>> rdd = (inputData);
|
||||
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
});
|
||||
Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()));
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
|
||||
|
@ -95,9 +94,9 @@ class ExecutionTest {
|
|||
void testFilter() {
|
||||
Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
|
||||
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
|
||||
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
|
||||
assertEquals(2, execute.size());
|
||||
|
@ -110,31 +109,25 @@ class ExecutionTest {
|
|||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
|
||||
List<List<List<Writable>>> inputSequences = new ArrayList<>();
|
||||
List<List<Writable>> seq1 = new ArrayList<>();
|
||||
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
List<List<Writable>> seq2 = new ArrayList<>();
|
||||
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
|
||||
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
|
||||
seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
|
||||
seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
|
||||
inputSequences.add(seq1);
|
||||
inputSequences.add(seq2);
|
||||
List<List<List<Writable>>> rdd = (inputSequences);
|
||||
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
|
||||
Collections.sort(out, new Comparator<List<List<Writable>>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
|
||||
return -Integer.compare(o1.size(), o2.size());
|
||||
}
|
||||
});
|
||||
Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size()));
|
||||
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
|
||||
List<List<Writable>> seq1e = new ArrayList<>();
|
||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
List<List<Writable>> seq2e = new ArrayList<>();
|
||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
|
||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
|
||||
seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
|
||||
seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
|
||||
expectedSequence.add(seq1e);
|
||||
expectedSequence.add(seq2e);
|
||||
assertEquals(expectedSequence, out);
|
||||
|
@ -143,26 +136,26 @@ class ExecutionTest {
|
|||
@Test
|
||||
@DisplayName("Test Reduction Global")
|
||||
void testReductionGlobal() {
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
|
||||
List<List<Writable>> inData = in;
|
||||
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
||||
List<List<Writable>> out = outRdd;
|
||||
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
|
||||
List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0)));
|
||||
assertEquals(expOut, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Test Reduction By Key")
|
||||
void testReductionByKey() {
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
|
||||
List<List<Writable>> inData = in;
|
||||
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
|
||||
List<List<Writable>> out = outRdd;
|
||||
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||
List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||
out = new ArrayList<>(out);
|
||||
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||
assertEquals(expOut, out);
|
||||
|
|
|
@ -28,12 +28,15 @@ import org.datavec.api.writable.*;
|
|||
|
||||
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class TestJoin {
|
||||
|
||||
@Test
|
||||
|
@ -46,27 +49,27 @@ public class TestJoin {
|
|||
.addColumnDouble("amount").build();
|
||||
|
||||
List<List<Writable>> infoList = new ArrayList<>();
|
||||
infoList.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345")));
|
||||
infoList.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765")));
|
||||
infoList.add(Arrays.<Writable>asList(new LongWritable(50000), new Text("Customer50000")));
|
||||
infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345")));
|
||||
infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765")));
|
||||
infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000")));
|
||||
|
||||
List<List<Writable>> purchaseList = new ArrayList<>();
|
||||
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000000), new LongWritable(12345),
|
||||
purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345),
|
||||
new DoubleWritable(10.00)));
|
||||
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000001), new LongWritable(12345),
|
||||
purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345),
|
||||
new DoubleWritable(20.00)));
|
||||
purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000002), new LongWritable(98765),
|
||||
purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765),
|
||||
new DoubleWritable(30.00)));
|
||||
|
||||
Join join = new Join.Builder(Join.JoinType.RightOuter).setJoinColumns("customerID")
|
||||
.setSchemas(customerInfoSchema, purchasesSchema).build();
|
||||
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"),
|
||||
expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"),
|
||||
new LongWritable(1000000), new DoubleWritable(10.00)));
|
||||
expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"),
|
||||
expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"),
|
||||
new LongWritable(1000001), new DoubleWritable(20.00)));
|
||||
expected.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765"),
|
||||
expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"),
|
||||
new LongWritable(1000002), new DoubleWritable(30.00)));
|
||||
|
||||
|
||||
|
@ -77,12 +80,7 @@ public class TestJoin {
|
|||
List<List<Writable>> joined = LocalTransformExecutor.executeJoin(join, info, purchases);
|
||||
List<List<Writable>> joinedList = new ArrayList<>(joined);
|
||||
//Sort by order ID (column 3, index 2)
|
||||
Collections.sort(joinedList, new Comparator<List<Writable>>() {
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Long.compare(o1.get(2).toLong(), o2.get(2).toLong());
|
||||
}
|
||||
});
|
||||
Collections.sort(joinedList, (o1, o2) -> Long.compare(o1.get(2).toLong(), o2.get(2).toLong()));
|
||||
assertEquals(expected, joinedList);
|
||||
|
||||
assertEquals(3, joinedList.size());
|
||||
|
@ -110,12 +108,7 @@ public class TestJoin {
|
|||
List<List<Writable>> joined2 = LocalTransformExecutor.executeJoin(join2, purchases, info);
|
||||
List<List<Writable>> joinedList2 = new ArrayList<>(joined2);
|
||||
//Sort by order ID (column 0)
|
||||
Collections.sort(joinedList2, new Comparator<List<Writable>>() {
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Long.compare(o1.get(0).toLong(), o2.get(0).toLong());
|
||||
}
|
||||
});
|
||||
Collections.sort(joinedList2, (o1, o2) -> Long.compare(o1.get(0).toLong(), o2.get(0).toLong()));
|
||||
assertEquals(3, joinedList2.size());
|
||||
|
||||
assertEquals(expectedManyToOne, joinedList2);
|
||||
|
@ -189,29 +182,26 @@ public class TestJoin {
|
|||
new ArrayList<>(LocalTransformExecutor.executeJoin(join, firstRDD, secondRDD));
|
||||
|
||||
//Sort output by column 0, then column 1, then column 2 for comparison to expected...
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
Writable w1 = o1.get(0);
|
||||
Writable w2 = o2.get(0);
|
||||
if (w1 instanceof NullWritable)
|
||||
return 1;
|
||||
else if (w2 instanceof NullWritable)
|
||||
return -1;
|
||||
int c = Long.compare(w1.toLong(), w2.toLong());
|
||||
if (c != 0)
|
||||
return c;
|
||||
c = o1.get(1).toString().compareTo(o2.get(1).toString());
|
||||
if (c != 0)
|
||||
return c;
|
||||
w1 = o1.get(2);
|
||||
w2 = o2.get(2);
|
||||
if (w1 instanceof NullWritable)
|
||||
return 1;
|
||||
else if (w2 instanceof NullWritable)
|
||||
return -1;
|
||||
return Long.compare(w1.toLong(), w2.toLong());
|
||||
}
|
||||
Collections.sort(out, (o1, o2) -> {
|
||||
Writable w1 = o1.get(0);
|
||||
Writable w2 = o2.get(0);
|
||||
if (w1 instanceof NullWritable)
|
||||
return 1;
|
||||
else if (w2 instanceof NullWritable)
|
||||
return -1;
|
||||
int c = Long.compare(w1.toLong(), w2.toLong());
|
||||
if (c != 0)
|
||||
return c;
|
||||
c = o1.get(1).toString().compareTo(o2.get(1).toString());
|
||||
if (c != 0)
|
||||
return c;
|
||||
w1 = o1.get(2);
|
||||
w2 = o2.get(2);
|
||||
if (w1 instanceof NullWritable)
|
||||
return 1;
|
||||
else if (w2 instanceof NullWritable)
|
||||
return -1;
|
||||
return Long.compare(w1.toLong(), w2.toLong());
|
||||
});
|
||||
|
||||
switch (jt) {
|
||||
|
|
|
@ -31,14 +31,17 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
|||
|
||||
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class TestCalculateSortedRank {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -31,7 +31,9 @@ import org.datavec.api.writable.Writable;
|
|||
|
||||
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -39,7 +41,8 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
public class TestConvertToSequence {
|
||||
|
||||
@Test
|
||||
|
@ -48,12 +51,12 @@ public class TestConvertToSequence {
|
|||
Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build();
|
||||
|
||||
List<List<Writable>> allExamples =
|
||||
Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"),
|
||||
Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"),
|
||||
new LongWritable(-10)),
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(s)
|
||||
.convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time"))
|
||||
|
@ -75,13 +78,13 @@ public class TestConvertToSequence {
|
|||
}
|
||||
|
||||
List<List<Writable>> expSeq0 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
|
||||
|
||||
List<List<Writable>> expSeq1 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
|
||||
|
||||
assertEquals(expSeq0, seq0);
|
||||
assertEquals(expSeq1, seq1);
|
||||
|
@ -96,9 +99,9 @@ public class TestConvertToSequence {
|
|||
.build();
|
||||
|
||||
List<List<Writable>> allExamples = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("a"), new LongWritable(0)),
|
||||
Arrays.<Writable>asList(new Text("b"), new LongWritable(1)),
|
||||
Arrays.<Writable>asList(new Text("c"), new LongWritable(2)));
|
||||
Arrays.asList(new Text("a"), new LongWritable(0)),
|
||||
Arrays.asList(new Text("b"), new LongWritable(1)),
|
||||
Arrays.asList(new Text("c"), new LongWritable(2)));
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(s)
|
||||
.convertToSequence()
|
||||
|
|
|
@ -25,8 +25,10 @@ import org.apache.spark.serializer.SerializerInstance;
|
|||
import org.datavec.api.records.reader.RecordReader;
|
||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
import org.datavec.api.split.FileSplit;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -34,7 +36,10 @@ import java.nio.ByteBuffer;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestKryoSerialization extends BaseSparkTest {
|
||||
|
||||
@Override
|
||||
|
|
|
@ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|||
import org.datavec.api.split.FileSplit;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.HashSet;
|
||||
|
@ -37,7 +39,10 @@ import java.util.Set;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestLineRecordReaderFunction extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -24,7 +24,10 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.transform.misc.NDArrayToWritablesFunction;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -33,7 +36,10 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
@NativeTag
|
||||
public class TestNDArrayToWritablesFunction {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -39,10 +39,12 @@ import org.datavec.spark.functions.pairdata.PathToKeyConverter;
|
|||
import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename;
|
||||
import org.datavec.spark.util.DataVecSparkUtil;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import scala.Tuple2;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -53,7 +55,10 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest;
|
|||
import org.datavec.spark.functions.data.FilesAsBytesFunction;
|
||||
import org.datavec.spark.functions.data.RecordReaderBytesFunction;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
|
@ -51,7 +53,10 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestRecordReaderBytesFunction extends BaseSparkTest {
|
||||
|
||||
|
||||
|
|
|
@ -32,10 +32,12 @@ import org.datavec.api.writable.Writable;
|
|||
import org.datavec.image.recordreader.ImageRecordReader;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
|
@ -45,7 +47,10 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestRecordReaderFunction extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest;
|
|||
import org.datavec.spark.functions.data.FilesAsBytesFunction;
|
||||
import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
|
@ -50,7 +52,10 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest {
|
||||
|
||||
|
||||
|
|
|
@ -34,10 +34,12 @@ import org.datavec.api.writable.Writable;
|
|||
import org.datavec.codec.reader.CodecRecordReader;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
|
@ -46,7 +48,10 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.fail;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestSequenceRecordReaderFunction extends BaseSparkTest {
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,10 @@ package org.datavec.spark.functions;
|
|||
|
||||
import org.datavec.api.writable.*;
|
||||
import org.datavec.spark.transform.misc.WritablesToNDArrayFunction;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -31,7 +34,10 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
@NativeTag
|
||||
public class TestWritablesToNDArrayFunction {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -29,7 +29,9 @@ import org.datavec.api.writable.Writable;
|
|||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction;
|
||||
import org.datavec.spark.transform.misc.WritablesToStringFunction;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import scala.Tuple2;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -37,7 +39,10 @@ import java.util.Arrays;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestWritablesToStringFunctions extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
@ -57,19 +62,9 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
|
|||
|
||||
|
||||
JavaSparkContext sc = getContext();
|
||||
JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() {
|
||||
@Override
|
||||
public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception {
|
||||
return stringStringTuple2;
|
||||
}
|
||||
});
|
||||
JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2);
|
||||
|
||||
JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() {
|
||||
@Override
|
||||
public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception {
|
||||
return stringStringTuple2;
|
||||
}
|
||||
});
|
||||
JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2);
|
||||
|
||||
System.out.println(left.cogroup(right).collect());
|
||||
}
|
||||
|
@ -77,7 +72,7 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
|
|||
@Test
|
||||
public void testWritablesToString() throws Exception {
|
||||
|
||||
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue"));
|
||||
List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue"));
|
||||
String expected = l.get(0).toString() + "," + l.get(1).toString();
|
||||
|
||||
assertEquals(expected, new WritablesToStringFunction(",").call(l));
|
||||
|
@ -86,8 +81,8 @@ public class TestWritablesToStringFunctions extends BaseSparkTest {
|
|||
@Test
|
||||
public void testSequenceWritablesToString() throws Exception {
|
||||
|
||||
List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")),
|
||||
Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue")));
|
||||
List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")),
|
||||
Arrays.asList(new DoubleWritable(2.5), new Text("otherValue")));
|
||||
|
||||
String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n"
|
||||
+ l.get(1).get(0).toString() + "," + l.get(1).get(1).toString();
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
package org.datavec.spark.storage;
|
||||
|
||||
import com.sun.jna.Platform;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.spark.api.java.JavaPairRDD;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
|
@ -37,7 +39,10 @@ import java.util.Map;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestSparkStorageUtils extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
@ -46,11 +51,11 @@ public class TestSparkStorageUtils extends BaseSparkTest {
|
|||
return;
|
||||
}
|
||||
List<List<Writable>> l = new ArrayList<>();
|
||||
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
|
||||
l.add(Arrays.asList(new Text("zero"), new IntWritable(0),
|
||||
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))));
|
||||
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(11),
|
||||
l.add(Arrays.asList(new Text("one"), new IntWritable(11),
|
||||
new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))));
|
||||
l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(22),
|
||||
l.add(Arrays.asList(new Text("two"), new IntWritable(22),
|
||||
new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0))));
|
||||
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(l);
|
||||
|
@ -92,17 +97,17 @@ public class TestSparkStorageUtils extends BaseSparkTest {
|
|||
}
|
||||
List<List<List<Writable>>> l = new ArrayList<>();
|
||||
l.add(Arrays.asList(
|
||||
Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0),
|
||||
Arrays.asList(new Text("zero"), new IntWritable(0),
|
||||
new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))),
|
||||
Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(1),
|
||||
Arrays.asList(new Text("one"), new IntWritable(1),
|
||||
new DoubleWritable(1.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 1.0))),
|
||||
Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(2),
|
||||
Arrays.asList(new Text("two"), new IntWritable(2),
|
||||
new DoubleWritable(2.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 2.0)))));
|
||||
|
||||
l.add(Arrays.asList(
|
||||
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bzero"), new IntWritable(10),
|
||||
Arrays.asList(new Text("Bzero"), new IntWritable(10),
|
||||
new DoubleWritable(10), new NDArrayWritable(Nd4j.valueArrayOf(10, 10.0))),
|
||||
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bone"), new IntWritable(11),
|
||||
Arrays.asList(new Text("Bone"), new IntWritable(11),
|
||||
new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))),
|
||||
Arrays.<org.datavec.api.writable.Writable>asList(new Text("Btwo"), new IntWritable(12),
|
||||
new DoubleWritable(12.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 12.0)))));
|
||||
|
|
|
@ -30,14 +30,19 @@ import org.datavec.api.util.ndarray.RecordConverter;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class DataFramesTests extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
@ -110,15 +115,15 @@ public class DataFramesTests extends BaseSparkTest {
|
|||
public void testNormalize() {
|
||||
List<List<Writable>> data = new ArrayList<>();
|
||||
|
||||
data.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10)));
|
||||
data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20)));
|
||||
data.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30)));
|
||||
data.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10)));
|
||||
data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20)));
|
||||
data.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30)));
|
||||
|
||||
|
||||
List<List<Writable>> expMinMax = new ArrayList<>();
|
||||
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
|
||||
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new DoubleWritable(0.5)));
|
||||
expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(1.0)));
|
||||
expMinMax.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
|
||||
expMinMax.add(Arrays.asList(new DoubleWritable(0.5), new DoubleWritable(0.5)));
|
||||
expMinMax.add(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(1.0)));
|
||||
|
||||
double m1 = (1 + 2 + 3) / 3.0;
|
||||
double s1 = new StandardDeviation().evaluate(new double[] {1, 2, 3});
|
||||
|
@ -127,11 +132,11 @@ public class DataFramesTests extends BaseSparkTest {
|
|||
|
||||
List<List<Writable>> expStandardize = new ArrayList<>();
|
||||
expStandardize.add(
|
||||
Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2)));
|
||||
Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2)));
|
||||
expStandardize.add(
|
||||
Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2)));
|
||||
Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2)));
|
||||
expStandardize.add(
|
||||
Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2)));
|
||||
Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2)));
|
||||
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
|
||||
|
||||
|
@ -178,13 +183,13 @@ public class DataFramesTests extends BaseSparkTest {
|
|||
List<List<List<Writable>>> sequences = new ArrayList<>();
|
||||
|
||||
List<List<Writable>> seq1 = new ArrayList<>();
|
||||
seq1.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100)));
|
||||
seq1.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200)));
|
||||
seq1.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300)));
|
||||
seq1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100)));
|
||||
seq1.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200)));
|
||||
seq1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300)));
|
||||
|
||||
List<List<Writable>> seq2 = new ArrayList<>();
|
||||
seq2.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400)));
|
||||
seq2.add(Arrays.<Writable>asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500)));
|
||||
seq2.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400)));
|
||||
seq2.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500)));
|
||||
|
||||
sequences.add(seq1);
|
||||
sequences.add(seq2);
|
||||
|
@ -199,21 +204,21 @@ public class DataFramesTests extends BaseSparkTest {
|
|||
|
||||
//Min/max normalization:
|
||||
List<List<Writable>> expSeq1MinMax = new ArrayList<>();
|
||||
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)),
|
||||
expSeq1MinMax.add(Arrays.asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)),
|
||||
new DoubleWritable((10 - 10.0) / (50.0 - 10.0)),
|
||||
new DoubleWritable((100 - 100.0) / (500.0 - 100.0))));
|
||||
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)),
|
||||
expSeq1MinMax.add(Arrays.asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)),
|
||||
new DoubleWritable((20 - 10.0) / (50.0 - 10.0)),
|
||||
new DoubleWritable((200 - 100.0) / (500.0 - 100.0))));
|
||||
expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)),
|
||||
expSeq1MinMax.add(Arrays.asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)),
|
||||
new DoubleWritable((30 - 10.0) / (50.0 - 10.0)),
|
||||
new DoubleWritable((300 - 100.0) / (500.0 - 100.0))));
|
||||
|
||||
List<List<Writable>> expSeq2MinMax = new ArrayList<>();
|
||||
expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)),
|
||||
expSeq2MinMax.add(Arrays.asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)),
|
||||
new DoubleWritable((40 - 10.0) / (50.0 - 10.0)),
|
||||
new DoubleWritable((400 - 100.0) / (500.0 - 100.0))));
|
||||
expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)),
|
||||
expSeq2MinMax.add(Arrays.asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)),
|
||||
new DoubleWritable((50 - 10.0) / (50.0 - 10.0)),
|
||||
new DoubleWritable((500 - 100.0) / (500.0 - 100.0))));
|
||||
|
||||
|
@ -246,17 +251,17 @@ public class DataFramesTests extends BaseSparkTest {
|
|||
double s3 = new StandardDeviation().evaluate(new double[] {100, 200, 300, 400, 500});
|
||||
|
||||
List<List<Writable>> expSeq1Std = new ArrayList<>();
|
||||
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2),
|
||||
expSeq1Std.add(Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2),
|
||||
new DoubleWritable((100 - m3) / s3)));
|
||||
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2),
|
||||
expSeq1Std.add(Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2),
|
||||
new DoubleWritable((200 - m3) / s3)));
|
||||
expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2),
|
||||
expSeq1Std.add(Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2),
|
||||
new DoubleWritable((300 - m3) / s3)));
|
||||
|
||||
List<List<Writable>> expSeq2Std = new ArrayList<>();
|
||||
expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2),
|
||||
expSeq2Std.add(Arrays.asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2),
|
||||
new DoubleWritable((400 - m3) / s3)));
|
||||
expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2),
|
||||
expSeq2Std.add(Arrays.asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2),
|
||||
new DoubleWritable((500 - m3) / s3)));
|
||||
|
||||
|
||||
|
|
|
@ -31,22 +31,24 @@ import org.datavec.api.writable.DoubleWritable;
|
|||
import org.datavec.api.writable.IntWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.api.writable.NDArrayWritable;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.datavec.python.PythonTransform;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import java.util.*;
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static java.time.Duration.ofMillis;
|
||||
import static org.junit.jupiter.api.Assertions.assertTimeout;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@DisplayName("Execution Test")
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
class ExecutionTest extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
@ -55,22 +57,16 @@ class ExecutionTest extends BaseSparkTest {
|
|||
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
|
||||
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
});
|
||||
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
assertEquals(expected, out);
|
||||
}
|
||||
|
||||
|
@ -81,31 +77,25 @@ class ExecutionTest extends BaseSparkTest {
|
|||
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
|
||||
List<List<List<Writable>>> inputSequences = new ArrayList<>();
|
||||
List<List<Writable>> seq1 = new ArrayList<>();
|
||||
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
List<List<Writable>> seq2 = new ArrayList<>();
|
||||
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
|
||||
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
|
||||
seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
|
||||
seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
|
||||
inputSequences.add(seq1);
|
||||
inputSequences.add(seq2);
|
||||
JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences);
|
||||
List<List<List<Writable>>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect());
|
||||
Collections.sort(out, new Comparator<List<List<Writable>>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
|
||||
return -Integer.compare(o1.size(), o2.size());
|
||||
}
|
||||
});
|
||||
Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size()));
|
||||
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
|
||||
List<List<Writable>> seq1e = new ArrayList<>();
|
||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
List<List<Writable>> seq2e = new ArrayList<>();
|
||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
|
||||
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
|
||||
seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
|
||||
seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
|
||||
expectedSequence.add(seq1e);
|
||||
expectedSequence.add(seq2e);
|
||||
assertEquals(expectedSequence, out);
|
||||
|
@ -114,34 +104,28 @@ class ExecutionTest extends BaseSparkTest {
|
|||
@Test
|
||||
@DisplayName("Test Reduction Global")
|
||||
void testReductionGlobal() {
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
|
||||
JavaRDD<List<Writable>> inData = sc.parallelize(in);
|
||||
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
|
||||
List<List<Writable>> out = outRdd.collect();
|
||||
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
|
||||
List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0)));
|
||||
assertEquals(expOut, out);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Test Reduction By Key")
|
||||
void testReductionByKey() {
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
|
||||
List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
|
||||
JavaRDD<List<Writable>> inData = sc.parallelize(in);
|
||||
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
|
||||
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
|
||||
JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp);
|
||||
List<List<Writable>> out = outRdd.collect();
|
||||
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||
List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
|
||||
out = new ArrayList<>(out);
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
});
|
||||
Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()));
|
||||
assertEquals(expOut, out);
|
||||
}
|
||||
|
||||
|
@ -150,15 +134,15 @@ class ExecutionTest extends BaseSparkTest {
|
|||
void testUniqueMultiCol() {
|
||||
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
|
||||
Map<String, List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd);
|
||||
assertEquals(2, l.size());
|
||||
|
@ -180,58 +164,20 @@ class ExecutionTest extends BaseSparkTest {
|
|||
String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0";
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build();
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1)));
|
||||
inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1)));
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
|
||||
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
});
|
||||
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
|
||||
expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1)));
|
||||
expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1)));
|
||||
assertEquals(expected, out);
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
|
||||
@DisplayName("Test Python Execution With ND Arrays")
|
||||
void testPythonExecutionWithNDArrays() {
|
||||
assertTimeout(ofMillis(60000), () -> {
|
||||
long[] shape = new long[] { 3, 2 };
|
||||
Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).build();
|
||||
Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build();
|
||||
String pythonCode = "col3 = col1 + col2";
|
||||
TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build();
|
||||
INDArray zeros = Nd4j.zeros(shape);
|
||||
INDArray ones = Nd4j.ones(shape);
|
||||
INDArray twos = ones.add(ones);
|
||||
List<List<Writable>> inputData = new ArrayList<>();
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones)));
|
||||
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones)));
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(inputData);
|
||||
List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect());
|
||||
Collections.sort(out, new Comparator<List<Writable>>() {
|
||||
|
||||
@Override
|
||||
public int compare(List<Writable> o1, List<Writable> o2) {
|
||||
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
|
||||
}
|
||||
});
|
||||
List<List<Writable>> expected = new ArrayList<>();
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones)));
|
||||
expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos)));
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Test First Digit Transform Benfords Law")
|
||||
|
|
|
@ -28,7 +28,10 @@ import org.datavec.api.util.ndarray.RecordConverter;
|
|||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -43,7 +46,10 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
@NativeTag
|
||||
public class NormalizationTests extends BaseSparkTest {
|
||||
|
||||
|
||||
|
|
|
@ -38,7 +38,9 @@ import org.datavec.local.transforms.AnalyzeLocal;
|
|||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.datavec.spark.transform.AnalyzeSpark;
|
||||
import org.joda.time.DateTimeZone;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
@ -48,7 +50,10 @@ import java.nio.file.Files;
|
|||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestAnalysis extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -27,12 +27,17 @@ import org.datavec.api.transform.schema.Schema;
|
|||
import org.datavec.api.writable.*;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.datavec.spark.transform.SparkTransformExecutor;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestJoin extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -30,24 +30,29 @@ import org.datavec.api.writable.Writable;
|
|||
import org.datavec.api.writable.comparator.DoubleWritableComparator;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.datavec.spark.transform.SparkTransformExecutor;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestCalculateSortedRank extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
public void testCalculateSortedRank() {
|
||||
|
||||
List<List<Writable>> data = new ArrayList<>();
|
||||
data.add(Arrays.asList((Writable) new Text("0"), new DoubleWritable(0.0)));
|
||||
data.add(Arrays.asList((Writable) new Text("3"), new DoubleWritable(0.3)));
|
||||
data.add(Arrays.asList((Writable) new Text("2"), new DoubleWritable(0.2)));
|
||||
data.add(Arrays.asList((Writable) new Text("1"), new DoubleWritable(0.1)));
|
||||
data.add(Arrays.asList(new Text("0"), new DoubleWritable(0.0)));
|
||||
data.add(Arrays.asList(new Text("3"), new DoubleWritable(0.3)));
|
||||
data.add(Arrays.asList(new Text("2"), new DoubleWritable(0.2)));
|
||||
data.add(Arrays.asList(new Text("1"), new DoubleWritable(0.1)));
|
||||
|
||||
JavaRDD<List<Writable>> rdd = sc.parallelize(data);
|
||||
|
||||
|
|
|
@ -29,7 +29,9 @@ import org.datavec.api.writable.Text;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.datavec.spark.transform.SparkTransformExecutor;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
|
@ -37,7 +39,10 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.JAVA_ONLY)
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestConvertToSequence extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
@ -45,13 +50,13 @@ public class TestConvertToSequence extends BaseSparkTest {
|
|||
|
||||
Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build();
|
||||
|
||||
List<List<Writable>> allExamples =
|
||||
Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"),
|
||||
new LongWritable(-10)),
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
|
||||
List<List<Writable>> allExamples;
|
||||
allExamples = Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)),
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"),
|
||||
new LongWritable(-10)),
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)));
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(s)
|
||||
.convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time"))
|
||||
|
@ -73,13 +78,13 @@ public class TestConvertToSequence extends BaseSparkTest {
|
|||
}
|
||||
|
||||
List<List<Writable>> expSeq0 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
|
||||
Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)),
|
||||
Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)));
|
||||
|
||||
List<List<Writable>> expSeq1 = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)),
|
||||
Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)));
|
||||
|
||||
assertEquals(expSeq0, seq0);
|
||||
assertEquals(expSeq1, seq1);
|
||||
|
@ -94,9 +99,9 @@ public class TestConvertToSequence extends BaseSparkTest {
|
|||
.build();
|
||||
|
||||
List<List<Writable>> allExamples = Arrays.asList(
|
||||
Arrays.<Writable>asList(new Text("a"), new LongWritable(0)),
|
||||
Arrays.<Writable>asList(new Text("b"), new LongWritable(1)),
|
||||
Arrays.<Writable>asList(new Text("c"), new LongWritable(2)));
|
||||
Arrays.asList(new Text("a"), new LongWritable(0)),
|
||||
Arrays.asList(new Text("b"), new LongWritable(1)),
|
||||
Arrays.asList(new Text("c"), new LongWritable(2)));
|
||||
|
||||
TransformProcess tp = new TransformProcess.Builder(s)
|
||||
.convertToSequence()
|
||||
|
|
|
@ -28,7 +28,9 @@ import org.datavec.api.writable.Text;
|
|||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.BaseSparkTest;
|
||||
import org.datavec.spark.transform.utils.SparkUtils;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
|
@ -38,6 +40,9 @@ import java.util.List;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
@Tag(TagNames.SPARK)
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.DIST_SYSTEMS)
|
||||
public class TestSparkUtil extends BaseSparkTest {
|
||||
|
||||
@Test
|
||||
|
@ -46,8 +51,8 @@ public class TestSparkUtil extends BaseSparkTest {
|
|||
return;
|
||||
}
|
||||
List<List<Writable>> l = new ArrayList<>();
|
||||
l.add(Arrays.<Writable>asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1)));
|
||||
l.add(Arrays.<Writable>asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2)));
|
||||
l.add(Arrays.asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1)));
|
||||
l.add(Arrays.asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2)));
|
||||
|
||||
File f = File.createTempFile("testSparkUtil", "txt");
|
||||
f.deleteOnExit();
|
||||
|
|
|
@ -27,8 +27,11 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.resources.Resources;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -39,6 +42,8 @@ import java.nio.file.Files;
|
|||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
@Disabled
|
||||
@NativeTag
|
||||
@Tag(TagNames.RNG)
|
||||
public class RandomTests extends BaseDL4JTest {
|
||||
|
||||
@Test
|
||||
|
|
|
@ -23,11 +23,11 @@ import org.deeplearning4j.BaseDL4JTest;
|
|||
import org.deeplearning4j.datasets.base.MnistFetcher;
|
||||
import org.deeplearning4j.common.resources.DL4JResources;
|
||||
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.*;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
@ -41,10 +41,13 @@ import java.util.Set;
|
|||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
|
||||
@DisplayName("Mnist Fetcher Test")
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.NDARRAY_ETL)
|
||||
class MnistFetcherTest extends BaseDL4JTest {
|
||||
|
||||
|
||||
|
|
|
@ -23,8 +23,14 @@ package org.deeplearning4j.datasets;
|
|||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher;
|
||||
import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
@NativeTag
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@Tag(TagNames.NDARRAY_ETL)
|
||||
public class TestDataSets extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
|
|
|
@ -20,8 +20,11 @@
|
|||
package org.deeplearning4j.datasets.datavec;
|
||||
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
@ -76,6 +79,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
|
|||
@Slf4j
|
||||
@DisplayName("Record Reader Data Setiterator Test")
|
||||
@Disabled
|
||||
@NativeTag
|
||||
class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
|
@ -148,6 +152,7 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest {
|
|||
@ParameterizedTest
|
||||
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
|
||||
@DisplayName("Test Sequence Record Reader")
|
||||
@Tag(TagNames.NDARRAY_INDEXING)
|
||||
void testSequenceRecordReader(Nd4jBackend backend) throws Exception {
|
||||
File rootDir = temporaryFolder.toFile();
|
||||
// need to manually extract
|
||||
|
|
|
@ -21,6 +21,8 @@ package org.deeplearning4j.datasets.datavec;
|
|||
|
||||
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
import org.nd4j.shade.guava.io.Files;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.FilenameUtils;
|
||||
|
@ -70,6 +72,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
|
|||
|
||||
@DisplayName("Record Reader Multi Data Set Iterator Test")
|
||||
@Disabled
|
||||
@Tag(TagNames.FILE_IO)
|
||||
class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest {
|
||||
|
||||
@TempDir
|
||||
|
|
|
@ -21,6 +21,7 @@ package org.deeplearning4j.datasets.fetchers;
|
|||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
|
||||
import org.junit.jupiter.api.Tag;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.File;
|
||||
|
@ -28,11 +29,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
|
|||
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.extension.ExtendWith;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.common.tests.tags.TagNames;
|
||||
|
||||
/**
|
||||
* @author saudet
|
||||
*/
|
||||
@DisplayName("Svhn Data Fetcher Test")
|
||||
@Tag(TagNames.FILE_IO)
|
||||
@NativeTag
|
||||
class SvhnDataFetcherTest extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
|
|
|
@ -26,6 +26,7 @@ import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator;
|
|||
import org.deeplearning4j.nn.util.TestDataSetConsumer;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.common.tests.tags.NativeTag;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import java.util.ArrayList;
|
||||
|
@ -40,6 +41,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows;
|
|||
|
||||
@Slf4j
|
||||
@DisplayName("Async Data Set Iterator Test")
|
||||
@NativeTag
|
||||
class AsyncDataSetIteratorTest extends BaseDL4JTest {
|
||||
|
||||
private ExistingDataSetIterator backIterator;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue