Add tags for junit 5
This commit is contained in:
		
							parent
							
								
									3c205548af
								
							
						
					
					
						commit
						5e8951cd8e
					
				| @ -12,7 +12,7 @@ DL4J was a junit 4 based code based for testing. | ||||
| It's now based on junit 5's jupiter API, which has support for [Tags](https://junit.org/junit5/docs/5.0.1/api/org/junit/jupiter/api/Tag.html). | ||||
| 
 | ||||
| DL4j's code base has a number of different kinds of tests that fall in to several categories: | ||||
| 1. Long and flaky involving distributed systems (spark, parameter server) | ||||
| 1. Long and flaky involving distributed systems (spark, parameter-server) | ||||
| 2. Code that requires large downloads, but runs quickly | ||||
| 3. Quick tests that test basic functionality | ||||
| 4. Comprehensive integration tests that test several parts of a code  base | ||||
| @ -38,8 +38,10 @@ A few kinds of tags exist: | ||||
| 3. Distributed systems: spark, multi-threaded | ||||
| 4. Functional cross-cutting concerns: multi module tests, similar functionality (excludes time based) | ||||
| 5. Platform specific tests that can vary on different hardware: cpu, gpu | ||||
| 6. JVM crash: Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash | ||||
| 
 | ||||
| 6. JVM crash: (jvm-crash) Tests with native code can crash the JVM for tests. It's useful to be able to turn those off when debugging.: jvm-crash | ||||
| 7. RNG: (rng) for RNG related tests | ||||
| 8. Samediff:(samediff) samediff related tests | ||||
| 9. Training related functionality | ||||
| 
 | ||||
| 
 | ||||
| ## Consequences | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| @ -38,8 +39,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Csv Line Sequence Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class CSVLineSequenceRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @TempDir | ||||
| @ -54,8 +58,8 @@ class CSVLineSequenceRecordReaderTest extends BaseND4JTest { | ||||
|         FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8); | ||||
|         SequenceRecordReader rr = new CSVLineSequenceRecordReader(); | ||||
|         rr.initialize(new FileSplit(source)); | ||||
|         List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c"))); | ||||
|         List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4"))); | ||||
|         List<List<Writable>> exp0 = Arrays.asList(Collections.singletonList(new Text("a")), Collections.singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c"))); | ||||
|         List<List<Writable>> exp1 = Arrays.asList(Collections.singletonList(new Text("1")), Collections.singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4"))); | ||||
|         for (int i = 0; i < 3; i++) { | ||||
|             int count = 0; | ||||
|             while (rr.hasNext()) { | ||||
|  | ||||
| @ -27,6 +27,7 @@ import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| @ -41,8 +42,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Csv Multi Sequence Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader; | ||||
| import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -34,12 +35,15 @@ import java.util.List; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Csvn Lines Sequence Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     @DisplayName("Test CSVN Lines Sequence Record Reader") | ||||
|     @DisplayName("Test CSV Lines Sequence Record Reader") | ||||
|     void testCSVNLinesSequenceRecordReader() throws Exception { | ||||
|         int nLinesPerSequence = 10; | ||||
|         SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence); | ||||
|  | ||||
| @ -34,9 +34,12 @@ import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.nio.file.Files; | ||||
| @ -50,6 +53,8 @@ import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| 
 | ||||
| @DisplayName("Csv Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class CSVRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -27,6 +27,7 @@ import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.api.split.NumberedFileInputSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| @ -43,8 +44,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Csv Sequence Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class CSVSequenceRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | ||||
| import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -32,8 +33,11 @@ import java.util.List; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Csv Variable Sliding Window Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -28,6 +28,7 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader; | ||||
| import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.junit.jupiter.params.ParameterizedTest; | ||||
| @ -42,9 +43,12 @@ import static org.junit.jupiter.api.Assertions.*; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.factory.Nd4jBackend; | ||||
| 
 | ||||
| @DisplayName("File Batch Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class FileBatchRecordReaderTest extends BaseND4JTest { | ||||
|     @TempDir  Path testDir; | ||||
| 
 | ||||
|  | ||||
| @ -25,6 +25,7 @@ import org.datavec.api.split.CollectionInputSplit; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("File Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class FileRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -28,10 +28,12 @@ import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.jackson.core.JsonFactory; | ||||
| import org.nd4j.shade.jackson.databind.ObjectMapper; | ||||
| import java.io.File; | ||||
| @ -45,6 +47,8 @@ import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @DisplayName("Jackson Line Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class JacksonLineRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -31,10 +31,12 @@ import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.jackson.core.JsonFactory; | ||||
| import org.nd4j.shade.jackson.databind.ObjectMapper; | ||||
| import org.nd4j.shade.jackson.dataformat.xml.XmlFactory; | ||||
| @ -51,6 +53,8 @@ import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @DisplayName("Jackson Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class JacksonRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| 
 | ||||
| @DisplayName("Lib Svm Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class LibSvmRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -30,6 +30,7 @@ import org.datavec.api.split.InputSplit; | ||||
| import org.datavec.api.split.InputStreamInputSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| @ -47,8 +48,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Line Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class LineReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -33,6 +33,7 @@ import org.datavec.api.split.NumberedFileInputSplit; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| @ -46,8 +47,11 @@ import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Regex Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class RegexRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -35,9 +36,13 @@ import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| 
 | ||||
| @DisplayName("Svm Light Record Reader Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class SVMLightRecordReaderTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -26,15 +26,18 @@ import org.datavec.api.records.reader.SequenceRecordReader; | ||||
| import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestCollectionRecordReaders extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -23,12 +23,15 @@ package org.datavec.api.records.reader.impl; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestConcatenatingRecordReader extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -37,9 +37,11 @@ import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.jackson.core.JsonFactory; | ||||
| import org.nd4j.shade.jackson.databind.ObjectMapper; | ||||
| 
 | ||||
| @ -48,7 +50,8 @@ import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestSerialization extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -30,9 +30,11 @@ import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.LongWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| @ -41,6 +43,8 @@ import java.util.List; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TransformProcessRecordReaderTests extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -74,11 +78,11 @@ public class TransformProcessRecordReaderTests extends BaseND4JTest { | ||||
|     public void simpleTransformTestSequence() { | ||||
|         List<List<Writable>> sequence = new ArrayList<>(); | ||||
|         //First window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0), | ||||
|                         new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1), | ||||
|                         new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2), | ||||
|                         new IntWritable(0))); | ||||
| 
 | ||||
|         Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.BeforeEach; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import java.io.File; | ||||
| @ -34,8 +35,11 @@ import java.util.List; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Csv Record Writer Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class CSVRecordWriterTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @BeforeEach | ||||
|  | ||||
| @ -29,8 +29,10 @@ import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -46,6 +48,8 @@ import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| 
 | ||||
| @DisplayName("Lib Svm Record Writer Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class LibSvmRecordWriterTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -26,8 +26,10 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.split.partition.NumberOfRecordsPartitioner; | ||||
| import org.datavec.api.writable.*; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -43,6 +45,8 @@ import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| 
 | ||||
| @DisplayName("Svm Light Record Writer Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class SVMLightRecordWriterTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -20,7 +20,9 @@ | ||||
| 
 | ||||
| package org.datavec.api.split; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.guava.io.Files; | ||||
| import org.datavec.api.io.filters.BalancedPathFilter; | ||||
| import org.datavec.api.io.filters.RandomPathFilter; | ||||
| @ -42,6 +44,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
|  * | ||||
|  * @author saudet | ||||
|  */ | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class InputSplitTests extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -20,13 +20,16 @@ | ||||
| 
 | ||||
| package org.datavec.api.split; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.net.URI; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class NumberedFileInputSplitTests  extends BaseND4JTest { | ||||
|     @Test | ||||
|     public void testNumberedFileInputSplitBasic() { | ||||
|  | ||||
| @ -26,11 +26,13 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.function.Function; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.FileInputStream; | ||||
| @ -46,7 +48,8 @@ import java.util.Random; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertNotEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestStreamInputSplit extends BaseND4JTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -19,6 +19,7 @@ | ||||
|  */ | ||||
| package org.datavec.api.split; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import java.net.URI; | ||||
| @ -28,11 +29,14 @@ import static java.util.Arrays.asList; | ||||
| import static org.junit.jupiter.api.Assertions.assertArrayEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| /** | ||||
|  * @author Ede Meijer | ||||
|  */ | ||||
| @DisplayName("Transform Split Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class TransformSplitTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -20,7 +20,9 @@ | ||||
| 
 | ||||
| package org.datavec.api.split.parittion; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.guava.io.Files; | ||||
| import org.datavec.api.conf.Configuration; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| @ -33,7 +35,8 @@ import java.io.File; | ||||
| import java.io.OutputStream; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class PartitionerTests extends BaseND4JTest { | ||||
|     @Test | ||||
|     public void testRecordsPerFilePartition() { | ||||
|  | ||||
| @ -29,13 +29,16 @@ import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestTransformProcess extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -27,14 +27,17 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.transform.transform.TestTransforms; | ||||
| import org.datavec.api.writable.*; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestConditions extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -27,8 +27,10 @@ import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| @ -38,7 +40,8 @@ import java.util.List; | ||||
| import static java.util.Arrays.asList; | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestFilters  extends BaseND4JTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -26,9 +26,11 @@ import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.NullWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.nio.file.Path; | ||||
| import java.util.ArrayList; | ||||
| @ -37,7 +39,8 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestJoin extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -33,14 +33,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.*; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.fail; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestMultiOpReduce extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -24,14 +24,17 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; | ||||
| import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestReductions extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -22,11 +22,15 @@ package org.datavec.api.transform.schema; | ||||
| 
 | ||||
| import org.datavec.api.transform.metadata.ColumnMetaData; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JACKSON_SERDE) | ||||
| public class TestJsonYaml extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -21,11 +21,14 @@ | ||||
| package org.datavec.api.transform.schema; | ||||
| 
 | ||||
| import org.datavec.api.transform.ColumnType; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestSchemaMethods extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -33,8 +33,10 @@ import org.datavec.api.writable.LongWritable; | ||||
| import org.datavec.api.writable.NullWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| @ -42,7 +44,8 @@ import java.util.List; | ||||
| import java.util.concurrent.TimeUnit; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestReduceSequenceByWindowFunction extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -27,8 +27,10 @@ import org.datavec.api.writable.LongWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| @ -36,7 +38,8 @@ import java.util.List; | ||||
| import java.util.concurrent.TimeUnit; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestSequenceSplit extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -46,13 +49,13 @@ public class TestSequenceSplit extends BaseND4JTest { | ||||
|                         .build(); | ||||
| 
 | ||||
|         List<List<Writable>> inputSequence = new ArrayList<>(); | ||||
|         inputSequence.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0"))); | ||||
|         inputSequence.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1"))); | ||||
|         inputSequence.add(Arrays.asList(new LongWritable(0), new Text("t0"))); | ||||
|         inputSequence.add(Arrays.asList(new LongWritable(1000), new Text("t1"))); | ||||
|         //Second split: 74 seconds later | ||||
|         inputSequence.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2"))); | ||||
|         inputSequence.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3"))); | ||||
|         inputSequence.add(Arrays.asList(new LongWritable(75000), new Text("t2"))); | ||||
|         inputSequence.add(Arrays.asList(new LongWritable(100000), new Text("t3"))); | ||||
|         //Third split: 1 minute and 1 milliseconds later | ||||
|         inputSequence.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4"))); | ||||
|         inputSequence.add(Arrays.asList(new LongWritable(160001), new Text("t4"))); | ||||
| 
 | ||||
|         SequenceSplit seqSplit = new SequenceSplitTimeSeparation("time", 1, TimeUnit.MINUTES); | ||||
|         seqSplit.setInputSchema(schema); | ||||
| @ -61,13 +64,13 @@ public class TestSequenceSplit extends BaseND4JTest { | ||||
|         assertEquals(3, splits.size()); | ||||
| 
 | ||||
|         List<List<Writable>> exp0 = new ArrayList<>(); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0"))); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1"))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(0), new Text("t0"))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(1000), new Text("t1"))); | ||||
|         List<List<Writable>> exp1 = new ArrayList<>(); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2"))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3"))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(75000), new Text("t2"))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(100000), new Text("t3"))); | ||||
|         List<List<Writable>> exp2 = new ArrayList<>(); | ||||
|         exp2.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4"))); | ||||
|         exp2.add(Arrays.asList(new LongWritable(160001), new Text("t4"))); | ||||
| 
 | ||||
|         assertEquals(exp0, splits.get(0)); | ||||
|         assertEquals(exp1, splits.get(1)); | ||||
|  | ||||
| @ -29,8 +29,10 @@ import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.LongWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| @ -38,7 +40,8 @@ import java.util.List; | ||||
| import java.util.concurrent.TimeUnit; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestWindowFunctions extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -49,15 +52,15 @@ public class TestWindowFunctions extends BaseND4JTest { | ||||
|         //Create some data. | ||||
|         List<List<Writable>> sequence = new ArrayList<>(); | ||||
|         //First window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2))); | ||||
|         //Second window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3))); | ||||
|         //Third window: empty | ||||
|         //Fourth window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5))); | ||||
| 
 | ||||
|         Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) | ||||
|                         .addColumnInteger("intcolumn").build(); | ||||
| @ -100,15 +103,15 @@ public class TestWindowFunctions extends BaseND4JTest { | ||||
|         //Create some data. | ||||
|         List<List<Writable>> sequence = new ArrayList<>(); | ||||
|         //First window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2))); | ||||
|         //Second window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3))); | ||||
|         //Third window: empty | ||||
|         //Fourth window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5))); | ||||
| 
 | ||||
|         Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) | ||||
|                         .addColumnInteger("intcolumn").build(); | ||||
| @ -150,15 +153,15 @@ public class TestWindowFunctions extends BaseND4JTest { | ||||
|         //Create some data. | ||||
|         List<List<Writable>> sequence = new ArrayList<>(); | ||||
|         //First window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2))); | ||||
|         //Second window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3))); | ||||
|         //Third window: empty | ||||
|         //Fourth window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5))); | ||||
| 
 | ||||
|         Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) | ||||
|                         .addColumnInteger("intcolumn").build(); | ||||
| @ -188,13 +191,13 @@ public class TestWindowFunctions extends BaseND4JTest { | ||||
|         //Create some data. | ||||
|         List<List<Writable>> sequence = new ArrayList<>(); | ||||
|         //First window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); | ||||
| 
 | ||||
| 
 | ||||
|         Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) | ||||
| @ -207,32 +210,32 @@ public class TestWindowFunctions extends BaseND4JTest { | ||||
| 
 | ||||
|         //First window: -1000 to 1000 | ||||
|         List<List<Writable>> exp0 = new ArrayList<>(); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); | ||||
|         //Second window: 0 to 2000 | ||||
|         List<List<Writable>> exp1 = new ArrayList<>(); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); | ||||
|         //Third window: 1000 to 3000 | ||||
|         List<List<Writable>> exp2 = new ArrayList<>(); | ||||
|         exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); | ||||
|         exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); | ||||
|         exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); | ||||
|         exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); | ||||
|         exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); | ||||
|         exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); | ||||
|         //Fourth window: 2000 to 4000 | ||||
|         List<List<Writable>> exp3 = new ArrayList<>(); | ||||
|         exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); | ||||
|         exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); | ||||
|         //Fifth window: 3000 to 5000 | ||||
|         List<List<Writable>> exp4 = new ArrayList<>(); | ||||
|         //Sixth window: 4000 to 6000 | ||||
|         List<List<Writable>> exp5 = new ArrayList<>(); | ||||
|         exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); | ||||
|         exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); | ||||
|         //Seventh window: 5000 to 7000 | ||||
|         List<List<Writable>> exp6 = new ArrayList<>(); | ||||
|         exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); | ||||
|         exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); | ||||
| 
 | ||||
|         List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp4, exp5, exp6); | ||||
| 
 | ||||
| @ -250,13 +253,13 @@ public class TestWindowFunctions extends BaseND4JTest { | ||||
|         //Create some data. | ||||
|         List<List<Writable>> sequence = new ArrayList<>(); | ||||
|         //First window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); | ||||
|         sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); | ||||
| 
 | ||||
| 
 | ||||
|         Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) | ||||
| @ -272,31 +275,31 @@ public class TestWindowFunctions extends BaseND4JTest { | ||||
| 
 | ||||
|         //First window: -1000 to 1000 | ||||
|         List<List<Writable>> exp0 = new ArrayList<>(); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); | ||||
|         exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); | ||||
|         exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); | ||||
|         //Second window: 0 to 2000 | ||||
|         List<List<Writable>> exp1 = new ArrayList<>(); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); | ||||
|         exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); | ||||
|         exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); | ||||
|         //Third window: 1000 to 3000 | ||||
|         List<List<Writable>> exp2 = new ArrayList<>(); | ||||
|         exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); | ||||
|         exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); | ||||
|         exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); | ||||
|         exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); | ||||
|         exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); | ||||
|         exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); | ||||
|         //Fourth window: 2000 to 4000 | ||||
|         List<List<Writable>> exp3 = new ArrayList<>(); | ||||
|         exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); | ||||
|         exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); | ||||
|         //Fifth window: 3000 to 5000 -> Empty: excluded | ||||
|         //Sixth window: 4000 to 6000 | ||||
|         List<List<Writable>> exp5 = new ArrayList<>(); | ||||
|         exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); | ||||
|         exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); | ||||
|         //Seventh window: 5000 to 7000 | ||||
|         List<List<Writable>> exp6 = new ArrayList<>(); | ||||
|         exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); | ||||
|         exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); | ||||
| 
 | ||||
|         List<List<List<Writable>>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp5, exp6); | ||||
| 
 | ||||
|  | ||||
| @ -26,11 +26,17 @@ import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.transform.serde.testClasses.CustomCondition; | ||||
| import org.datavec.api.transform.serde.testClasses.CustomFilter; | ||||
| import org.datavec.api.transform.serde.testClasses.CustomTransform; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JACKSON_SERDE) | ||||
| @Tag(TagNames.CUSTOM_FUNCTIONALITY) | ||||
| public class TestCustomTransformJsonYaml extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -64,14 +64,19 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform; | ||||
| import org.datavec.api.writable.comparator.DoubleWritableComparator; | ||||
| import org.joda.time.DateTimeFieldType; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| import java.util.concurrent.TimeUnit; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JACKSON_SERDE) | ||||
| public class TestYamlJsonSerde  extends BaseND4JTest { | ||||
| 
 | ||||
|     public static YamlSerializer y = new YamlSerializer(); | ||||
|  | ||||
| @ -24,22 +24,26 @@ import org.datavec.api.transform.StringReduceOp; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestReduce extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testReducerDouble() { | ||||
| 
 | ||||
|         List<List<Writable>> inputs = new ArrayList<>(); | ||||
|         inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); | ||||
|         inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); | ||||
|         inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); | ||||
|         inputs.add(Arrays.asList(new Text("1"), new Text("2"))); | ||||
|         inputs.add(Arrays.asList(new Text("1"), new Text("2"))); | ||||
|         inputs.add(Arrays.asList(new Text("1"), new Text("2"))); | ||||
| 
 | ||||
|         Map<StringReduceOp, String> exp = new LinkedHashMap<>(); | ||||
|         exp.put(StringReduceOp.MERGE, "12"); | ||||
|  | ||||
| @ -37,10 +37,12 @@ import org.datavec.api.writable.Writable; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.file.Path; | ||||
| @ -49,7 +51,9 @@ import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.UI) | ||||
| public class TestUI extends BaseND4JTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| package org.datavec.api.util; | ||||
| 
 | ||||
| import org.junit.jupiter.api.BeforeEach; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import java.io.BufferedReader; | ||||
| @ -33,8 +34,11 @@ import static org.hamcrest.core.AnyOf.anyOf; | ||||
| import static org.hamcrest.core.IsEqual.equalTo; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Class Path Resource Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class ClassPathResourceTest extends BaseND4JTest { | ||||
| 
 | ||||
|     // File sizes are reported slightly different on Linux vs. Windows | ||||
|  | ||||
| @ -22,8 +22,10 @@ package org.datavec.api.util; | ||||
| import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| @ -32,6 +34,8 @@ import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @DisplayName("Time Series Utils Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class TimeSeriesUtilsTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -19,7 +19,9 @@ | ||||
|  */ | ||||
| package org.datavec.api.writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.guava.collect.Lists; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.util.ndarray.RecordConverter; | ||||
| @ -36,6 +38,8 @@ import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @DisplayName("Record Converter Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class RecordConverterTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -21,15 +21,18 @@ | ||||
| package org.datavec.api.writable; | ||||
| 
 | ||||
| import org.datavec.api.transform.metadata.NDArrayMetaData; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestNDArrayWritableAndSerialization extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -20,8 +20,10 @@ | ||||
| package org.datavec.api.writable; | ||||
| 
 | ||||
| import org.datavec.api.writable.batch.NDArrayRecordBatch; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataBuffer; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| @ -37,6 +39,8 @@ import org.junit.jupiter.api.DisplayName; | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @DisplayName("Writable Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class WritableTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -42,9 +42,11 @@ import org.datavec.api.writable.*; | ||||
| import org.datavec.arrow.recordreader.ArrowRecordReader; | ||||
| import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| @ -62,6 +64,8 @@ import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @Slf4j | ||||
| @DisplayName("Arrow Converter Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class ArrowConverterTest extends BaseND4JTest { | ||||
| 
 | ||||
|     private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); | ||||
| @ -142,8 +146,8 @@ class ArrowConverterTest extends BaseND4JTest { | ||||
|         List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch); | ||||
|         List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build()); | ||||
|         List<List<Writable>> assertionBatch = new ArrayList<>(); | ||||
|         assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0))); | ||||
|         assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1))); | ||||
|         assertionBatch.add(Arrays.asList(new IntWritable(0), new IntWritable(0))); | ||||
|         assertionBatch.add(Arrays.asList(new IntWritable(1), new IntWritable(1))); | ||||
|         assertEquals(assertionBatch, batchRecords); | ||||
|     } | ||||
| 
 | ||||
| @ -156,11 +160,11 @@ class ArrowConverterTest extends BaseND4JTest { | ||||
|             schema.addColumnTime(String.valueOf(i), TimeZone.getDefault()); | ||||
|             single.add(String.valueOf(i)); | ||||
|         } | ||||
|         List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3))); | ||||
|         List<List<Writable>> input = Arrays.asList(Arrays.asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3))); | ||||
|         List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input); | ||||
|         ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build()); | ||||
|         List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)); | ||||
|         writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5))); | ||||
|         List<Writable> assertion = Arrays.asList(new LongWritable(4), new LongWritable(5)); | ||||
|         writableRecordBatch.set(1, Arrays.asList(new LongWritable(4), new LongWritable(5))); | ||||
|         List<Writable> recordTest = writableRecordBatch.get(1); | ||||
|         assertEquals(assertion, recordTest); | ||||
|     } | ||||
| @ -174,11 +178,11 @@ class ArrowConverterTest extends BaseND4JTest { | ||||
|             schema.addColumnInteger(String.valueOf(i)); | ||||
|             single.add(String.valueOf(i)); | ||||
|         } | ||||
|         List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3))); | ||||
|         List<List<Writable>> input = Arrays.asList(Arrays.asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3))); | ||||
|         List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input); | ||||
|         ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build()); | ||||
|         List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)); | ||||
|         writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5))); | ||||
|         List<Writable> assertion = Arrays.asList(new IntWritable(4), new IntWritable(5)); | ||||
|         writableRecordBatch.set(1, Arrays.asList(new IntWritable(4), new IntWritable(5))); | ||||
|         List<Writable> recordTest = writableRecordBatch.get(1); | ||||
|         assertEquals(assertion, recordTest); | ||||
|     } | ||||
|  | ||||
| @ -33,6 +33,7 @@ import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.arrow.recordreader.ArrowRecordReader; | ||||
| import org.datavec.arrow.recordreader.ArrowRecordWriter; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.primitives.Triple; | ||||
| @ -44,8 +45,11 @@ import java.util.List; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Record Mapper Test") | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class RecordMapperTest extends BaseND4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -30,8 +30,10 @@ import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.arrow.ArrowConverter; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| @ -39,7 +41,8 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| 
 | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { | ||||
| 
 | ||||
|     private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); | ||||
|  | ||||
| @ -25,6 +25,7 @@ import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.image.recordreader.ImageRecordReader; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -36,8 +37,12 @@ import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Label Generator Test") | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class LabelGeneratorTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -23,7 +23,10 @@ package org.datavec.image.loader; | ||||
| import org.apache.commons.io.FilenameUtils; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| 
 | ||||
| import java.io.File; | ||||
| @ -39,6 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class LoaderTests { | ||||
| 
 | ||||
|     private static void ensureDataAvailable(){ | ||||
| @ -81,7 +86,7 @@ public class LoaderTests { | ||||
|         String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin"; | ||||
|         String path = FilenameUtils.concat(System.getProperty("user.home"), subDir); | ||||
|         byte[] fullDataExpected = new byte[3073]; | ||||
|         FileInputStream inExpected = new FileInputStream(new File(path)); | ||||
|         FileInputStream inExpected = new FileInputStream(path); | ||||
|         inExpected.read(fullDataExpected); | ||||
| 
 | ||||
|         byte[] fullDataActual = new byte[3073]; | ||||
| @ -94,7 +99,7 @@ public class LoaderTests { | ||||
|         subDir = "cifar/cifar-10-batches-bin/test_batch.bin"; | ||||
|         path = FilenameUtils.concat(System.getProperty("user.home"), subDir); | ||||
|         fullDataExpected = new byte[3073]; | ||||
|         inExpected = new FileInputStream(new File(path)); | ||||
|         inExpected = new FileInputStream(path); | ||||
|         inExpected.read(fullDataExpected); | ||||
| 
 | ||||
|         fullDataActual = new byte[3073]; | ||||
|  | ||||
| @ -21,8 +21,11 @@ | ||||
| package org.datavec.image.loader; | ||||
| 
 | ||||
| import org.datavec.image.data.Image; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.resources.Resources; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.awt.image.BufferedImage; | ||||
| @ -34,7 +37,8 @@ import java.util.Random; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| 
 | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestImageLoader { | ||||
| 
 | ||||
|     private static long seed = 10; | ||||
|  | ||||
| @ -31,10 +31,13 @@ import org.bytedeco.javacv.OpenCVFrameConverter; | ||||
| import org.datavec.image.data.Image; | ||||
| import org.datavec.image.data.ImageWritable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.resources.Resources; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| @ -60,6 +63,8 @@ import static org.junit.jupiter.api.Assertions.fail; | ||||
|  * @author saudet | ||||
|  */ | ||||
| @Slf4j | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestNativeImageLoader { | ||||
|     static final long seed = 10; | ||||
|     static final Random rng = new Random(seed); | ||||
|  | ||||
| @ -28,9 +28,12 @@ import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.image.loader.NativeImageLoader; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.loader.FileBatch; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import java.io.File; | ||||
| @ -41,6 +44,8 @@ import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @DisplayName("File Batch Record Reader Test") | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class FileBatchRecordReaderTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -37,9 +37,12 @@ import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.api.writable.batch.NDArrayRecordBatch; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| @ -54,7 +57,8 @@ import java.util.List; | ||||
| import java.util.Random; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestImageRecordReader { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -36,9 +36,12 @@ import org.datavec.image.transform.ImageTransform; | ||||
| import org.datavec.image.transform.PipelineImageTransform; | ||||
| import org.datavec.image.transform.ResizeImageTransform; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.indexing.BooleanIndexing; | ||||
| @ -54,7 +57,8 @@ import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestObjectDetectionRecordReader { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -22,10 +22,13 @@ package org.datavec.image.recordreader.objdetect; | ||||
| 
 | ||||
| import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.file.Path; | ||||
| @ -34,7 +37,8 @@ import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class TestVocLabelProvider { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -20,6 +20,7 @@ | ||||
| package org.datavec.image.transform; | ||||
| 
 | ||||
| import org.datavec.image.data.ImageWritable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import java.io.IOException; | ||||
| import java.util.List; | ||||
| @ -28,8 +29,13 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Json Yaml Test") | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JACKSON_SERDE) | ||||
| class JsonYamlTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -22,12 +22,17 @@ package org.datavec.image.transform; | ||||
| import org.bytedeco.javacv.Frame; | ||||
| import org.datavec.image.data.ImageWritable; | ||||
| import org.junit.jupiter.api.BeforeEach; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Resize Image Transform Test") | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class ResizeImageTransformTest { | ||||
| 
 | ||||
|     @BeforeEach | ||||
|  | ||||
| @ -24,6 +24,7 @@ import org.bytedeco.javacpp.indexer.UByteIndexer; | ||||
| import org.bytedeco.javacv.CanvasFrame; | ||||
| import org.bytedeco.javacv.Frame; | ||||
| import org.bytedeco.javacv.OpenCVFrameConverter; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| import org.datavec.image.data.ImageWritable; | ||||
| @ -37,6 +38,8 @@ import java.util.List; | ||||
| import java.util.Random; | ||||
| 
 | ||||
| import org.bytedeco.opencv.opencv_core.*; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.bytedeco.opencv.global.opencv_core.*; | ||||
| import static org.bytedeco.opencv.global.opencv_imgproc.*; | ||||
| @ -46,6 +49,8 @@ import static org.junit.jupiter.api.Assertions.*; | ||||
|  * | ||||
|  * @author saudet | ||||
|  */ | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| public class TestImageTransform { | ||||
|     static final long seed = 10; | ||||
|     static final Random rng = new Random(seed); | ||||
|  | ||||
| @ -22,6 +22,7 @@ package org.datavec.poi.excel; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import java.util.List; | ||||
| @ -29,8 +30,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Excel Record Reader Test") | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| class ExcelRecordReaderTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.primitives.Triple; | ||||
| @ -36,8 +37,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @DisplayName("Excel Record Writer Test") | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| class ExcelRecordWriterTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -47,17 +47,19 @@ import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.LongWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.junit.jupiter.api.AfterEach; | ||||
| import org.junit.jupiter.api.BeforeEach; | ||||
| import org.junit.jupiter.api.*; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| 
 | ||||
| import java.nio.file.Path; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| 
 | ||||
| @DisplayName("Jdbc Record Reader Test") | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| class JDBCRecordReaderTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -36,15 +36,18 @@ import org.datavec.api.writable.LongWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class LocalTransformProcessRecordReaderTests { | ||||
| 
 | ||||
|     @Test | ||||
| @ -64,11 +67,11 @@ public class LocalTransformProcessRecordReaderTests { | ||||
|     public void simpleTransformTestSequence() { | ||||
|         List<List<Writable>> sequence = new ArrayList<>(); | ||||
|         //First window: | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0), | ||||
|                 new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1), | ||||
|                 new IntWritable(0))); | ||||
|         sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), | ||||
|         sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2), | ||||
|                 new IntWritable(0))); | ||||
| 
 | ||||
|         Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) | ||||
|  | ||||
| @ -30,8 +30,10 @@ import org.datavec.api.util.ndarray.RecordConverter; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.local.transforms.AnalyzeLocal; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -40,7 +42,8 @@ import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class TestAnalyzeLocal { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.util.HashSet; | ||||
| @ -38,7 +40,8 @@ import java.util.stream.Collectors; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class TestLineRecordReaderFunction  { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.datavec.local.transforms.misc.NDArrayToWritablesFunction; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| @ -34,7 +37,8 @@ import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @NativeTag | ||||
| public class TestNDArrayToWritablesFunction { | ||||
| 
 | ||||
|     @Test | ||||
| @ -50,7 +54,7 @@ public class TestNDArrayToWritablesFunction { | ||||
|     @Test | ||||
|     public void testNDArrayToWritablesArray() throws Exception { | ||||
|         INDArray arr = Nd4j.arange(5); | ||||
|         List<Writable> expected = Arrays.asList((Writable) new NDArrayWritable(arr)); | ||||
|         List<Writable> expected = Arrays.asList(new NDArrayWritable(arr)); | ||||
|         List<Writable> actual = new NDArrayToWritablesFunction(true).apply(arr); | ||||
|         assertEquals(expected, actual); | ||||
|     } | ||||
|  | ||||
| @ -25,7 +25,10 @@ import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.datavec.local.transforms.misc.WritablesToNDArrayFunction; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| @ -34,7 +37,8 @@ import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @NativeTag | ||||
| public class TestWritablesToNDArrayFunction { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -30,13 +30,16 @@ import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction; | ||||
| import org.datavec.local.transforms.misc.WritablesToStringFunction; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class TestWritablesToStringFunctions  { | ||||
| 
 | ||||
| 
 | ||||
| @ -44,7 +47,7 @@ public class TestWritablesToStringFunctions  { | ||||
|     @Test | ||||
|     public void testWritablesToString() throws Exception { | ||||
| 
 | ||||
|         List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")); | ||||
|         List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue")); | ||||
|         String expected = l.get(0).toString() + "," + l.get(1).toString(); | ||||
| 
 | ||||
|         assertEquals(expected, new WritablesToStringFunction(",").apply(l)); | ||||
| @ -53,8 +56,8 @@ public class TestWritablesToStringFunctions  { | ||||
|     @Test | ||||
|     public void testSequenceWritablesToString() throws Exception { | ||||
| 
 | ||||
|         List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")), | ||||
|                         Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue"))); | ||||
|         List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")), | ||||
|                         Arrays.asList(new DoubleWritable(2.5), new Text("otherValue"))); | ||||
| 
 | ||||
|         String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n" | ||||
|                         + l.get(1).get(0).toString() + "," + l.get(1).get(1).toString(); | ||||
|  | ||||
| @ -31,7 +31,10 @@ import org.datavec.api.transform.schema.SequenceSchema; | ||||
| import org.datavec.api.writable.*; | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.ops.transforms.Transforms; | ||||
| @ -42,6 +45,8 @@ import static java.time.Duration.ofMillis; | ||||
| import static org.junit.jupiter.api.Assertions.assertTimeout; | ||||
| 
 | ||||
| @DisplayName("Execution Test") | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @NativeTag | ||||
| class ExecutionTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -71,18 +76,12 @@ class ExecutionTest { | ||||
|         Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build(); | ||||
|         TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build(); | ||||
|         List<List<Writable>> inputData = new ArrayList<>(); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); | ||||
|         List<List<Writable>> rdd = (inputData); | ||||
|         List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp)); | ||||
|         Collections.sort(out, new Comparator<List<Writable>>() { | ||||
| 
 | ||||
|             @Override | ||||
|             public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                 return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); | ||||
|             } | ||||
|         }); | ||||
|         Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt())); | ||||
|         List<List<Writable>> expected = new ArrayList<>(); | ||||
|         expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f))); | ||||
|         expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f))); | ||||
| @ -95,9 +94,9 @@ class ExecutionTest { | ||||
|     void testFilter() { | ||||
|         Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build(); | ||||
|         List<List<Writable>> inputData = new ArrayList<>(); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); | ||||
|         TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build(); | ||||
|         List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess); | ||||
|         assertEquals(2, execute.size()); | ||||
| @ -110,31 +109,25 @@ class ExecutionTest { | ||||
|         TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); | ||||
|         List<List<List<Writable>>> inputSequences = new ArrayList<>(); | ||||
|         List<List<Writable>> seq1 = new ArrayList<>(); | ||||
|         seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         List<List<Writable>> seq2 = new ArrayList<>(); | ||||
|         seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); | ||||
|         seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); | ||||
|         seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); | ||||
|         seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); | ||||
|         inputSequences.add(seq1); | ||||
|         inputSequences.add(seq2); | ||||
|         List<List<List<Writable>>> rdd = (inputSequences); | ||||
|         List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp); | ||||
|         Collections.sort(out, new Comparator<List<List<Writable>>>() { | ||||
| 
 | ||||
|             @Override | ||||
|             public int compare(List<List<Writable>> o1, List<List<Writable>> o2) { | ||||
|                 return -Integer.compare(o1.size(), o2.size()); | ||||
|             } | ||||
|         }); | ||||
|         Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size())); | ||||
|         List<List<List<Writable>>> expectedSequence = new ArrayList<>(); | ||||
|         List<List<Writable>> seq1e = new ArrayList<>(); | ||||
|         seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|         seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|         seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|         seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|         seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|         seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|         List<List<Writable>> seq2e = new ArrayList<>(); | ||||
|         seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); | ||||
|         seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); | ||||
|         seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); | ||||
|         seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); | ||||
|         expectedSequence.add(seq1e); | ||||
|         expectedSequence.add(seq2e); | ||||
|         assertEquals(expectedSequence, out); | ||||
| @ -143,26 +136,26 @@ class ExecutionTest { | ||||
|     @Test | ||||
|     @DisplayName("Test Reduction Global") | ||||
|     void testReductionGlobal() { | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))); | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))); | ||||
|         List<List<Writable>> inData = in; | ||||
|         Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build(); | ||||
|         TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); | ||||
|         List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); | ||||
|         List<List<Writable>> out = outRdd; | ||||
|         List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0))); | ||||
|         List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); | ||||
|         assertEquals(expOut, out); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     @DisplayName("Test Reduction By Key") | ||||
|     void testReductionByKey() { | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); | ||||
|         List<List<Writable>> inData = in; | ||||
|         Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build(); | ||||
|         TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); | ||||
|         List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp); | ||||
|         List<List<Writable>> out = outRdd; | ||||
|         List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); | ||||
|         List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); | ||||
|         out = new ArrayList<>(out); | ||||
|         Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt())); | ||||
|         assertEquals(expOut, out); | ||||
|  | ||||
| @ -28,12 +28,15 @@ import org.datavec.api.writable.*; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class TestJoin  { | ||||
| 
 | ||||
|     @Test | ||||
| @ -46,27 +49,27 @@ public class TestJoin  { | ||||
|                         .addColumnDouble("amount").build(); | ||||
| 
 | ||||
|         List<List<Writable>> infoList = new ArrayList<>(); | ||||
|         infoList.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"))); | ||||
|         infoList.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765"))); | ||||
|         infoList.add(Arrays.<Writable>asList(new LongWritable(50000), new Text("Customer50000"))); | ||||
|         infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"))); | ||||
|         infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"))); | ||||
|         infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000"))); | ||||
| 
 | ||||
|         List<List<Writable>> purchaseList = new ArrayList<>(); | ||||
|         purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000000), new LongWritable(12345), | ||||
|         purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), | ||||
|                         new DoubleWritable(10.00))); | ||||
|         purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000001), new LongWritable(12345), | ||||
|         purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), | ||||
|                         new DoubleWritable(20.00))); | ||||
|         purchaseList.add(Arrays.<Writable>asList(new LongWritable(1000002), new LongWritable(98765), | ||||
|         purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), | ||||
|                         new DoubleWritable(30.00))); | ||||
| 
 | ||||
|         Join join = new Join.Builder(Join.JoinType.RightOuter).setJoinColumns("customerID") | ||||
|                         .setSchemas(customerInfoSchema, purchasesSchema).build(); | ||||
| 
 | ||||
|         List<List<Writable>> expected = new ArrayList<>(); | ||||
|         expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"), | ||||
|         expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), | ||||
|                         new LongWritable(1000000), new DoubleWritable(10.00))); | ||||
|         expected.add(Arrays.<Writable>asList(new LongWritable(12345), new Text("Customer12345"), | ||||
|         expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), | ||||
|                         new LongWritable(1000001), new DoubleWritable(20.00))); | ||||
|         expected.add(Arrays.<Writable>asList(new LongWritable(98765), new Text("Customer98765"), | ||||
|         expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"), | ||||
|                         new LongWritable(1000002), new DoubleWritable(30.00))); | ||||
| 
 | ||||
| 
 | ||||
| @ -77,12 +80,7 @@ public class TestJoin  { | ||||
|         List<List<Writable>> joined = LocalTransformExecutor.executeJoin(join, info, purchases); | ||||
|         List<List<Writable>> joinedList = new ArrayList<>(joined); | ||||
|         //Sort by order ID (column 3, index 2) | ||||
|         Collections.sort(joinedList, new Comparator<List<Writable>>() { | ||||
|             @Override | ||||
|             public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                 return Long.compare(o1.get(2).toLong(), o2.get(2).toLong()); | ||||
|             } | ||||
|         }); | ||||
|         Collections.sort(joinedList, (o1, o2) -> Long.compare(o1.get(2).toLong(), o2.get(2).toLong())); | ||||
|         assertEquals(expected, joinedList); | ||||
| 
 | ||||
|         assertEquals(3, joinedList.size()); | ||||
| @ -110,12 +108,7 @@ public class TestJoin  { | ||||
|         List<List<Writable>> joined2 = LocalTransformExecutor.executeJoin(join2, purchases, info); | ||||
|         List<List<Writable>> joinedList2 = new ArrayList<>(joined2); | ||||
|         //Sort by order ID (column 0) | ||||
|         Collections.sort(joinedList2, new Comparator<List<Writable>>() { | ||||
|             @Override | ||||
|             public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                 return Long.compare(o1.get(0).toLong(), o2.get(0).toLong()); | ||||
|             } | ||||
|         }); | ||||
|         Collections.sort(joinedList2, (o1, o2) -> Long.compare(o1.get(0).toLong(), o2.get(0).toLong())); | ||||
|         assertEquals(3, joinedList2.size()); | ||||
| 
 | ||||
|         assertEquals(expectedManyToOne, joinedList2); | ||||
| @ -189,29 +182,26 @@ public class TestJoin  { | ||||
|                             new ArrayList<>(LocalTransformExecutor.executeJoin(join, firstRDD, secondRDD)); | ||||
| 
 | ||||
|             //Sort output by column 0, then column 1, then column 2 for comparison to expected... | ||||
|             Collections.sort(out, new Comparator<List<Writable>>() { | ||||
|                 @Override | ||||
|                 public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                     Writable w1 = o1.get(0); | ||||
|                     Writable w2 = o2.get(0); | ||||
|                     if (w1 instanceof NullWritable) | ||||
|                         return 1; | ||||
|                     else if (w2 instanceof NullWritable) | ||||
|                         return -1; | ||||
|                     int c = Long.compare(w1.toLong(), w2.toLong()); | ||||
|                     if (c != 0) | ||||
|                         return c; | ||||
|                     c = o1.get(1).toString().compareTo(o2.get(1).toString()); | ||||
|                     if (c != 0) | ||||
|                         return c; | ||||
|                     w1 = o1.get(2); | ||||
|                     w2 = o2.get(2); | ||||
|                     if (w1 instanceof NullWritable) | ||||
|                         return 1; | ||||
|                     else if (w2 instanceof NullWritable) | ||||
|                         return -1; | ||||
|                     return Long.compare(w1.toLong(), w2.toLong()); | ||||
|                 } | ||||
|             Collections.sort(out, (o1, o2) -> { | ||||
|                 Writable w1 = o1.get(0); | ||||
|                 Writable w2 = o2.get(0); | ||||
|                 if (w1 instanceof NullWritable) | ||||
|                     return 1; | ||||
|                 else if (w2 instanceof NullWritable) | ||||
|                     return -1; | ||||
|                 int c = Long.compare(w1.toLong(), w2.toLong()); | ||||
|                 if (c != 0) | ||||
|                     return c; | ||||
|                 c = o1.get(1).toString().compareTo(o2.get(1).toString()); | ||||
|                 if (c != 0) | ||||
|                     return c; | ||||
|                 w1 = o1.get(2); | ||||
|                 w2 = o2.get(2); | ||||
|                 if (w1 instanceof NullWritable) | ||||
|                     return 1; | ||||
|                 else if (w2 instanceof NullWritable) | ||||
|                     return -1; | ||||
|                 return Long.compare(w1.toLong(), w2.toLong()); | ||||
|             }); | ||||
| 
 | ||||
|             switch (jt) { | ||||
|  | ||||
| @ -31,14 +31,17 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator; | ||||
| 
 | ||||
| 
 | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class TestCalculateSortedRank  { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -31,7 +31,9 @@ import org.datavec.api.writable.Writable; | ||||
| 
 | ||||
| import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.Arrays; | ||||
| import java.util.Collections; | ||||
| @ -39,7 +41,8 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| public class TestConvertToSequence  { | ||||
| 
 | ||||
|     @Test | ||||
| @ -48,12 +51,12 @@ public class TestConvertToSequence  { | ||||
|         Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build(); | ||||
| 
 | ||||
|         List<List<Writable>> allExamples = | ||||
|                         Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), | ||||
|                         Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), | ||||
|                                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), | ||||
|                                         Arrays.asList(new Text("k1a"), new Text("k2a"), | ||||
|                                                         new LongWritable(-10)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); | ||||
|                                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); | ||||
| 
 | ||||
|         TransformProcess tp = new TransformProcess.Builder(s) | ||||
|                         .convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time")) | ||||
| @ -75,13 +78,13 @@ public class TestConvertToSequence  { | ||||
|         } | ||||
| 
 | ||||
|         List<List<Writable>> expSeq0 = Arrays.asList( | ||||
|                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), | ||||
|                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), | ||||
|                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); | ||||
| 
 | ||||
|         List<List<Writable>> expSeq1 = Arrays.asList( | ||||
|                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); | ||||
|                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); | ||||
| 
 | ||||
|         assertEquals(expSeq0, seq0); | ||||
|         assertEquals(expSeq1, seq1); | ||||
| @ -96,9 +99,9 @@ public class TestConvertToSequence  { | ||||
|                 .build(); | ||||
| 
 | ||||
|         List<List<Writable>> allExamples = Arrays.asList( | ||||
|                 Arrays.<Writable>asList(new Text("a"), new LongWritable(0)), | ||||
|                 Arrays.<Writable>asList(new Text("b"), new LongWritable(1)), | ||||
|                 Arrays.<Writable>asList(new Text("c"), new LongWritable(2))); | ||||
|                 Arrays.asList(new Text("a"), new LongWritable(0)), | ||||
|                 Arrays.asList(new Text("b"), new LongWritable(1)), | ||||
|                 Arrays.asList(new Text("c"), new LongWritable(2))); | ||||
| 
 | ||||
|         TransformProcess tp = new TransformProcess.Builder(s) | ||||
|                 .convertToSequence() | ||||
|  | ||||
| @ -25,8 +25,10 @@ import org.apache.spark.serializer.SerializerInstance; | ||||
| import org.datavec.api.records.reader.RecordReader; | ||||
| import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.ByteBuffer; | ||||
| @ -34,7 +36,10 @@ import java.nio.ByteBuffer; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestKryoSerialization extends BaseSparkTest { | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -27,8 +27,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | ||||
| import org.datavec.api.split.FileSplit; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.util.HashSet; | ||||
| @ -37,7 +39,10 @@ import java.util.Set; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestLineRecordReaderFunction extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -24,7 +24,10 @@ import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.transform.misc.NDArrayToWritablesFunction; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| @ -33,7 +36,10 @@ import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| @NativeTag | ||||
| public class TestNDArrayToWritablesFunction { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -39,10 +39,12 @@ import org.datavec.spark.functions.pairdata.PathToKeyConverter; | ||||
| import org.datavec.spark.functions.pairdata.PathToKeyConverterFilename; | ||||
| import org.datavec.spark.util.DataVecSparkUtil; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import scala.Tuple2; | ||||
| 
 | ||||
| import java.io.File; | ||||
| @ -53,7 +55,10 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.fail; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.functions.data.FilesAsBytesFunction; | ||||
| import org.datavec.spark.functions.data.RecordReaderBytesFunction; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.file.Files; | ||||
| @ -51,7 +53,10 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.fail; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestRecordReaderBytesFunction extends BaseSparkTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -32,10 +32,12 @@ import org.datavec.api.writable.Writable; | ||||
| import org.datavec.image.recordreader.ImageRecordReader; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.file.Path; | ||||
| @ -45,7 +47,10 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.fail; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestRecordReaderFunction extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -37,10 +37,12 @@ import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.functions.data.FilesAsBytesFunction; | ||||
| import org.datavec.spark.functions.data.SequenceRecordReaderBytesFunction; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.file.Files; | ||||
| @ -50,7 +52,10 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.fail; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -34,10 +34,12 @@ import org.datavec.api.writable.Writable; | ||||
| import org.datavec.codec.reader.CodecRecordReader; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.nio.file.Path; | ||||
| @ -46,7 +48,10 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.fail; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestSequenceRecordReaderFunction extends BaseSparkTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -22,7 +22,10 @@ package org.datavec.spark.functions; | ||||
| 
 | ||||
| import org.datavec.api.writable.*; | ||||
| import org.datavec.spark.transform.misc.WritablesToNDArrayFunction; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| @ -31,7 +34,10 @@ import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| @NativeTag | ||||
| public class TestWritablesToNDArrayFunction { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -29,7 +29,9 @@ import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction; | ||||
| import org.datavec.spark.transform.misc.WritablesToStringFunction; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import scala.Tuple2; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| @ -37,7 +39,10 @@ import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestWritablesToStringFunctions extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -57,19 +62,9 @@ public class TestWritablesToStringFunctions extends BaseSparkTest { | ||||
| 
 | ||||
| 
 | ||||
|         JavaSparkContext sc = getContext(); | ||||
|         JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() { | ||||
|             @Override | ||||
|             public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception { | ||||
|                 return stringStringTuple2; | ||||
|             } | ||||
|         }); | ||||
|         JavaPairRDD<String, String> left = sc.parallelize(leftMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2); | ||||
| 
 | ||||
|         JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair(new PairFunction<Tuple2<String, String>, String, String>() { | ||||
|             @Override | ||||
|             public Tuple2<String, String> call(Tuple2<String, String> stringStringTuple2) throws Exception { | ||||
|                 return stringStringTuple2; | ||||
|             } | ||||
|         }); | ||||
|         JavaPairRDD<String, String> right = sc.parallelize(rightMap).mapToPair((PairFunction<Tuple2<String, String>, String, String>) stringStringTuple2 -> stringStringTuple2); | ||||
| 
 | ||||
|         System.out.println(left.cogroup(right).collect()); | ||||
|     } | ||||
| @ -77,7 +72,7 @@ public class TestWritablesToStringFunctions extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testWritablesToString() throws Exception { | ||||
| 
 | ||||
|         List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")); | ||||
|         List<Writable> l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue")); | ||||
|         String expected = l.get(0).toString() + "," + l.get(1).toString(); | ||||
| 
 | ||||
|         assertEquals(expected, new WritablesToStringFunction(",").call(l)); | ||||
| @ -86,8 +81,8 @@ public class TestWritablesToStringFunctions extends BaseSparkTest { | ||||
|     @Test | ||||
|     public void testSequenceWritablesToString() throws Exception { | ||||
| 
 | ||||
|         List<List<Writable>> l = Arrays.asList(Arrays.<Writable>asList(new DoubleWritable(1.5), new Text("someValue")), | ||||
|                         Arrays.<Writable>asList(new DoubleWritable(2.5), new Text("otherValue"))); | ||||
|         List<List<Writable>> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")), | ||||
|                         Arrays.asList(new DoubleWritable(2.5), new Text("otherValue"))); | ||||
| 
 | ||||
|         String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n" | ||||
|                         + l.get(1).get(0).toString() + "," + l.get(1).get(1).toString(); | ||||
|  | ||||
| @ -21,6 +21,8 @@ | ||||
| package org.datavec.spark.storage; | ||||
| 
 | ||||
| import com.sun.jna.Platform; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.guava.io.Files; | ||||
| import org.apache.spark.api.java.JavaPairRDD; | ||||
| import org.apache.spark.api.java.JavaRDD; | ||||
| @ -37,7 +39,10 @@ import java.util.Map; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestSparkStorageUtils extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -46,11 +51,11 @@ public class TestSparkStorageUtils extends BaseSparkTest { | ||||
|             return; | ||||
|         } | ||||
|         List<List<Writable>> l = new ArrayList<>(); | ||||
|         l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0), | ||||
|         l.add(Arrays.asList(new Text("zero"), new IntWritable(0), | ||||
|                         new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0)))); | ||||
|         l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(11), | ||||
|         l.add(Arrays.asList(new Text("one"), new IntWritable(11), | ||||
|                         new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0)))); | ||||
|         l.add(Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(22), | ||||
|         l.add(Arrays.asList(new Text("two"), new IntWritable(22), | ||||
|                         new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0)))); | ||||
| 
 | ||||
|         JavaRDD<List<Writable>> rdd = sc.parallelize(l); | ||||
| @ -92,17 +97,17 @@ public class TestSparkStorageUtils extends BaseSparkTest { | ||||
|         } | ||||
|         List<List<List<Writable>>> l = new ArrayList<>(); | ||||
|         l.add(Arrays.asList( | ||||
|                         Arrays.<org.datavec.api.writable.Writable>asList(new Text("zero"), new IntWritable(0), | ||||
|                         Arrays.asList(new Text("zero"), new IntWritable(0), | ||||
|                                         new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))), | ||||
|                         Arrays.<org.datavec.api.writable.Writable>asList(new Text("one"), new IntWritable(1), | ||||
|                         Arrays.asList(new Text("one"), new IntWritable(1), | ||||
|                                         new DoubleWritable(1.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 1.0))), | ||||
|                         Arrays.<org.datavec.api.writable.Writable>asList(new Text("two"), new IntWritable(2), | ||||
|                         Arrays.asList(new Text("two"), new IntWritable(2), | ||||
|                                         new DoubleWritable(2.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 2.0))))); | ||||
| 
 | ||||
|         l.add(Arrays.asList( | ||||
|                         Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bzero"), new IntWritable(10), | ||||
|                         Arrays.asList(new Text("Bzero"), new IntWritable(10), | ||||
|                                         new DoubleWritable(10), new NDArrayWritable(Nd4j.valueArrayOf(10, 10.0))), | ||||
|                         Arrays.<org.datavec.api.writable.Writable>asList(new Text("Bone"), new IntWritable(11), | ||||
|                         Arrays.asList(new Text("Bone"), new IntWritable(11), | ||||
|                                         new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))), | ||||
|                         Arrays.<org.datavec.api.writable.Writable>asList(new Text("Btwo"), new IntWritable(12), | ||||
|                                         new DoubleWritable(12.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 12.0))))); | ||||
|  | ||||
| @ -30,14 +30,19 @@ import org.datavec.api.util.ndarray.RecordConverter; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class DataFramesTests extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -110,15 +115,15 @@ public class DataFramesTests extends BaseSparkTest { | ||||
|     public void testNormalize() { | ||||
|         List<List<Writable>> data = new ArrayList<>(); | ||||
| 
 | ||||
|         data.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10))); | ||||
|         data.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20))); | ||||
|         data.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30))); | ||||
|         data.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10))); | ||||
|         data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20))); | ||||
|         data.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30))); | ||||
| 
 | ||||
| 
 | ||||
|         List<List<Writable>> expMinMax = new ArrayList<>(); | ||||
|         expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0))); | ||||
|         expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(0.5), new DoubleWritable(0.5))); | ||||
|         expMinMax.add(Arrays.<Writable>asList(new DoubleWritable(1.0), new DoubleWritable(1.0))); | ||||
|         expMinMax.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0))); | ||||
|         expMinMax.add(Arrays.asList(new DoubleWritable(0.5), new DoubleWritable(0.5))); | ||||
|         expMinMax.add(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(1.0))); | ||||
| 
 | ||||
|         double m1 = (1 + 2 + 3) / 3.0; | ||||
|         double s1 = new StandardDeviation().evaluate(new double[] {1, 2, 3}); | ||||
| @ -127,11 +132,11 @@ public class DataFramesTests extends BaseSparkTest { | ||||
| 
 | ||||
|         List<List<Writable>> expStandardize = new ArrayList<>(); | ||||
|         expStandardize.add( | ||||
|                         Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2))); | ||||
|                         Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2))); | ||||
|         expStandardize.add( | ||||
|                         Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2))); | ||||
|                         Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2))); | ||||
|         expStandardize.add( | ||||
|                         Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2))); | ||||
|                         Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2))); | ||||
| 
 | ||||
|         JavaRDD<List<Writable>> rdd = sc.parallelize(data); | ||||
| 
 | ||||
| @ -178,13 +183,13 @@ public class DataFramesTests extends BaseSparkTest { | ||||
|         List<List<List<Writable>>> sequences = new ArrayList<>(); | ||||
| 
 | ||||
|         List<List<Writable>> seq1 = new ArrayList<>(); | ||||
|         seq1.add(Arrays.<Writable>asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100))); | ||||
|         seq1.add(Arrays.<Writable>asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200))); | ||||
|         seq1.add(Arrays.<Writable>asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300))); | ||||
|         seq1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100))); | ||||
|         seq1.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200))); | ||||
|         seq1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300))); | ||||
| 
 | ||||
|         List<List<Writable>> seq2 = new ArrayList<>(); | ||||
|         seq2.add(Arrays.<Writable>asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400))); | ||||
|         seq2.add(Arrays.<Writable>asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500))); | ||||
|         seq2.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400))); | ||||
|         seq2.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500))); | ||||
| 
 | ||||
|         sequences.add(seq1); | ||||
|         sequences.add(seq2); | ||||
| @ -199,21 +204,21 @@ public class DataFramesTests extends BaseSparkTest { | ||||
| 
 | ||||
|         //Min/max normalization: | ||||
|         List<List<Writable>> expSeq1MinMax = new ArrayList<>(); | ||||
|         expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)), | ||||
|         expSeq1MinMax.add(Arrays.asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)), | ||||
|                         new DoubleWritable((10 - 10.0) / (50.0 - 10.0)), | ||||
|                         new DoubleWritable((100 - 100.0) / (500.0 - 100.0)))); | ||||
|         expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)), | ||||
|         expSeq1MinMax.add(Arrays.asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)), | ||||
|                         new DoubleWritable((20 - 10.0) / (50.0 - 10.0)), | ||||
|                         new DoubleWritable((200 - 100.0) / (500.0 - 100.0)))); | ||||
|         expSeq1MinMax.add(Arrays.<Writable>asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)), | ||||
|         expSeq1MinMax.add(Arrays.asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)), | ||||
|                         new DoubleWritable((30 - 10.0) / (50.0 - 10.0)), | ||||
|                         new DoubleWritable((300 - 100.0) / (500.0 - 100.0)))); | ||||
| 
 | ||||
|         List<List<Writable>> expSeq2MinMax = new ArrayList<>(); | ||||
|         expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)), | ||||
|         expSeq2MinMax.add(Arrays.asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)), | ||||
|                         new DoubleWritable((40 - 10.0) / (50.0 - 10.0)), | ||||
|                         new DoubleWritable((400 - 100.0) / (500.0 - 100.0)))); | ||||
|         expSeq2MinMax.add(Arrays.<Writable>asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)), | ||||
|         expSeq2MinMax.add(Arrays.asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)), | ||||
|                         new DoubleWritable((50 - 10.0) / (50.0 - 10.0)), | ||||
|                         new DoubleWritable((500 - 100.0) / (500.0 - 100.0)))); | ||||
| 
 | ||||
| @ -246,17 +251,17 @@ public class DataFramesTests extends BaseSparkTest { | ||||
|         double s3 = new StandardDeviation().evaluate(new double[] {100, 200, 300, 400, 500}); | ||||
| 
 | ||||
|         List<List<Writable>> expSeq1Std = new ArrayList<>(); | ||||
|         expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2), | ||||
|         expSeq1Std.add(Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2), | ||||
|                         new DoubleWritable((100 - m3) / s3))); | ||||
|         expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2), | ||||
|         expSeq1Std.add(Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2), | ||||
|                         new DoubleWritable((200 - m3) / s3))); | ||||
|         expSeq1Std.add(Arrays.<Writable>asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2), | ||||
|         expSeq1Std.add(Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2), | ||||
|                         new DoubleWritable((300 - m3) / s3))); | ||||
| 
 | ||||
|         List<List<Writable>> expSeq2Std = new ArrayList<>(); | ||||
|         expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2), | ||||
|         expSeq2Std.add(Arrays.asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2), | ||||
|                         new DoubleWritable((400 - m3) / s3))); | ||||
|         expSeq2Std.add(Arrays.<Writable>asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2), | ||||
|         expSeq2Std.add(Arrays.asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2), | ||||
|                         new DoubleWritable((500 - m3) / s3))); | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -31,22 +31,24 @@ import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.IntWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.api.writable.NDArrayWritable; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.python.PythonTransform; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import java.util.*; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static java.time.Duration.ofMillis; | ||||
| import static org.junit.jupiter.api.Assertions.assertTimeout; | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @DisplayName("Execution Test") | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| class ExecutionTest extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -55,22 +57,16 @@ class ExecutionTest extends BaseSparkTest { | ||||
|         Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); | ||||
|         TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); | ||||
|         List<List<Writable>> inputData = new ArrayList<>(); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); | ||||
|         List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); | ||||
|         Collections.sort(out, new Comparator<List<Writable>>() { | ||||
| 
 | ||||
|             @Override | ||||
|             public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                 return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); | ||||
|             } | ||||
|         }); | ||||
|         Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt())); | ||||
|         List<List<Writable>> expected = new ArrayList<>(); | ||||
|         expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|         expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|         expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|         expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|         expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|         expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|         assertEquals(expected, out); | ||||
|     } | ||||
| 
 | ||||
| @ -81,31 +77,25 @@ class ExecutionTest extends BaseSparkTest { | ||||
|         TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build(); | ||||
|         List<List<List<Writable>>> inputSequences = new ArrayList<>(); | ||||
|         List<List<Writable>> seq1 = new ArrayList<>(); | ||||
|         seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         seq1.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         seq1.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         List<List<Writable>> seq2 = new ArrayList<>(); | ||||
|         seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); | ||||
|         seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); | ||||
|         seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); | ||||
|         seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); | ||||
|         inputSequences.add(seq1); | ||||
|         inputSequences.add(seq2); | ||||
|         JavaRDD<List<List<Writable>>> rdd = sc.parallelize(inputSequences); | ||||
|         List<List<List<Writable>>> out = new ArrayList<>(SparkTransformExecutor.executeSequenceToSequence(rdd, tp).collect()); | ||||
|         Collections.sort(out, new Comparator<List<List<Writable>>>() { | ||||
| 
 | ||||
|             @Override | ||||
|             public int compare(List<List<Writable>> o1, List<List<Writable>> o2) { | ||||
|                 return -Integer.compare(o1.size(), o2.size()); | ||||
|             } | ||||
|         }); | ||||
|         Collections.sort(out, (o1, o2) -> -Integer.compare(o1.size(), o2.size())); | ||||
|         List<List<List<Writable>>> expectedSequence = new ArrayList<>(); | ||||
|         List<List<Writable>> seq1e = new ArrayList<>(); | ||||
|         seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|         seq1e.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|         seq1e.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|         seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|         seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|         seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|         List<List<Writable>> seq2e = new ArrayList<>(); | ||||
|         seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); | ||||
|         seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); | ||||
|         seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); | ||||
|         seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); | ||||
|         expectedSequence.add(seq1e); | ||||
|         expectedSequence.add(seq2e); | ||||
|         assertEquals(expectedSequence, out); | ||||
| @ -114,34 +104,28 @@ class ExecutionTest extends BaseSparkTest { | ||||
|     @Test | ||||
|     @DisplayName("Test Reduction Global") | ||||
|     void testReductionGlobal() { | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))); | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))); | ||||
|         JavaRDD<List<Writable>> inData = sc.parallelize(in); | ||||
|         Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build(); | ||||
|         TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); | ||||
|         JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp); | ||||
|         List<List<Writable>> out = outRdd.collect(); | ||||
|         List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0))); | ||||
|         List<List<Writable>> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); | ||||
|         assertEquals(expOut, out); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     @DisplayName("Test Reduction By Key") | ||||
|     void testReductionByKey() { | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); | ||||
|         List<List<Writable>> in = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))); | ||||
|         JavaRDD<List<Writable>> inData = sc.parallelize(in); | ||||
|         Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build(); | ||||
|         TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build(); | ||||
|         JavaRDD<List<Writable>> outRdd = SparkTransformExecutor.execute(inData, tp); | ||||
|         List<List<Writable>> out = outRdd.collect(); | ||||
|         List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); | ||||
|         List<List<Writable>> expOut = Arrays.asList(Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); | ||||
|         out = new ArrayList<>(out); | ||||
|         Collections.sort(out, new Comparator<List<Writable>>() { | ||||
| 
 | ||||
|             @Override | ||||
|             public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                 return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); | ||||
|             } | ||||
|         }); | ||||
|         Collections.sort(out, (o1, o2) -> Integer.compare(o1.get(0).toInt(), o2.get(0).toInt())); | ||||
|         assertEquals(expOut, out); | ||||
|     } | ||||
| 
 | ||||
| @ -150,15 +134,15 @@ class ExecutionTest extends BaseSparkTest { | ||||
|     void testUniqueMultiCol() { | ||||
|         Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build(); | ||||
|         List<List<Writable>> inputData = new ArrayList<>(); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|         inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|         JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); | ||||
|         Map<String, List<Writable>> l = AnalyzeSpark.getUnique(Arrays.asList("col0", "col1"), schema, rdd); | ||||
|         assertEquals(2, l.size()); | ||||
| @ -180,58 +164,20 @@ class ExecutionTest extends BaseSparkTest { | ||||
|             String pythonCode = "col1 = ['state0', 'state1', 'state2'].index(col1)\ncol2 += 10.0"; | ||||
|             TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(finalSchema).build()).build(); | ||||
|             List<List<Writable>> inputData = new ArrayList<>(); | ||||
|             inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|             inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|             inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|             inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); | ||||
|             inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); | ||||
|             inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); | ||||
|             JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); | ||||
|             List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); | ||||
|             Collections.sort(out, new Comparator<List<Writable>>() { | ||||
| 
 | ||||
|                 @Override | ||||
|                 public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                     return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); | ||||
|                 } | ||||
|             }); | ||||
|             Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt())); | ||||
|             List<List<Writable>> expected = new ArrayList<>(); | ||||
|             expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|             expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|             expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|             expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); | ||||
|             expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); | ||||
|             expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); | ||||
|             assertEquals(expected, out); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     @Disabled("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771") | ||||
|     @DisplayName("Test Python Execution With ND Arrays") | ||||
|     void testPythonExecutionWithNDArrays() { | ||||
|         assertTimeout(ofMillis(60000), () -> { | ||||
|             long[] shape = new long[] { 3, 2 }; | ||||
|             Schema schema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).build(); | ||||
|             Schema finalSchema = new Schema.Builder().addColumnInteger("id").addColumnNDArray("col1", shape).addColumnNDArray("col2", shape).addColumnNDArray("col3", shape).build(); | ||||
|             String pythonCode = "col3 = col1 + col2"; | ||||
|             TransformProcess tp = new TransformProcess.Builder(schema).transform(PythonTransform.builder().code("first = np.sin(first)\nsecond = np.cos(second)").outputSchema(schema).build()).build(); | ||||
|             INDArray zeros = Nd4j.zeros(shape); | ||||
|             INDArray ones = Nd4j.ones(shape); | ||||
|             INDArray twos = ones.add(ones); | ||||
|             List<List<Writable>> inputData = new ArrayList<>(); | ||||
|             inputData.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); | ||||
|             inputData.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones))); | ||||
|             inputData.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones))); | ||||
|             JavaRDD<List<Writable>> rdd = sc.parallelize(inputData); | ||||
|             List<List<Writable>> out = new ArrayList<>(SparkTransformExecutor.execute(rdd, tp).collect()); | ||||
|             Collections.sort(out, new Comparator<List<Writable>>() { | ||||
| 
 | ||||
|                 @Override | ||||
|                 public int compare(List<Writable> o1, List<Writable> o2) { | ||||
|                     return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt()); | ||||
|                 } | ||||
|             }); | ||||
|             List<List<Writable>> expected = new ArrayList<>(); | ||||
|             expected.add(Arrays.<Writable>asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); | ||||
|             expected.add(Arrays.<Writable>asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones))); | ||||
|             expected.add(Arrays.<Writable>asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos))); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     @DisplayName("Test First Digit Transform Benfords Law") | ||||
|  | ||||
| @ -28,7 +28,10 @@ import org.datavec.api.util.ndarray.RecordConverter; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| @ -43,7 +46,10 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| @NativeTag | ||||
| public class NormalizationTests extends BaseSparkTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -38,7 +38,9 @@ import org.datavec.local.transforms.AnalyzeLocal; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.transform.AnalyzeSpark; | ||||
| import org.joda.time.DateTimeZone; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| @ -48,7 +50,10 @@ import java.nio.file.Files; | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.*; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestAnalysis extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -27,12 +27,17 @@ import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.writable.*; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.transform.SparkTransformExecutor; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestJoin extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -30,24 +30,29 @@ import org.datavec.api.writable.Writable; | ||||
| import org.datavec.api.writable.comparator.DoubleWritableComparator; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.transform.SparkTransformExecutor; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestCalculateSortedRank extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCalculateSortedRank() { | ||||
| 
 | ||||
|         List<List<Writable>> data = new ArrayList<>(); | ||||
|         data.add(Arrays.asList((Writable) new Text("0"), new DoubleWritable(0.0))); | ||||
|         data.add(Arrays.asList((Writable) new Text("3"), new DoubleWritable(0.3))); | ||||
|         data.add(Arrays.asList((Writable) new Text("2"), new DoubleWritable(0.2))); | ||||
|         data.add(Arrays.asList((Writable) new Text("1"), new DoubleWritable(0.1))); | ||||
|         data.add(Arrays.asList(new Text("0"), new DoubleWritable(0.0))); | ||||
|         data.add(Arrays.asList(new Text("3"), new DoubleWritable(0.3))); | ||||
|         data.add(Arrays.asList(new Text("2"), new DoubleWritable(0.2))); | ||||
|         data.add(Arrays.asList(new Text("1"), new DoubleWritable(0.1))); | ||||
| 
 | ||||
|         JavaRDD<List<Writable>> rdd = sc.parallelize(data); | ||||
| 
 | ||||
|  | ||||
| @ -29,7 +29,9 @@ import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.transform.SparkTransformExecutor; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.util.Arrays; | ||||
| import java.util.Collections; | ||||
| @ -37,7 +39,10 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| 
 | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.JAVA_ONLY) | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestConvertToSequence extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -45,13 +50,13 @@ public class TestConvertToSequence extends BaseSparkTest { | ||||
| 
 | ||||
|         Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build(); | ||||
| 
 | ||||
|         List<List<Writable>> allExamples = | ||||
|                         Arrays.asList(Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), | ||||
|                                                         new LongWritable(-10)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); | ||||
|         List<List<Writable>> allExamples; | ||||
|         allExamples = Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), | ||||
|                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), | ||||
|                                         new LongWritable(-10)), | ||||
|                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); | ||||
| 
 | ||||
|         TransformProcess tp = new TransformProcess.Builder(s) | ||||
|                         .convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time")) | ||||
| @ -73,13 +78,13 @@ public class TestConvertToSequence extends BaseSparkTest { | ||||
|         } | ||||
| 
 | ||||
|         List<List<Writable>> expSeq0 = Arrays.asList( | ||||
|                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), | ||||
|                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), | ||||
|                         Arrays.<Writable>asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), | ||||
|                         Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); | ||||
| 
 | ||||
|         List<List<Writable>> expSeq1 = Arrays.asList( | ||||
|                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                         Arrays.<Writable>asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); | ||||
|                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), | ||||
|                         Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); | ||||
| 
 | ||||
|         assertEquals(expSeq0, seq0); | ||||
|         assertEquals(expSeq1, seq1); | ||||
| @ -94,9 +99,9 @@ public class TestConvertToSequence extends BaseSparkTest { | ||||
|                 .build(); | ||||
| 
 | ||||
|         List<List<Writable>> allExamples = Arrays.asList( | ||||
|                 Arrays.<Writable>asList(new Text("a"), new LongWritable(0)), | ||||
|                 Arrays.<Writable>asList(new Text("b"), new LongWritable(1)), | ||||
|                 Arrays.<Writable>asList(new Text("c"), new LongWritable(2))); | ||||
|                 Arrays.asList(new Text("a"), new LongWritable(0)), | ||||
|                 Arrays.asList(new Text("b"), new LongWritable(1)), | ||||
|                 Arrays.asList(new Text("c"), new LongWritable(2))); | ||||
| 
 | ||||
|         TransformProcess tp = new TransformProcess.Builder(s) | ||||
|                 .convertToSequence() | ||||
|  | ||||
| @ -28,7 +28,9 @@ import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.BaseSparkTest; | ||||
| import org.datavec.spark.transform.utils.SparkUtils; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.FileInputStream; | ||||
| @ -38,6 +40,9 @@ import java.util.List; | ||||
| 
 | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| 
 | ||||
| @Tag(TagNames.SPARK) | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.DIST_SYSTEMS) | ||||
| public class TestSparkUtil extends BaseSparkTest { | ||||
| 
 | ||||
|     @Test | ||||
| @ -46,8 +51,8 @@ public class TestSparkUtil extends BaseSparkTest { | ||||
|            return; | ||||
|        } | ||||
|         List<List<Writable>> l = new ArrayList<>(); | ||||
|         l.add(Arrays.<Writable>asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1))); | ||||
|         l.add(Arrays.<Writable>asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2))); | ||||
|         l.add(Arrays.asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1))); | ||||
|         l.add(Arrays.asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2))); | ||||
| 
 | ||||
|         File f = File.createTempFile("testSparkUtil", "txt"); | ||||
|         f.deleteOnExit(); | ||||
|  | ||||
| @ -27,8 +27,11 @@ import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | ||||
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | ||||
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.resources.Resources; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.activations.Activation; | ||||
| import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| @ -39,6 +42,8 @@ import java.nio.file.Files; | ||||
| import java.util.concurrent.CountDownLatch; | ||||
| 
 | ||||
| @Disabled | ||||
| @NativeTag | ||||
| @Tag(TagNames.RNG) | ||||
| public class RandomTests extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -23,11 +23,11 @@ import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.deeplearning4j.datasets.base.MnistFetcher; | ||||
| import org.deeplearning4j.common.resources.DL4JResources; | ||||
| import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; | ||||
| import org.junit.jupiter.api.AfterAll; | ||||
| import org.junit.jupiter.api.BeforeAll; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.junit.jupiter.api.*; | ||||
| import org.junit.jupiter.api.io.TempDir; | ||||
| 
 | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| @ -41,10 +41,13 @@ import java.util.Set; | ||||
| import static org.junit.jupiter.api.Assertions.assertEquals; | ||||
| import static org.junit.jupiter.api.Assertions.assertFalse; | ||||
| import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| 
 | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @DisplayName("Mnist Fetcher Test") | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.NDARRAY_ETL) | ||||
| class MnistFetcherTest extends BaseDL4JTest { | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -23,8 +23,14 @@ package org.deeplearning4j.datasets; | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.deeplearning4j.datasets.fetchers.Cifar10Fetcher; | ||||
| import org.deeplearning4j.datasets.fetchers.TinyImageNetFetcher; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| @NativeTag | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @Tag(TagNames.NDARRAY_ETL) | ||||
| public class TestDataSets extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -20,8 +20,11 @@ | ||||
| package org.deeplearning4j.datasets.datavec; | ||||
| 
 | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.params.ParameterizedTest; | ||||
| import org.junit.jupiter.params.provider.MethodSource; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.linalg.factory.Nd4jBackend; | ||||
| import org.nd4j.shade.guava.io.Files; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| @ -76,6 +79,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| @Slf4j | ||||
| @DisplayName("Record Reader Data Setiterator Test") | ||||
| @Disabled | ||||
| @NativeTag | ||||
| class RecordReaderDataSetiteratorTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
| @ -148,6 +152,7 @@ class RecordReaderDataSetiteratorTest extends BaseDL4JTest { | ||||
|     @ParameterizedTest | ||||
|     @MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs") | ||||
|     @DisplayName("Test Sequence Record Reader") | ||||
|     @Tag(TagNames.NDARRAY_INDEXING) | ||||
|     void testSequenceRecordReader(Nd4jBackend backend) throws Exception { | ||||
|         File rootDir = temporaryFolder.toFile(); | ||||
|         // need to manually extract | ||||
|  | ||||
| @ -21,6 +21,8 @@ package org.deeplearning4j.datasets.datavec; | ||||
| 
 | ||||
| 
 | ||||
| import org.junit.jupiter.api.Disabled; | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| import org.nd4j.shade.guava.io.Files; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.apache.commons.io.FilenameUtils; | ||||
| @ -70,6 +72,7 @@ import org.junit.jupiter.api.extension.ExtendWith; | ||||
| 
 | ||||
| @DisplayName("Record Reader Multi Data Set Iterator Test") | ||||
| @Disabled | ||||
| @Tag(TagNames.FILE_IO) | ||||
| class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @TempDir | ||||
|  | ||||
| @ -21,6 +21,7 @@ package org.deeplearning4j.datasets.fetchers; | ||||
| 
 | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| 
 | ||||
| import org.junit.jupiter.api.Tag; | ||||
| import org.junit.jupiter.api.Test; | ||||
| 
 | ||||
| import java.io.File; | ||||
| @ -28,11 +29,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue; | ||||
| import static org.junit.jupiter.api.Assumptions.assumeTrue; | ||||
| import org.junit.jupiter.api.DisplayName; | ||||
| import org.junit.jupiter.api.extension.ExtendWith; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.common.tests.tags.TagNames; | ||||
| 
 | ||||
| /** | ||||
|  * @author saudet | ||||
|  */ | ||||
| @DisplayName("Svhn Data Fetcher Test") | ||||
| @Tag(TagNames.FILE_IO) | ||||
| @NativeTag | ||||
| class SvhnDataFetcherTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -26,6 +26,7 @@ import org.deeplearning4j.datasets.iterator.tools.VariableTimeseriesGenerator; | ||||
| import org.deeplearning4j.nn.util.TestDataSetConsumer; | ||||
| import org.junit.jupiter.api.BeforeEach; | ||||
| import org.junit.jupiter.api.Test; | ||||
| import org.nd4j.common.tests.tags.NativeTag; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import java.util.ArrayList; | ||||
| @ -40,6 +41,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; | ||||
| 
 | ||||
| @Slf4j | ||||
| @DisplayName("Async Data Set Iterator Test") | ||||
| @NativeTag | ||||
| class AsyncDataSetIteratorTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     private ExistingDataSetIterator backIterator; | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user