diff --git a/contrib/codegen-tools/codegen/pom.xml b/contrib/codegen-tools/codegen/pom.xml
index cbd00a825..5f367d8e4 100644
--- a/contrib/codegen-tools/codegen/pom.xml
+++ b/contrib/codegen-tools/codegen/pom.xml
@@ -15,7 +15,7 @@
1.7
1.18.8
1.1.7
- 4.12
+ 5.8.0-M1
5.4.2
1.8
3.1.1
diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java b/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java
index f41bd93a6..cbe4e265c 100644
--- a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java
+++ b/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java
@@ -17,13 +17,14 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.nd4j.codegen.ir;
-public class SerializationTest {
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
- public static void main(String...args) {
+@DisplayName("Serialization Test")
+class SerializationTest {
+ public static void main(String... args) {
}
-
}
diff --git a/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java b/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java
index 5d8e12885..7eeef5717 100644
--- a/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java
+++ b/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java
@@ -17,29 +17,23 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.nd4j.codegen.dsl;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test;
import org.nd4j.codegen.impl.java.DocsGenerator;
-
import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-public class DocsGeneratorTest {
+@DisplayName("Docs Generator Test")
+class DocsGeneratorTest {
@Test
- public void testJDtoMDAdapter() {
- String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" +
- " eye:\n" +
- " [ 1, 0]\n" +
- " [ 0, 1]\n" +
- " [ 0, 0]}";
- String expected = "{ INDArray eye = eye(3,2)\n" +
- " eye:\n" +
- " [ 1, 0]\n" +
- " [ 0, 1]\n" +
- " [ 0, 0]}";
+ @DisplayName("Test J Dto MD Adapter")
+ void testJDtoMDAdapter() {
+ String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
+ String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
assertEquals(out, expected);
diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml
index d7fcf5a47..fc091c5dd 100644
--- a/datavec/datavec-api/pom.xml
+++ b/datavec/datavec-api/pom.xml
@@ -34,6 +34,14 @@
datavec-api
+
+ org.junit.jupiter
+ junit-jupiter-api
+
+
+ org.junit.vintage
+ junit-vintage-engine
+
org.apache.commons
commons-lang3
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java
index 7aef92158..5ce4cb254 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@@ -27,46 +26,37 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
+@DisplayName("Csv Line Sequence Record Reader Test")
+class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
-public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
+ @TempDir
+ public Path testDir;
@Test
- public void test() throws Exception {
-
- File f = testDir.newFolder();
+ @DisplayName("Test")
+ void test(@TempDir Path testDir) throws Exception {
+ File f = testDir.toFile();
File source = new File(f, "temp.csv");
String str = "a,b,c\n1,2,3,4";
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
-
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
rr.initialize(new FileSplit(source));
-
- List> exp0 = Arrays.asList(
- Collections.singletonList(new Text("a")),
- Collections.singletonList(new Text("b")),
- Collections.singletonList(new Text("c")));
-
- List> exp1 = Arrays.asList(
- Collections.singletonList(new Text("1")),
- Collections.singletonList(new Text("2")),
- Collections.singletonList(new Text("3")),
- Collections.singletonList(new Text("4")));
-
- for( int i=0; i<3; i++ ) {
+ List> exp0 = Arrays.asList(Collections.singletonList(new Text("a")), Collections.singletonList(new Text("b")), Collections.singletonList(new Text("c")));
+ List> exp1 = Arrays.asList(Collections.singletonList(new Text("1")), Collections.singletonList(new Text("2")), Collections.singletonList(new Text("3")), Collections.singletonList(new Text("4")));
+ for (int i = 0; i < 3; i++) {
int count = 0;
while (rr.hasNext()) {
List> next = rr.sequenceRecord();
@@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
assertEquals(exp1, next);
}
}
-
assertEquals(2, count);
-
rr.reset();
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java
index f78676627..f108a4438 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@@ -27,32 +26,34 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
+@DisplayName("Csv Multi Sequence Record Reader Test")
+class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
-public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
+ @TempDir
+ public Path testDir;
@Test
- public void testConcatMode() throws Exception {
- for( int i=0; i<3; i++ ) {
-
+ @DisplayName("Test Concat Mode")
+ void testConcatMode() throws Exception {
+ for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
- switch (i){
+ switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@@ -68,31 +69,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
-
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
- File f = testDir.newFile();
+ File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
-
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
seqRR.initialize(new FileSplit(f));
-
-
List> exp0 = new ArrayList<>();
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
exp0.add(Collections.singletonList(new Text(s)));
}
-
List> exp1 = new ArrayList<>();
for (String s : "A,B,C".split(",")) {
exp1.add(Collections.singletonList(new Text(s)));
}
-
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
-
seqRR.reset();
-
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
@@ -100,13 +93,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testEqualLength() throws Exception {
-
- for( int i=0; i<3; i++ ) {
-
+ @DisplayName("Test Equal Length")
+ void testEqualLength() throws Exception {
+ for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
- switch (i) {
+ switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@@ -122,27 +114,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
-
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
- File f = testDir.newFile();
+ File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
-
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
seqRR.initialize(new FileSplit(f));
-
-
- List> exp0 = Arrays.asList(
- Arrays.asList(new Text("a"), new Text("1"), new Text("x")),
- Arrays.asList(new Text("b"), new Text("2"), new Text("y")));
-
+ List> exp0 = Arrays.asList(Arrays.asList(new Text("a"), new Text("1"), new Text("x")), Arrays.asList(new Text("b"), new Text("2"), new Text("y")));
List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C")));
-
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
-
seqRR.reset();
-
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
@@ -150,13 +132,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testPadding() throws Exception {
-
- for( int i=0; i<3; i++ ) {
-
+ @DisplayName("Test Padding")
+ void testPadding() throws Exception {
+ for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
- switch (i) {
+ switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@@ -172,27 +153,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
-
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
- File f = testDir.newFile();
+ File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
-
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
seqRR.initialize(new FileSplit(f));
-
-
- List> exp0 = Arrays.asList(
- Arrays.asList(new Text("a"), new Text("1"), new Text("x")),
- Arrays.asList(new Text("b"), new Text("PAD"), new Text("PAD")));
-
+ List> exp0 = Arrays.asList(Arrays.asList(new Text("a"), new Text("1"), new Text("x")), Arrays.asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C")));
-
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
-
seqRR.reset();
-
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java
index 80c75c830..184462c8c 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord;
@@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.util.ArrayList;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
-
-public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
+@DisplayName("Csvn Lines Sequence Record Reader Test")
+class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test
- public void testCSVNLinesSequenceRecordReader() throws Exception {
+ @DisplayName("Test CSVN Lines Sequence Record Reader")
+ void testCSVNLinesSequenceRecordReader() throws Exception {
int nLinesPerSequence = 10;
-
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
int count = 0;
while (seqRR.hasNext()) {
List> next = seqRR.sequenceRecord();
-
List> expected = new ArrayList<>();
for (int i = 0; i < nLinesPerSequence; i++) {
expected.add(rr.next());
}
-
assertEquals(10, next.size());
assertEquals(expected, next);
-
count++;
}
-
assertEquals(150 / nLinesPerSequence, count);
}
@Test
- public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
+ @DisplayName("Test CSV Nlines Sequence Record Reader Meta Data")
+ void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
int nLinesPerSequence = 10;
-
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
List>> out = new ArrayList<>();
while (seqRR.hasNext()) {
List> next = seqRR.sequenceRecord();
out.add(next);
}
-
seqRR.reset();
List>> out2 = new ArrayList<>();
List out3 = new ArrayList<>();
@@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
meta.add(seq.getMetaData());
out3.add(seq);
}
-
assertEquals(out, out2);
-
List out4 = seqRR.loadSequenceFromMetaData(meta);
assertEquals(out3, out4);
}
-
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java
index 85f20f3ad..c7e840c42 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
@@ -47,41 +46,44 @@ import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
-import static org.junit.Assert.*;
+import static org.junit.jupiter.api.Assertions.*;
+
+
+@DisplayName("Csv Record Reader Test")
+class CSVRecordReaderTest extends BaseND4JTest {
-public class CSVRecordReaderTest extends BaseND4JTest {
@Test
- public void testNext() throws Exception {
+ @DisplayName("Test Next")
+ void testNext() throws Exception {
CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
while (reader.hasNext()) {
List vals = reader.next();
List arr = new ArrayList<>(vals);
-
- assertEquals("Entry count", 23, vals.size());
+ assertEquals(23, vals.size(), "Entry count");
Text lastEntry = (Text) arr.get(arr.size() - 1);
- assertEquals("Last entry garbage", 1, lastEntry.getLength());
+ assertEquals(1, lastEntry.getLength(), "Last entry garbage");
}
}
@Test
- public void testEmptyEntries() throws Exception {
+ @DisplayName("Test Empty Entries")
+ void testEmptyEntries() throws Exception {
CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
while (reader.hasNext()) {
List vals = reader.next();
- assertEquals("Entry count", 23, vals.size());
+ assertEquals(23, vals.size(), "Entry count");
}
}
@Test
- public void testReset() throws Exception {
+ @DisplayName("Test Reset")
+ void testReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
int nResets = 5;
for (int i = 0; i < nResets; i++) {
-
int lineCount = 0;
while (rr.hasNext()) {
List line = rr.next();
@@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testResetWithSkipLines() throws Exception {
+ @DisplayName("Test Reset With Skip Lines")
+ void testResetWithSkipLines() throws Exception {
CSVRecordReader rr = new CSVRecordReader(10, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0;
@@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testWrite() throws Exception {
+ @DisplayName("Test Write")
+ void testWrite() throws Exception {
List> list = new ArrayList<>();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 10; i++) {
@@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
list.add(temp);
}
-
String expected = sb.toString();
-
Path p = Files.createTempFile("csvwritetest", "csv");
p.toFile().deleteOnExit();
-
FileRecordWriter writer = new CSVRecordWriter();
FileSplit fileSplit = new FileSplit(p.toFile());
- writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
for (List c : list) {
writer.write(c);
}
writer.close();
-
- //Read file back in; compare
+ // Read file back in; compare
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
-
- // System.out.println(expected);
- // System.out.println("----------");
- // System.out.println(fileContents);
-
+ // System.out.println(expected);
+ // System.out.println("----------");
+ // System.out.println(fileContents);
assertEquals(expected, fileContents);
}
@Test
- public void testTabsAsSplit1() throws Exception {
-
+ @DisplayName("Test Tabs As Split 1")
+ void testTabsAsSplit1() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '\t');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
while (reader.hasNext()) {
List list = new ArrayList<>(reader.next());
-
assertEquals(2, list.size());
}
}
@Test
- public void testPipesAsSplit() throws Exception {
-
+ @DisplayName("Test Pipes As Split")
+ void testPipesAsSplit() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '|');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
int lineidx = 0;
List sixthColumn = Arrays.asList(13, 95, 15, 25);
while (reader.hasNext()) {
List list = new ArrayList<>(reader.next());
-
assertEquals(10, list.size());
- assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt());
+ assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt());
lineidx++;
}
}
-
@Test
- public void testWithQuotes() throws Exception {
+ @DisplayName("Test With Quotes")
+ void testWithQuotes() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
while (reader.hasNext()) {
List vals = reader.next();
- assertEquals("Entry count", 6, vals.size());
- assertEquals("1", vals.get(0).toString());
- assertEquals("0", vals.get(1).toString());
- assertEquals("3", vals.get(2).toString());
- assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString());
- assertEquals("male", vals.get(4).toString());
- assertEquals("\"", vals.get(5).toString());
+ assertEquals(6, vals.size(), "Entry count");
+ assertEquals(vals.get(0).toString(), "1");
+ assertEquals(vals.get(1).toString(), "0");
+ assertEquals(vals.get(2).toString(), "3");
+ assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris");
+ assertEquals(vals.get(4).toString(), "male");
+ assertEquals(vals.get(5).toString(), "\"");
}
}
-
@Test
- public void testMeta() throws Exception {
+ @DisplayName("Test Meta")
+ void testMeta() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
int lineCount = 0;
List metaList = new ArrayList<>();
List> writables = new ArrayList<>();
@@ -213,30 +208,25 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(5, r.getRecord().size());
lineCount++;
RecordMetaData meta = r.getMetaData();
- // System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
-
+ // System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
metaList.add(meta);
writables.add(r.getRecord());
}
assertFalse(rr.hasNext());
assertEquals(150, lineCount);
rr.reset();
-
-
System.out.println("\n\n\n--------------------------------");
List contents = rr.loadFromMetaData(metaList);
assertEquals(150, contents.size());
- // for(Record r : contents ){
- // System.out.println(r);
- // }
-
+ // for(Record r : contents ){
+ // System.out.println(r);
+ // }
List meta2 = new ArrayList<>();
meta2.add(metaList.get(100));
meta2.add(metaList.get(90));
meta2.add(metaList.get(80));
meta2.add(metaList.get(70));
meta2.add(metaList.get(60));
-
List contents2 = rr.loadFromMetaData(meta2);
assertEquals(writables.get(100), contents2.get(0).getRecord());
assertEquals(writables.get(90), contents2.get(1).getRecord());
@@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testRegex() throws Exception {
- CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"});
+ @DisplayName("Test Regex")
+ void testRegex() throws Exception {
+ CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" });
reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
while (reader.hasNext()) {
List vals = reader.next();
- assertEquals("Entry count", 4, vals.size());
- assertEquals("normal", vals.get(0).toString());
- assertEquals("1.2.3.4", vals.get(1).toString());
- assertEquals("space", vals.get(2).toString());
- assertEquals("separator", vals.get(3).toString());
+ assertEquals(4, vals.size(), "Entry count");
+ assertEquals(vals.get(0).toString(), "normal");
+ assertEquals(vals.get(1).toString(), "1.2.3.4");
+ assertEquals(vals.get(2).toString(), "space");
+ assertEquals(vals.get(3).toString(), "separator");
}
}
- @Test(expected = NoSuchElementException.class)
- public void testCsvSkipAllLines() throws IOException, InterruptedException {
- final int numLines = 4;
- final List lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
- (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
- String header = ",one,two,three";
- List lines = new ArrayList<>();
- for (int i = 0; i < numLines; i++)
- lines.add(Integer.toString(i) + header);
- File tempFile = File.createTempFile("csvSkipLines", ".csv");
- FileUtils.writeLines(tempFile, lines);
-
- CSVRecordReader rr = new CSVRecordReader(numLines, ',');
- rr.initialize(new FileSplit(tempFile));
- rr.reset();
- assertTrue(!rr.hasNext());
- rr.next();
+ @Test
+ @DisplayName("Test Csv Skip All Lines")
+ void testCsvSkipAllLines() {
+ assertThrows(NoSuchElementException.class, () -> {
+ final int numLines = 4;
+ final List lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
+ String header = ",one,two,three";
+ List lines = new ArrayList<>();
+ for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
+ File tempFile = File.createTempFile("csvSkipLines", ".csv");
+ FileUtils.writeLines(tempFile, lines);
+ CSVRecordReader rr = new CSVRecordReader(numLines, ',');
+ rr.initialize(new FileSplit(tempFile));
+ rr.reset();
+ assertTrue(!rr.hasNext());
+ rr.next();
+ });
}
@Test
- public void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
+ @DisplayName("Test Csv Skip All But One Line")
+ void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
final int numLines = 4;
- final List lineList = Arrays.asList(new Text(Integer.toString(numLines - 1)),
- new Text("one"), new Text("two"), new Text("three"));
+ final List lineList = Arrays.asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three"));
String header = ",one,two,three";
List lines = new ArrayList<>();
- for (int i = 0; i < numLines; i++)
- lines.add(Integer.toString(i) + header);
+ for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines);
-
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
rr.initialize(new FileSplit(tempFile));
rr.reset();
@@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(rr.next(), lineList);
}
-
@Test
- public void testStreamReset() throws Exception {
+ @DisplayName("Test Stream Reset")
+ void testStreamReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
-
int count = 0;
- while(rr.hasNext()){
+ while (rr.hasNext()) {
assertNotNull(rr.next());
count++;
}
assertEquals(150, count);
-
assertFalse(rr.resetSupported());
-
- try{
+ try {
rr.reset();
fail("Expected exception");
- } catch (Exception e){
+ } catch (Exception e) {
String msg = e.getMessage();
String msg2 = e.getCause().getMessage();
- assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
- assertTrue(msg2, msg2.contains("Reset not supported from streams"));
-// e.printStackTrace();
+ assertTrue(msg.contains("Error during LineRecordReader reset"),msg);
+ assertTrue(msg2.contains("Reset not supported from streams"),msg2);
+ // e.printStackTrace();
}
}
@Test
- public void testUsefulExceptionNoInit(){
-
+ @DisplayName("Test Useful Exception No Init")
+ void testUsefulExceptionNoInit() {
CSVRecordReader rr = new CSVRecordReader(0, ',');
-
- try{
+ try {
rr.hasNext();
fail("Expected exception");
- } catch (Exception e){
- assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
+ } catch (Exception e) {
+ assertTrue( e.getMessage().contains("initialized"),e.getMessage());
}
-
- try{
+ try {
rr.next();
fail("Expected exception");
- } catch (Exception e){
- assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
+ } catch (Exception e) {
+ assertTrue(e.getMessage().contains("initialized"),e.getMessage());
}
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java
index 70a774165..e022746e0 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord;
@@ -28,11 +27,10 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
@@ -41,25 +39,27 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
+@DisplayName("Csv Sequence Record Reader Test")
+class CSVSequenceRecordReaderTest extends BaseND4JTest {
-public class CSVSequenceRecordReaderTest extends BaseND4JTest {
-
- @Rule
- public TemporaryFolder tempDir = new TemporaryFolder();
+ @TempDir
+ public Path tempDir;
@Test
- public void test() throws Exception {
-
+ @DisplayName("Test")
+ void test() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
-
int sequenceCount = 0;
while (seqReader.hasNext()) {
List> sequence = seqReader.sequenceRecord();
- assertEquals(4, sequence.size()); //4 lines, plus 1 header line
-
+ // 4 lines, plus 1 header line
+ assertEquals(4, sequence.size());
Iterator> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testReset() throws Exception {
+ @DisplayName("Test Reset")
+ void testReset() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
-
int nTests = 5;
for (int i = 0; i < nTests; i++) {
seqReader.reset();
-
int sequenceCount = 0;
while (seqReader.hasNext()) {
List> sequence = seqReader.sequenceRecord();
- assertEquals(4, sequence.size()); //4 lines, plus 1 header line
-
+ // 4 lines, plus 1 header line
+ assertEquals(4, sequence.size());
Iterator> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testMetaData() throws Exception {
+ @DisplayName("Test Meta Data")
+ void testMetaData() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
-
List>> l = new ArrayList<>();
while (seqReader.hasNext()) {
List> sequence = seqReader.sequenceRecord();
- assertEquals(4, sequence.size()); //4 lines, plus 1 header line
-
+ // 4 lines, plus 1 header line
+ assertEquals(4, sequence.size());
Iterator> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
lineCount++;
}
assertEquals(4, lineCount);
-
l.add(sequence);
}
-
List l2 = new ArrayList<>();
List meta = new ArrayList<>();
seqReader.reset();
@@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
meta.add(sr.getMetaData());
}
assertEquals(3, l2.size());
-
List fromMeta = seqReader.loadSequenceFromMetaData(meta);
for (int i = 0; i < 3; i++) {
assertEquals(l.get(i), l2.get(i).getSequenceRecord());
@@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
}
- private static class
- TestInputSplit implements InputSplit {
+ @DisplayName("Test Input Split")
+ private static class TestInputSplit implements InputSplit {
@Override
public boolean canWriteToLocation(URI location) {
@@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void updateSplitLocations(boolean reset) {
-
}
@Override
@@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void bootStrapForWrite() {
-
}
@Override
@@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void reset() {
- //No op
+ // No op
}
@Override
public boolean resetSupported() {
return true;
}
-
-
-
-
}
-
@Test
- public void testCsvSeqAndNumberedFileSplit() throws Exception {
- File baseDir = tempDir.newFolder();
- //Simple sanity check unit test
+ @DisplayName("Test Csv Seq And Numbered File Split")
+ void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception {
+ File baseDir = tempDir.toFile();
+ // Simple sanity check unit test
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
}
-
- //Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
+ // Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
-
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
-
- while(featureReader.hasNext()){
+ while (featureReader.hasNext()) {
featureReader.nextSequence();
}
-
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java
index cab012faf..148f8ff0b 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.SequenceRecordReader;
@@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.util.LinkedList;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
-
-public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
+@DisplayName("Csv Variable Sliding Window Record Reader Test")
+class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test
- public void testCSVVariableSlidingWindowRecordReader() throws Exception {
+ @DisplayName("Test CSV Variable Sliding Window Record Reader")
+ void testCSVVariableSlidingWindowRecordReader() throws Exception {
int maxLinesPerSequence = 3;
-
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
int count = 0;
while (seqRR.hasNext()) {
List> next = seqRR.sequenceRecord();
-
- if(count==maxLinesPerSequence-1) {
+ if (count == maxLinesPerSequence - 1) {
LinkedList> expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next());
}
assertEquals(expected, next);
-
}
- if(count==maxLinesPerSequence) {
+ if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size());
}
- if(count==0) { // first seq should be length 1
+ if (count == 0) {
+ // first seq should be length 1
assertEquals(1, next.size());
}
- if(count>151) { // last seq should be length 1
+ if (count > 151) {
+ // last seq should be length 1
assertEquals(1, next.size());
}
-
count++;
}
-
assertEquals(152, count);
}
@Test
- public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
+ @DisplayName("Test CSV Variable Sliding Window Record Reader Stride")
+ void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
int maxLinesPerSequence = 3;
int stride = 2;
-
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
int count = 0;
while (seqRR.hasNext()) {
List> next = seqRR.sequenceRecord();
-
- if(count==maxLinesPerSequence-1) {
+ if (count == maxLinesPerSequence - 1) {
LinkedList> expected = new LinkedList<>();
- for(int s = 0; s < stride; s++) {
+ for (int s = 0; s < stride; s++) {
expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next());
}
}
assertEquals(expected, next);
-
}
- if(count==maxLinesPerSequence) {
+ if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size());
}
- if(count==0) { // first seq should be length 2
+ if (count == 0) {
+ // first seq should be length 2
assertEquals(2, next.size());
}
- if(count>151) { // last seq should be length 1
+ if (count > 151) {
+ // last seq should be length 1
assertEquals(1, next.size());
}
-
count++;
}
-
assertEquals(76, count);
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java
index 036e23475..1acbf2fac 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@@ -29,44 +28,38 @@ import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch;
-
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.*;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.*;
-
-public class FileBatchRecordReaderTest extends BaseND4JTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
+@DisplayName("File Batch Record Reader Test")
+class FileBatchRecordReaderTest extends BaseND4JTest {
@Test
- public void testCsv() throws Exception {
-
- //This is an unrealistic use case - one line/record per CSV
- File baseDir = testDir.newFolder();
-
+ @DisplayName("Test Csv")
+ void testCsv(@TempDir Path testDir) throws Exception {
+ // This is an unrealistic use case - one line/record per CSV
+ File baseDir = testDir.toFile();
List fileList = new ArrayList<>();
- for( int i=0; i<10; i++ ){
+ for (int i = 0; i < 10; i++) {
String s = "file_" + i + "," + i + "," + i;
File f = new File(baseDir, "origFile" + i + ".csv");
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
fileList.add(f);
}
-
FileBatch fb = FileBatch.forFiles(fileList);
-
RecordReader rr = new CSVRecordReader();
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
-
-
- for( int test=0; test<3; test++) {
+ for (int test = 0; test < 3; test++) {
for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext());
List next = fbrr.next();
@@ -83,15 +76,15 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testCsvSequence() throws Exception {
- //CSV sequence - 3 lines per file, 10 files
- File baseDir = testDir.newFolder();
-
+ @DisplayName("Test Csv Sequence")
+ void testCsvSequence(@TempDir Path testDir) throws Exception {
+ // CSV sequence - 3 lines per file, 10 files
+ File baseDir = testDir.toFile();
List fileList = new ArrayList<>();
- for( int i=0; i<10; i++ ){
+ for (int i = 0; i < 10; i++) {
StringBuilder sb = new StringBuilder();
- for( int j=0; j<3; j++ ){
- if(j > 0)
+ for (int j = 0; j < 3; j++) {
+ if (j > 0)
sb.append("\n");
sb.append("file_" + i + "," + i + "," + j);
}
@@ -99,19 +92,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
fileList.add(f);
}
-
FileBatch fb = FileBatch.forFiles(fileList);
SequenceRecordReader rr = new CSVSequenceRecordReader();
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
-
-
- for( int test=0; test<3; test++) {
+ for (int test = 0; test < 3; test++) {
for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext());
List> next = fbrr.sequenceRecord();
assertEquals(3, next.size());
int count = 0;
- for(List step : next ){
+ for (List step : next) {
String s1 = "file_" + i;
assertEquals(s1, step.get(0).toString());
assertEquals(String.valueOf(i), step.get(1).toString());
@@ -123,5 +113,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
fbrr.reset();
}
}
-
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java
index 910fc31b2..d914cd95f 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record;
@@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-
-public class FileRecordReaderTest extends BaseND4JTest {
+@DisplayName("File Record Reader Test")
+class FileRecordReaderTest extends BaseND4JTest {
@Test
- public void testReset() throws Exception {
+ @DisplayName("Test Reset")
+ void testReset() throws Exception {
FileRecordReader rr = new FileRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
-
int nResets = 5;
for (int i = 0; i < nResets; i++) {
-
int lineCount = 0;
while (rr.hasNext()) {
List line = rr.next();
@@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testMeta() throws Exception {
+ @DisplayName("Test Meta")
+ void testMeta() throws Exception {
FileRecordReader rr = new FileRecordReader();
-
-
URI[] arr = new URI[3];
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
-
InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
rr.initialize(is);
-
List> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.next());
}
-
assertEquals(3, out.size());
-
rr.reset();
List> out2 = new ArrayList<>();
List out3 = new ArrayList<>();
@@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest {
out2.add(r.getRecord());
out3.add(r);
meta.add(r.getMetaData());
-
assertEquals(arr[count++], r.getMetaData().getURI());
}
-
assertEquals(out, out2);
List fromMeta = rr.loadFromMetaData(meta);
assertEquals(out3, fromMeta);
}
-
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java
index 9d5b76688..4095d1af7 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader;
@@ -29,96 +28,80 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
-
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
+@DisplayName("Jackson Line Record Reader Test")
+class JacksonLineRecordReaderTest extends BaseND4JTest {
-public class JacksonLineRecordReaderTest extends BaseND4JTest {
+ @TempDir
+ public Path testDir;
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
-
- public JacksonLineRecordReaderTest() {
- }
+ public JacksonLineRecordReaderTest() {
+ }
private static FieldSelection getFieldSelection() {
- return new FieldSelection.Builder().addField("value1").
- addField("value2").
- addField("value3").
- addField("value4").
- addField("value5").
- addField("value6").
- addField("value7").
- addField("value8").
- addField("value9").
- addField("value10").build();
+ return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build();
}
-
+
@Test
- public void testReadJSON() throws Exception {
-
+ @DisplayName("Test Read JSON")
+ void testReadJSON() throws Exception {
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
-
testJacksonRecordReader(rr);
- }
-
- private static void testJacksonRecordReader(RecordReader rr) {
- while (rr.hasNext()) {
- List json0 = rr.next();
- //System.out.println(json0);
- assert(json0.size() > 0);
- }
}
+ private static void testJacksonRecordReader(RecordReader rr) {
+ while (rr.hasNext()) {
+ List json0 = rr.next();
+ // System.out.println(json0);
+ assert (json0.size() > 0);
+ }
+ }
@Test
- public void testJacksonLineSequenceRecordReader() throws Exception {
- File dir = testDir.newFolder();
- new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
-
- FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
- .addField(new Text("MISSING_CX"), "c", "x").build();
-
- JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
- File[] files = dir.listFiles();
- Arrays.sort(files);
- URI[] u = new URI[files.length];
- for( int i=0; i> expSeq0 = new ArrayList<>();
- expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
- expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
- expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
-
- List> expSeq1 = new ArrayList<>();
- expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
-
-
- int count = 0;
- while(rr.hasNext()){
- List> next = rr.sequenceRecord();
- if(count++ == 0){
- assertEquals(expSeq0, next);
- } else {
- assertEquals(expSeq1, next);
- }
- }
-
- assertEquals(2, count);
- }
+ @DisplayName("Test Jackson Line Sequence Record Reader")
+ void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception {
+ File dir = testDir.toFile();
+ new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
+ FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
+ JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
+ File[] files = dir.listFiles();
+ Arrays.sort(files);
+ URI[] u = new URI[files.length];
+ for (int i = 0; i < files.length; i++) {
+ u[i] = files[i].toURI();
+ }
+ rr.initialize(new CollectionInputSplit(u));
+ List> expSeq0 = new ArrayList<>();
+ expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
+ expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
+ expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
+ List> expSeq1 = new ArrayList<>();
+ expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
+ int count = 0;
+ while (rr.hasNext()) {
+ List> next = rr.sequenceRecord();
+ if (count++ == 0) {
+ assertEquals(expSeq0, next);
+ } else {
+ assertEquals(expSeq1, next);
+ }
+ }
+ assertEquals(2, count);
+ }
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java
index 5b91c4523..2e4a2261b 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.io.labels.PathLabelGenerator;
@@ -32,113 +31,94 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
-
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
+@DisplayName("Jackson Record Reader Test")
+class JacksonRecordReaderTest extends BaseND4JTest {
-public class JacksonRecordReaderTest extends BaseND4JTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
+ @TempDir
+ public Path testDir;
@Test
- public void testReadingJson() throws Exception {
- //Load 3 values from 3 JSON files
- //stricture: a:value, b:value, c:x:value, c:y:value
- //And we want to load only a:value, b:value and c:x:value
- //For first JSON file: all values are present
- //For second JSON file: b:value is missing
- //For third JSON file: c:x:value is missing
-
+ @DisplayName("Test Reading Json")
+ void testReadingJson(@TempDir Path testDir) throws Exception {
+ // Load 3 values from 3 JSON files
+ // stricture: a:value, b:value, c:x:value, c:y:value
+ // And we want to load only a:value, b:value and c:x:value
+ // For first JSON file: all values are present
+ // For second JSON file: b:value is missing
+ // For third JSON file: c:x:value is missing
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
- File f = testDir.newFolder();
+ File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
-
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
-
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(is);
-
testJacksonRecordReader(rr);
}
@Test
- public void testReadingYaml() throws Exception {
- //Exact same information as JSON format, but in YAML format
-
+ @DisplayName("Test Reading Yaml")
+ void testReadingYaml(@TempDir Path testDir) throws Exception {
+ // Exact same information as JSON format, but in YAML format
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
- File f = testDir.newFolder();
+ File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
-
-
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
-
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
rr.initialize(is);
-
testJacksonRecordReader(rr);
}
@Test
- public void testReadingXml() throws Exception {
- //Exact same information as JSON format, but in XML format
-
+ @DisplayName("Test Reading Xml")
+ void testReadingXml(@TempDir Path testDir) throws Exception {
+ // Exact same information as JSON format, but in XML format
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
- File f = testDir.newFolder();
+ File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
-
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
-
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
rr.initialize(is);
-
testJacksonRecordReader(rr);
}
-
private static FieldSelection getFieldSelection() {
- return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
- .addField(new Text("MISSING_CX"), "c", "x").build();
+ return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
}
-
-
private static void testJacksonRecordReader(RecordReader rr) {
-
List json0 = rr.next();
List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, json0);
-
List json1 = rr.next();
- List exp1 =
- Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
+ List exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, json1);
-
List json2 = rr.next();
- List exp2 =
- Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
+ List exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, json2);
-
assertFalse(rr.hasNext());
-
- //Test reset
+ // Test reset
rr.reset();
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
@@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testAppendingLabels() throws Exception {
-
+ @DisplayName("Test Appending Labels")
+ void testAppendingLabels(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
- File f = testDir.newFolder();
+ File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
-
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
-
- //Insert at the end:
- RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
- new LabelGen());
+ // Insert at the end:
+ RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
rr.initialize(is);
-
- List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
- new IntWritable(0));
+ List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0));
assertEquals(exp0, rr.next());
-
- List exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
- new IntWritable(1));
+ List exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1));
assertEquals(exp1, rr.next());
-
- List exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
- new IntWritable(2));
+ List exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2));
assertEquals(exp2, rr.next());
-
- //Insert at position 0:
- rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
- new LabelGen(), 0);
+ // Insert at position 0:
+ rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0);
rr.initialize(is);
-
- exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
- new Text("cxValue0"));
+ exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, rr.next());
-
- exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
- new Text("cxValue1"));
+ exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, rr.next());
-
- exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
- new Text("MISSING_CX"));
+ exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, rr.next());
}
@Test
- public void testAppendingLabelsMetaData() throws Exception {
+ @DisplayName("Test Appending Labels Meta Data")
+ void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
- File f = testDir.newFolder();
+ File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
-
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
-
- //Insert at the end:
- RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
- new LabelGen());
+ // Insert at the end:
+ RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
rr.initialize(is);
-
List> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.next());
}
assertEquals(3, out.size());
-
rr.reset();
-
List> out2 = new ArrayList<>();
List outRecord = new ArrayList<>();
List meta = new ArrayList<>();
@@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
outRecord.add(r);
meta.add(r.getMetaData());
}
-
assertEquals(out, out2);
-
List fromMeta = rr.loadFromMetaData(meta);
assertEquals(outRecord, fromMeta);
}
-
+ @DisplayName("Label Gen")
private static class LabelGen implements PathLabelGenerator {
@Override
@@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
return true;
}
}
-
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java
index e7fe410c8..9d3ae9663 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration;
@@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.io.IOException;
import java.util.*;
-
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
-import static org.junit.Assert.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
+import static org.junit.jupiter.api.Assertions.assertThrows;
-public class LibSvmRecordReaderTest extends BaseND4JTest {
+@DisplayName("Lib Svm Record Reader Test")
+class LibSvmRecordReaderTest extends BaseND4JTest {
@Test
- public void testBasicRecord() throws IOException, InterruptedException {
+ @DisplayName("Test Basic Record")
+ void testBasicRecord() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- new IntWritable(7)));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- new IntWritable(2)));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
// 33
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- new IntWritable(33)));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testNoAppendLabel() throws IOException, InterruptedException {
+ @DisplayName("Test No Append Label")
+ void testNoAppendLabel() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5)));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 33
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testNoLabel() throws IOException, InterruptedException {
+ @DisplayName("Test No Label")
+ void testNoLabel() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
- // 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5)));
- // qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO));
- // 1:1.0
- correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO));
- //
- correct.put(3, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO));
-
+ // 2:1 4:2 6:3 8:4 10:5
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
+ // qid:42 1:0.1 2:2 6:6.6 8:80
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
+ // 1:1.0
+ correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
+ //
+ correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testMultioutputRecord() throws IOException, InterruptedException {
+ @DisplayName("Test Multioutput Record")
+ void testMultioutputRecord() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- new IntWritable(7), new DoubleWritable(2.45),
- new IntWritable(9)));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- new IntWritable(2), new IntWritable(3),
- new IntWritable(4)));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
// 33,32.0,31.9
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- new IntWritable(33), new DoubleWritable(32.0),
- new DoubleWritable(31.9)));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
-
@Test
- public void testMultilabelRecord() throws IOException, InterruptedException {
+ @DisplayName("Test Multilabel Record")
+ void testMultilabelRecord() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- LABEL_ONE, LABEL_ZERO,
- LABEL_ONE, LABEL_ZERO));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- LABEL_ZERO, LABEL_ONE,
- LABEL_ZERO, LABEL_ZERO));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ONE, LABEL_ONE,
- LABEL_ZERO, LABEL_ONE));
- // 1:1.0
- correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
- //
- correct.put(4, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
+ // 1:1.0
+ correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
+ //
+ correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testZeroBasedIndexing() throws IOException, InterruptedException {
+ @DisplayName("Test Zero Based Indexing")
+ void testZeroBasedIndexing() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO,
- ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- LABEL_ZERO,
- LABEL_ONE, LABEL_ZERO,
- LABEL_ONE, LABEL_ZERO));
+ correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(ZERO,
- new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ZERO, LABEL_ONE,
- LABEL_ZERO, LABEL_ZERO));
+ correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
- correct.put(2, Arrays.asList(ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ONE, LABEL_ONE,
- LABEL_ZERO, LABEL_ONE));
- // 1:1.0
- correct.put(3, Arrays.asList(ZERO,
- new DoubleWritable(1.0), ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
- //
- correct.put(4, Arrays.asList(ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
+ // 1:1.0
+ correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
+ //
+ correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
// Zero-based indexing is default
- config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
+ // NOT STANDARD!
+ config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
@@ -336,87 +206,107 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
- @Test(expected = NoSuchElementException.class)
- public void testNoSuchElementException() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Test No Such Element Exception")
+ void testNoSuchElementException() {
+ assertThrows(NoSuchElementException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
+ while (rr.hasNext()) rr.next();
rr.next();
- rr.next();
+ });
}
- @Test(expected = UnsupportedOperationException.class)
- public void failedToSetNumFeaturesException() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Failed To Set Num Features Exception")
+ void failedToSetNumFeaturesException() {
+ assertThrows(UnsupportedOperationException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
+ while (rr.hasNext()) rr.next();
+ });
+ }
+
+ @Test
+ @DisplayName("Test Inconsistent Num Labels Exception")
+ void testInconsistentNumLabelsException() {
+ assertThrows(UnsupportedOperationException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
+ while (rr.hasNext()) rr.next();
+ });
+ }
+
+ @Test
+ @DisplayName("Test Inconsistent Num Multiabels Exception")
+ void testInconsistentNumMultiabelsException() {
+ assertThrows(UnsupportedOperationException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
+ config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
+ while (rr.hasNext()) rr.next();
+ });
+ }
+
+ @Test
+ @DisplayName("Test Feature Index Exceeds Num Features")
+ void testFeatureIndexExceedsNumFeatures() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
+ });
}
- @Test(expected = UnsupportedOperationException.class)
- public void testInconsistentNumLabelsException() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Test Label Index Exceeds Num Labels")
+ void testLabelIndexExceedsNumLabels() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
+ config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
+ config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
+ });
}
- @Test(expected = UnsupportedOperationException.class)
- public void testInconsistentNumMultiabelsException() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
- config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Test Zero Index Feature Without Using Zero Indexing")
+ void testZeroIndexFeatureWithoutUsingZeroIndexing() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
+ config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
+ config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
+ });
}
- @Test(expected = IndexOutOfBoundsException.class)
- public void testFeatureIndexExceedsNumFeatures() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- rr.next();
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testLabelIndexExceedsNumLabels() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
- config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
- config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- rr.next();
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
- config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
- config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
- rr.next();
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
- LibSvmRecordReader rr = new LibSvmRecordReader();
- Configuration config = new Configuration();
- config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
- config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
- config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
- config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
- rr.next();
+ @Test
+ @DisplayName("Test Zero Index Label Without Using Zero Indexing")
+ void testZeroIndexLabelWithoutUsingZeroIndexing() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ LibSvmRecordReader rr = new LibSvmRecordReader();
+ Configuration config = new Configuration();
+ config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
+ config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
+ config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
+ config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
+ rr.next();
+ });
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java
index 18dc8b0fd..dd81758d0 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@@ -31,10 +30,9 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
@@ -45,34 +43,31 @@ import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
+@DisplayName("Line Reader Test")
+class LineReaderTest extends BaseND4JTest {
-public class LineReaderTest extends BaseND4JTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
@Test
- public void testLineReader() throws Exception {
- File tmpdir = testDir.newFolder();
+ @DisplayName("Test Line Reader")
+ void testLineReader(@TempDir Path tmpDir) throws Exception {
+ File tmpdir = tmpDir.toFile();
if (tmpdir.exists())
tmpdir.delete();
tmpdir.mkdir();
-
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
-
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
-
InputSplit split = new FileSplit(tmpdir);
-
RecordReader reader = new LineRecordReader();
reader.initialize(split);
-
int count = 0;
List> list = new ArrayList<>();
while (reader.hasNext()) {
@@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest {
list.add(l);
count++;
}
-
assertEquals(9, count);
}
@Test
- public void testLineReaderMetaData() throws Exception {
- File tmpdir = testDir.newFolder();
-
+ @DisplayName("Test Line Reader Meta Data")
+ void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception {
+ File tmpdir = tmpDir.toFile();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
-
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
-
InputSplit split = new FileSplit(tmpdir);
-
RecordReader reader = new LineRecordReader();
reader.initialize(split);
-
List> list = new ArrayList<>();
while (reader.hasNext()) {
list.add(reader.next());
}
assertEquals(9, list.size());
-
-
List> out2 = new ArrayList<>();
List out3 = new ArrayList<>();
List meta = new ArrayList<>();
@@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest {
assertEquals(uri, split.locations()[fileIdx]);
count++;
}
-
assertEquals(list, out2);
-
List fromMeta = reader.loadFromMetaData(meta);
assertEquals(out3, fromMeta);
-
- //try: second line of second and third files only...
+ // try: second line of second and third files only...
List subsetMeta = new ArrayList<>();
subsetMeta.add(meta.get(4));
subsetMeta.add(meta.get(7));
@@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest {
}
@Test
- public void testLineReaderWithInputStreamInputSplit() throws Exception {
- File tmpdir = testDir.newFolder();
-
+ @DisplayName("Test Line Reader With Input Stream Input Split")
+ void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception {
+ File tmpdir = testDir.toFile();
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
-
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
os.flush();
os.close();
-
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
-
RecordReader reader = new LineRecordReader();
reader.initialize(split);
-
int count = 0;
while (reader.hasNext()) {
assertEquals(1, reader.next().size());
count++;
}
-
assertEquals(9, count);
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java
index 97e1a854a..997a6de10 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record;
@@ -34,43 +33,40 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import org.junit.jupiter.api.DisplayName;
+import java.nio.file.Path;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
+@DisplayName("Regex Record Reader Test")
+class RegexRecordReaderTest extends BaseND4JTest {
-public class RegexRecordReaderTest extends BaseND4JTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
+ @TempDir
+ public Path testDir;
@Test
- public void testRegexLineRecordReader() throws Exception {
+ @DisplayName("Test Regex Line Record Reader")
+ void testRegexLineRecordReader() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
-
RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
-
- List exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"),
- new Text("DEBUG"), new Text("First entry message!"));
- List exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
- new Text("INFO"), new Text("Second entry message!"));
- List exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
- new Text("WARN"), new Text("Third entry message!"));
+ List exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"));
+ List exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"));
+ List exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"));
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
assertEquals(exp2, rr.next());
assertFalse(rr.hasNext());
-
- //Test reset:
+ // Test reset:
rr.reset();
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
@@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testRegexLineRecordReaderMeta() throws Exception {
+ @DisplayName("Test Regex Line Record Reader Meta")
+ void testRegexLineRecordReaderMeta() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
-
RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
-
List> list = new ArrayList<>();
while (rr.hasNext()) {
list.add(rr.next());
}
assertEquals(3, list.size());
-
List list2 = new ArrayList<>();
List> list3 = new ArrayList<>();
List meta = new ArrayList<>();
rr.reset();
- int count = 1; //Start by skipping 1 line
+ // Start by skipping 1 line
+ int count = 1;
while (rr.hasNext()) {
Record r = rr.nextRecord();
list2.add(r);
list3.add(r.getRecord());
meta.add(r.getMetaData());
-
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
}
-
List fromMeta = rr.loadFromMetaData(meta);
-
assertEquals(list, list3);
assertEquals(list2, fromMeta);
}
@Test
- public void testRegexSequenceRecordReader() throws Exception {
+ @DisplayName("Test Regex Sequence Record Reader")
+ void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
-
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
- File f = testDir.newFolder();
+ File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
-
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
-
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is);
-
List> exp0 = new ArrayList<>();
- exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"),
- new Text("First entry message!")));
- exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"),
- new Text("Second entry message!")));
- exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
- new Text("Third entry message!")));
-
-
+ exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")));
+ exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")));
+ exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")));
List> exp1 = new ArrayList<>();
- exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"),
- new Text("First entry message!")));
- exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"),
- new Text("Second entry message!")));
- exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
- new Text("Third entry message!")));
-
+ exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!")));
+ exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!")));
+ exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!")));
assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord());
assertFalse(rr.hasNext());
-
- //Test resetting:
+ // Test resetting:
rr.reset();
assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord());
@@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testRegexSequenceRecordReaderMeta() throws Exception {
+ @DisplayName("Test Regex Sequence Record Reader Meta")
+ void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
-
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
- File f = testDir.newFolder();
+ File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
-
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
-
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is);
-
List>> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.sequenceRecord());
}
-
assertEquals(2, out.size());
List>> out2 = new ArrayList<>();
List out3 = new ArrayList<>();
@@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest {
out3.add(seqr);
meta.add(seqr.getMetaData());
}
-
List fromMeta = rr.loadSequenceFromMetaData(meta);
-
assertEquals(out, out2);
assertEquals(out3, fromMeta);
}
-
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java
index 35b2d6a46..c072cea97 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration;
@@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
-
import java.io.IOException;
import java.util.*;
-
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
-import static org.junit.Assert.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
+import static org.junit.jupiter.api.Assertions.assertThrows;
-public class SVMLightRecordReaderTest extends BaseND4JTest {
+@DisplayName("Svm Light Record Reader Test")
+class SVMLightRecordReaderTest extends BaseND4JTest {
@Test
- public void testBasicRecord() throws IOException, InterruptedException {
+ @DisplayName("Test Basic Record")
+ void testBasicRecord() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- new IntWritable(7)));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- new IntWritable(2)));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
// 33
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- new IntWritable(33)));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testNoAppendLabel() throws IOException, InterruptedException {
+ @DisplayName("Test No Append Label")
+ void testNoAppendLabel() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5)));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 33
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testNoLabel() throws IOException, InterruptedException {
+ @DisplayName("Test No Label")
+ void testNoLabel() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
- // 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5)));
- // qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO));
- // 1:1.0
- correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO));
- //
- correct.put(3, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO));
-
+ // 2:1 4:2 6:3 8:4 10:5
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
+ // qid:42 1:0.1 2:2 6:6.6 8:80
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
+ // 1:1.0
+ correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
+ //
+ correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testMultioutputRecord() throws IOException, InterruptedException {
+ @DisplayName("Test Multioutput Record")
+ void testMultioutputRecord() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- new IntWritable(7), new DoubleWritable(2.45),
- new IntWritable(9)));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- new IntWritable(2), new IntWritable(3),
- new IntWritable(4)));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
// 33,32.0,31.9
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- new IntWritable(33), new DoubleWritable(32.0),
- new DoubleWritable(31.9)));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
-
@Test
- public void testMultilabelRecord() throws IOException, InterruptedException {
+ @DisplayName("Test Multilabel Record")
+ void testMultilabelRecord() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- LABEL_ONE, LABEL_ZERO,
- LABEL_ONE, LABEL_ZERO));
+ correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- LABEL_ZERO, LABEL_ONE,
- LABEL_ZERO, LABEL_ZERO));
+ correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
- correct.put(2, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ONE, LABEL_ONE,
- LABEL_ZERO, LABEL_ONE));
- // 1:1.0
- correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
- //
- correct.put(4, Arrays.asList(ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
+ // 1:1.0
+ correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
+ //
+ correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testZeroBasedIndexing() throws IOException, InterruptedException {
+ @DisplayName("Test Zero Based Indexing")
+ void testZeroBasedIndexing() throws IOException, InterruptedException {
Map> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
- correct.put(0, Arrays.asList(ZERO,
- ZERO, ONE,
- ZERO, new DoubleWritable(2),
- ZERO, new DoubleWritable(3),
- ZERO, new DoubleWritable(4),
- ZERO, new DoubleWritable(5),
- LABEL_ZERO,
- LABEL_ONE, LABEL_ZERO,
- LABEL_ONE, LABEL_ZERO));
+ correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
- correct.put(1, Arrays.asList(ZERO,
- new DoubleWritable(0.1), new DoubleWritable(2),
- ZERO, ZERO,
- ZERO, new DoubleWritable(6.6),
- ZERO, new DoubleWritable(80),
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ZERO, LABEL_ONE,
- LABEL_ZERO, LABEL_ZERO));
+ correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
- correct.put(2, Arrays.asList(ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ONE, LABEL_ONE,
- LABEL_ZERO, LABEL_ONE));
- // 1:1.0
- correct.put(3, Arrays.asList(ZERO,
- new DoubleWritable(1.0), ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
- //
- correct.put(4, Arrays.asList(ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- ZERO, ZERO,
- LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO,
- LABEL_ZERO, LABEL_ZERO));
-
+ correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
+ // 1:1.0
+ correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
+ //
+ correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
// Zero-based indexing is default
- config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
+ // NOT STANDARD!
+ config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
@@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
- public void testNextRecord() throws IOException, InterruptedException {
+ @DisplayName("Test Next Record")
+ void testNextRecord() throws IOException, InterruptedException {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
-
Record record = rr.nextRecord();
List recordList = record.getRecord();
assertEquals(new DoubleWritable(1.0), recordList.get(1));
assertEquals(new DoubleWritable(3.0), recordList.get(5));
assertEquals(new DoubleWritable(4.0), recordList.get(7));
-
record = rr.nextRecord();
recordList = record.getRecord();
assertEquals(new DoubleWritable(0.1), recordList.get(0));
@@ -354,82 +223,102 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(new DoubleWritable(80.0), recordList.get(7));
}
- @Test(expected = NoSuchElementException.class)
- public void testNoSuchElementException() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Test No Such Element Exception")
+ void testNoSuchElementException() {
+ assertThrows(NoSuchElementException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
+ while (rr.hasNext()) rr.next();
rr.next();
- rr.next();
+ });
}
- @Test(expected = UnsupportedOperationException.class)
- public void failedToSetNumFeaturesException() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Failed To Set Num Features Exception")
+ void failedToSetNumFeaturesException() {
+ assertThrows(UnsupportedOperationException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
+ while (rr.hasNext()) rr.next();
+ });
+ }
+
+ @Test
+ @DisplayName("Test Inconsistent Num Labels Exception")
+ void testInconsistentNumLabelsException() {
+ assertThrows(UnsupportedOperationException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
+ while (rr.hasNext()) rr.next();
+ });
+ }
+
+ @Test
+ @DisplayName("Failed To Set Num Multiabels Exception")
+ void failedToSetNumMultiabelsException() {
+ assertThrows(UnsupportedOperationException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
+ while (rr.hasNext()) rr.next();
+ });
+ }
+
+ @Test
+ @DisplayName("Test Feature Index Exceeds Num Features")
+ void testFeatureIndexExceedsNumFeatures() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
+ });
}
- @Test(expected = UnsupportedOperationException.class)
- public void testInconsistentNumLabelsException() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Test Label Index Exceeds Num Labels")
+ void testLabelIndexExceedsNumLabels() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
+ config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
+ });
}
- @Test(expected = UnsupportedOperationException.class)
- public void failedToSetNumMultiabelsException() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
- while (rr.hasNext())
+ @Test
+ @DisplayName("Test Zero Index Feature Without Using Zero Indexing")
+ void testZeroIndexFeatureWithoutUsingZeroIndexing() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
+ config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
+ });
}
- @Test(expected = IndexOutOfBoundsException.class)
- public void testFeatureIndexExceedsNumFeatures() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- rr.next();
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testLabelIndexExceedsNumLabels() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
- config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
- rr.next();
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
- config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
- rr.next();
- }
-
- @Test(expected = IndexOutOfBoundsException.class)
- public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
- SVMLightRecordReader rr = new SVMLightRecordReader();
- Configuration config = new Configuration();
- config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
- config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
- config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
- rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
- rr.next();
+ @Test
+ @DisplayName("Test Zero Index Label Without Using Zero Indexing")
+ void testZeroIndexLabelWithoutUsingZeroIndexing() {
+ assertThrows(IndexOutOfBoundsException.class, () -> {
+ SVMLightRecordReader rr = new SVMLightRecordReader();
+ Configuration config = new Configuration();
+ config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
+ config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
+ config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
+ rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
+ rr.next();
+ });
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java
index 5890722b3..c63240896 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.writer.impl;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.io.File;
import java.util.ArrayList;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertEquals;
-
-public class CSVRecordWriterTest extends BaseND4JTest {
-
- @Before
- public void setUp() throws Exception {
+@DisplayName("Csv Record Writer Test")
+class CSVRecordWriterTest extends BaseND4JTest {
+ @BeforeEach
+ void setUp() throws Exception {
}
@Test
- public void testWrite() throws Exception {
+ @DisplayName("Test Write")
+ void testWrite() throws Exception {
File tempFile = File.createTempFile("datavec", "writer");
tempFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(tempFile);
CSVRecordWriter writer = new CSVRecordWriter();
- writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
List collection = new ArrayList<>();
collection.add(new Text("12"));
collection.add(new Text("13"));
collection.add(new Text("14"));
-
writer.write(collection);
-
CSVRecordReader reader = new CSVRecordReader(0);
reader.initialize(new FileSplit(tempFile));
int cnt = 0;
while (reader.hasNext()) {
List line = new ArrayList<>(reader.next());
assertEquals(3, line.size());
-
assertEquals(12, line.get(0).toInt());
assertEquals(13, line.get(1).toInt());
assertEquals(14, line.get(2).toInt());
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java
index 0e80e10b7..66c9ab3d2 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils;
@@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
-
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
+import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.Assert.assertEquals;
-
-public class LibSvmRecordWriterTest extends BaseND4JTest {
+@DisplayName("Lib Svm Record Writer Test")
+class LibSvmRecordWriterTest extends BaseND4JTest {
@Test
- public void testBasic() throws Exception {
+ @DisplayName("Test Basic")
+ void testBasic() throws Exception {
Configuration configWriter = new Configuration();
-
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testNoLabel() throws Exception {
+ @DisplayName("Test No Label")
+ void testNoLabel() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
-
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testMultioutputRecord() throws Exception {
+ @DisplayName("Test Multioutput Record")
+ void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
-
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testMultilabelRecord() throws Exception {
+ @DisplayName("Test Multilabel Record")
+ void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
-
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testZeroBasedIndexing() throws Exception {
+ @DisplayName("Test Zero Based Indexing")
+ void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
-
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
- FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ FileSplit outputSplit = new FileSplit(tempFile);
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) {
@@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
writer.write(record);
}
}
-
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
List linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) {
@@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
}
@Test
- public void testNDArrayWritables() throws Exception {
+ @DisplayName("Test ND Array Writables")
+ void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13);
arr3.putScalar(1, 14);
arr3.putScalar(2, 15);
- List record = Arrays.asList((Writable) new DoubleWritable(1),
- new NDArrayWritable(arr2),
- new IntWritable(2),
- new DoubleWritable(3),
- new NDArrayWritable(arr3),
- new IntWritable(4));
+ List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
-
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
-
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
- public void testNDArrayWritablesMultilabel() throws Exception {
+ @DisplayName("Test ND Array Writables Multilabel")
+ void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
- List record = Arrays.asList((Writable) new DoubleWritable(1),
- new NDArrayWritable(arr2),
- new IntWritable(2),
- new DoubleWritable(3),
- new NDArrayWritable(arr3),
- new DoubleWritable(1));
+ List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
-
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
-
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
- public void testNDArrayWritablesZeroIndex() throws Exception {
+ @DisplayName("Test ND Array Writables Zero Index")
+ void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
- List record = Arrays.asList((Writable) new DoubleWritable(1),
- new NDArrayWritable(arr2),
- new IntWritable(2),
- new DoubleWritable(3),
- new NDArrayWritable(arr3),
- new DoubleWritable(1));
+ List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
-
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
- configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
- configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
+ // NOT STANDARD!
+ configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
+ // NOT STANDARD!
+ configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
-
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
- public void testNonIntegerButValidMultilabel() throws Exception {
- List record = Arrays.asList((Writable) new IntWritable(3),
- new IntWritable(2),
- new DoubleWritable(1.0));
+ @DisplayName("Test Non Integer But Valid Multilabel")
+ void testNonIntegerButValidMultilabel() throws Exception {
+ List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
}
- @Test(expected = NumberFormatException.class)
- public void nonIntegerMultilabel() throws Exception {
- List record = Arrays.asList((Writable) new IntWritable(3),
- new IntWritable(2),
- new DoubleWritable(1.2));
- File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
- tempFile.setWritable(true);
- tempFile.deleteOnExit();
- if (tempFile.exists())
- tempFile.delete();
-
- try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
- Configuration configWriter = new Configuration();
- configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
- configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
- configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
- FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
- writer.write(record);
- }
+ @Test
+ @DisplayName("Non Integer Multilabel")
+ void nonIntegerMultilabel() {
+ assertThrows(NumberFormatException.class, () -> {
+ List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
+ File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
+ tempFile.setWritable(true);
+ tempFile.deleteOnExit();
+ if (tempFile.exists())
+ tempFile.delete();
+ try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
+ Configuration configWriter = new Configuration();
+ configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
+ configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
+ configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
+ FileSplit outputSplit = new FileSplit(tempFile);
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
+ writer.write(record);
+ }
+ });
}
- @Test(expected = NumberFormatException.class)
- public void nonBinaryMultilabel() throws Exception {
- List record = Arrays.asList((Writable) new IntWritable(0),
- new IntWritable(1),
- new IntWritable(2));
- File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
- tempFile.setWritable(true);
- tempFile.deleteOnExit();
- if (tempFile.exists())
- tempFile.delete();
-
- try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
- Configuration configWriter = new Configuration();
- configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0);
- configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1);
- configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true);
- FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
- writer.write(record);
- }
+ @Test
+ @DisplayName("Non Binary Multilabel")
+ void nonBinaryMultilabel() {
+ assertThrows(NumberFormatException.class, () -> {
+ List record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
+ File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
+ tempFile.setWritable(true);
+ tempFile.deleteOnExit();
+ if (tempFile.exists())
+ tempFile.delete();
+ try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
+ Configuration configWriter = new Configuration();
+ configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
+ configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
+ configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
+ FileSplit outputSplit = new FileSplit(tempFile);
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
+ writer.write(record);
+ }
+ });
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java
index 8efb2a539..56a130465 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java
@@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils;
@@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
-
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
+import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.Assert.assertEquals;
-
-public class SVMLightRecordWriterTest extends BaseND4JTest {
+@DisplayName("Svm Light Record Writer Test")
+class SVMLightRecordWriterTest extends BaseND4JTest {
@Test
- public void testBasic() throws Exception {
+ @DisplayName("Test Basic")
+ void testBasic() throws Exception {
Configuration configWriter = new Configuration();
-
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testNoLabel() throws Exception {
+ @DisplayName("Test No Label")
+ void testNoLabel() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
-
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testMultioutputRecord() throws Exception {
+ @DisplayName("Test Multioutput Record")
+ void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
-
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testMultilabelRecord() throws Exception {
+ @DisplayName("Test Multilabel Record")
+ void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
-
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
- public void testZeroBasedIndexing() throws Exception {
+ @DisplayName("Test Zero Based Indexing")
+ void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
-
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
-
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) {
@@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
writer.write(record);
}
}
-
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
List linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) {
@@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
}
@Test
- public void testNDArrayWritables() throws Exception {
+ @DisplayName("Test ND Array Writables")
+ void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13);
arr3.putScalar(1, 14);
arr3.putScalar(2, 15);
- List record = Arrays.asList((Writable) new DoubleWritable(1),
- new NDArrayWritable(arr2),
- new IntWritable(2),
- new DoubleWritable(3),
- new NDArrayWritable(arr3),
- new IntWritable(4));
+ List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
-
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
-
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
- public void testNDArrayWritablesMultilabel() throws Exception {
+ @DisplayName("Test ND Array Writables Multilabel")
+ void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
- List record = Arrays.asList((Writable) new DoubleWritable(1),
- new NDArrayWritable(arr2),
- new IntWritable(2),
- new DoubleWritable(3),
- new NDArrayWritable(arr3),
- new DoubleWritable(1));
+ List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
-
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
-
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
- public void testNDArrayWritablesZeroIndex() throws Exception {
+ @DisplayName("Test ND Array Writables Zero Index")
+ void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
- List record = Arrays.asList((Writable) new DoubleWritable(1),
- new NDArrayWritable(arr2),
- new IntWritable(2),
- new DoubleWritable(3),
- new NDArrayWritable(arr3),
- new DoubleWritable(1));
+ List record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
-
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
- configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
- configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
+ // NOT STANDARD!
+ configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
+ // NOT STANDARD!
+ configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
-
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
- public void testNonIntegerButValidMultilabel() throws Exception {
- List record = Arrays.asList((Writable) new IntWritable(3),
- new IntWritable(2),
- new DoubleWritable(1.0));
+ @DisplayName("Test Non Integer But Valid Multilabel")
+ void testNonIntegerButValidMultilabel() throws Exception {
+ List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
-
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
}
- @Test(expected = NumberFormatException.class)
- public void nonIntegerMultilabel() throws Exception {
- List record = Arrays.asList((Writable) new IntWritable(3),
- new IntWritable(2),
- new DoubleWritable(1.2));
- File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
- tempFile.setWritable(true);
- tempFile.deleteOnExit();
- if (tempFile.exists())
- tempFile.delete();
-
- try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
- Configuration configWriter = new Configuration();
- configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
- configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
- configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
- FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
- writer.write(record);
- }
+ @Test
+ @DisplayName("Non Integer Multilabel")
+ void nonIntegerMultilabel() {
+ assertThrows(NumberFormatException.class, () -> {
+ List record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
+ File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
+ tempFile.setWritable(true);
+ tempFile.deleteOnExit();
+ if (tempFile.exists())
+ tempFile.delete();
+ try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
+ Configuration configWriter = new Configuration();
+ configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
+ configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
+ configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
+ FileSplit outputSplit = new FileSplit(tempFile);
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
+ writer.write(record);
+ }
+ });
}
- @Test(expected = NumberFormatException.class)
- public void nonBinaryMultilabel() throws Exception {
- List record = Arrays.asList((Writable) new IntWritable(0),
- new IntWritable(1),
- new IntWritable(2));
- File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
- tempFile.setWritable(true);
- tempFile.deleteOnExit();
- if (tempFile.exists())
- tempFile.delete();
-
- try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
- Configuration configWriter = new Configuration();
- configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
- configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
- configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
- FileSplit outputSplit = new FileSplit(tempFile);
- writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
- writer.write(record);
- }
+ @Test
+ @DisplayName("Non Binary Multilabel")
+ void nonBinaryMultilabel() {
+ assertThrows(NumberFormatException.class, () -> {
+ List record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
+ File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
+ tempFile.setWritable(true);
+ tempFile.deleteOnExit();
+ if (tempFile.exists())
+ tempFile.delete();
+ try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
+ Configuration configWriter = new Configuration();
+ configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
+ configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
+ configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
+ FileSplit outputSplit = new FileSplit(tempFile);
+ writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
+ writer.write(record);
+ }
+ });
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java
index 79c799fd5..253eb98f4 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java
@@ -17,44 +17,43 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.split;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collection;
-
import static java.util.Arrays.asList;
-import static org.junit.Assert.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
/**
* @author Ede Meijer
*/
-public class TransformSplitTest extends BaseND4JTest {
- @Test
- public void testTransform() throws URISyntaxException {
- Collection inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
+@DisplayName("Transform Split Test")
+class TransformSplitTest extends BaseND4JTest {
+ @Test
+ @DisplayName("Test Transform")
+ void testTransform() throws URISyntaxException {
+ Collection inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
+
@Override
public URI apply(URI uri) throws URISyntaxException {
return uri.normalize();
}
});
-
- assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations());
+ assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations());
}
@Test
- public void testSearchReplace() throws URISyntaxException {
+ @DisplayName("Test Search Replace")
+ void testSearchReplace() throws URISyntaxException {
Collection inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
-
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
-
- assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")},
- SUT.locations());
+ assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations());
}
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java
index e67722f78..42351fd9a 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpArchTest.java
@@ -17,32 +17,25 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.transform.ops;
import com.tngtech.archunit.core.importer.ImportOption;
import com.tngtech.archunit.junit.AnalyzeClasses;
import com.tngtech.archunit.junit.ArchTest;
-import com.tngtech.archunit.junit.ArchUnitRunner;
import com.tngtech.archunit.lang.ArchRule;
+import com.tngtech.archunit.lang.extension.ArchUnitExtension;
+import com.tngtech.archunit.lang.extension.ArchUnitExtensions;
import org.junit.runner.RunWith;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.io.Serializable;
-
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-@RunWith(ArchUnitRunner.class)
-@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class})
-public class AggregableMultiOpArchTest extends BaseND4JTest {
+@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class })
+@DisplayName("Aggregable Multi Op Arch Test")
+class AggregableMultiOpArchTest extends BaseND4JTest {
@ArchTest
- public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes()
- .that().resideInAPackage("org.datavec.api.transform.ops")
- .and().doNotHaveSimpleName("AggregatorImpls")
- .and().doNotHaveSimpleName("IAggregableReduceOp")
- .and().doNotHaveSimpleName("StringAggregatorImpls")
- .and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1")
- .should().implement(Serializable.class)
- .because("All aggregate ops must be serializable.");
-}
\ No newline at end of file
+ public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable.");
+}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java
index cb4fdeb04..acd2971ac 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java
@@ -17,52 +17,46 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.util.*;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertTrue;
-
-public class AggregableMultiOpTest extends BaseND4JTest {
+@DisplayName("Aggregable Multi Op Test")
+class AggregableMultiOpTest extends BaseND4JTest {
private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@Test
- public void testMulti() throws Exception {
+ @DisplayName("Test Multi")
+ void testMulti() throws Exception {
AggregatorImpls.AggregableFirst af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp multi = new AggregableMultiOp<>(Arrays.asList(af, as));
-
assertTrue(multi.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
multi.accept(intList.get(i));
}
-
// mutablility
assertTrue(as.get().toDouble() == 45D);
assertTrue(af.get().toInt() == 1);
-
List res = multi.get();
assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1);
-
AggregatorImpls.AggregableFirst rf = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum rs = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
-
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
-
List revRes = reverse.get();
assertTrue(revRes.get(1).toDouble() == 45D);
assertTrue(revRes.get(0).toInt() == 9);
-
multi.combine(reverse);
List combinedRes = multi.get();
assertTrue(combinedRes.get(1).toDouble() == 90D);
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java
index 47da27bdc..e7c8de557 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java
@@ -17,41 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.transform.ops;
import org.junit.Rule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.DisplayName;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-
-public class AggregatorImplsTest extends BaseND4JTest {
+@DisplayName("Aggregator Impls Test")
+class AggregatorImplsTest extends BaseND4JTest {
private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
+
private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test
- public void aggregableFirstTest() {
+ @DisplayName("Aggregable First Test")
+ void aggregableFirstTest() {
AggregatorImpls.AggregableFirst first = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) {
first.accept(intList.get(i));
}
assertEquals(1, first.get().toInt());
-
AggregatorImpls.AggregableFirst firstS = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < stringList.size(); i++) {
firstS.accept(stringList.get(i));
}
assertTrue(firstS.get().toString().equals("arakoa"));
-
-
AggregatorImpls.AggregableFirst reverse = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(1, first.get().toInt());
}
-
@Test
- public void aggregableLastTest() {
+ @DisplayName("Aggregable Last Test")
+ void aggregableLastTest() {
AggregatorImpls.AggregableLast last = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) {
last.accept(intList.get(i));
}
assertEquals(9, last.get().toInt());
-
AggregatorImpls.AggregableLast lastS = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i));
}
assertTrue(lastS.get().toString().equals("acceptance"));
-
-
AggregatorImpls.AggregableLast reverse = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
- public void aggregableCountTest() {
+ @DisplayName("Aggregable Count Test")
+ void aggregableCountTest() {
AggregatorImpls.AggregableCount cnt = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) {
cnt.accept(intList.get(i));
}
assertEquals(9, cnt.get().toInt());
-
AggregatorImpls.AggregableCount lastS = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i));
}
assertEquals(4, lastS.get().toInt());
-
-
AggregatorImpls.AggregableCount reverse = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
- public void aggregableMaxTest() {
+ @DisplayName("Aggregable Max Test")
+ void aggregableMaxTest() {
AggregatorImpls.AggregableMax mx = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i));
}
assertEquals(9, mx.get().toInt());
-
-
AggregatorImpls.AggregableMax reverse = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, mx.get().toInt());
}
-
@Test
- public void aggregableRangeTest() {
+ @DisplayName("Aggregable Range Test")
+ void aggregableRangeTest() {
AggregatorImpls.AggregableRange mx = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i));
}
assertEquals(8, mx.get().toInt());
-
-
AggregatorImpls.AggregableRange reverse = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1) + 9);
@@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
- public void aggregableMinTest() {
+ @DisplayName("Aggregable Min Test")
+ void aggregableMinTest() {
AggregatorImpls.AggregableMin mn = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i));
}
assertEquals(1, mn.get().toInt());
-
-
AggregatorImpls.AggregableMin reverse = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
- public void aggregableSumTest() {
+ @DisplayName("Aggregable Sum Test")
+ void aggregableSumTest() {
AggregatorImpls.AggregableSum sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
}
assertEquals(45, sm.get().toInt());
-
-
AggregatorImpls.AggregableSum reverse = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(90, sm.get().toInt());
}
-
@Test
- public void aggregableMeanTest() {
+ @DisplayName("Aggregable Mean Test")
+ void aggregableMeanTest() {
AggregatorImpls.AggregableMean mn = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i));
}
assertEquals(9l, (long) mn.getCount());
assertEquals(5D, mn.get().toDouble(), 0.001);
-
-
AggregatorImpls.AggregableMean reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
- public void aggregableStdDevTest() {
+ @DisplayName("Aggregable Std Dev Test")
+ void aggregableStdDevTest() {
AggregatorImpls.AggregableStdDev sd = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
-
-
AggregatorImpls.AggregableStdDev reverse = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
- assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001);
+ assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble());
}
@Test
- public void aggregableVariance() {
+ @DisplayName("Aggregable Variance")
+ void aggregableVariance() {
AggregatorImpls.AggregableVariance sd = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
-
-
AggregatorImpls.AggregableVariance reverse = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
- assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001);
+ assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble());
}
@Test
- public void aggregableUncorrectedStdDevTest() {
+ @DisplayName("Aggregable Uncorrected Std Dev Test")
+ void aggregableUncorrectedStdDevTest() {
AggregatorImpls.AggregableUncorrectedStdDev sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
-
-
- AggregatorImpls.AggregableUncorrectedStdDev reverse =
- new AggregatorImpls.AggregableUncorrectedStdDev<>();
+ AggregatorImpls.AggregableUncorrectedStdDev reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
- assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001);
+ assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble());
}
-
@Test
- public void aggregablePopulationVariance() {
+ @DisplayName("Aggregable Population Variance")
+ void aggregablePopulationVariance() {
AggregatorImpls.AggregablePopulationVariance sd = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
-
-
- AggregatorImpls.AggregablePopulationVariance reverse =
- new AggregatorImpls.AggregablePopulationVariance<>();
+ AggregatorImpls.AggregablePopulationVariance reverse = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
- assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001);
+ assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble());
}
@Test
- public void aggregableCountUniqueTest() {
+ @DisplayName("Aggregable Count Unique Test")
+ void aggregableCountUniqueTest() {
// at this low range, it's linear counting
-
AggregatorImpls.AggregableCountUnique cu = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) {
cu.accept(intList.get(i));
@@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt());
cu.accept(1);
assertEquals(9, cu.get().toInt());
-
AggregatorImpls.AggregableCountUnique reverse = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -290,16 +268,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
@Rule
public final ExpectedException exception = ExpectedException.none();
-
@Test
- public void incompatibleAggregatorTest() {
+ @DisplayName("Incompatible Aggregator Test")
+ void incompatibleAggregatorTest() {
AggregatorImpls.AggregableSum sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
}
assertEquals(45, sm.get().toInt());
-
-
AggregatorImpls.AggregableMean reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@@ -308,5 +284,4 @@ public class AggregatorImplsTest extends BaseND4JTest {
sm.combine(reverse);
assertEquals(45, sm.get().toInt());
}
-
}
diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java
index 098c5635a..6a444923d 100644
--- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java
+++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java
@@ -17,77 +17,65 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
-
package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
-
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.api.DisplayName;
+import org.junit.jupiter.api.extension.ExtendWith;
-import static org.junit.Assert.assertTrue;
-
-public class DispatchOpTest extends BaseND4JTest {
+@DisplayName("Dispatch Op Test")
+class DispatchOpTest extends BaseND4JTest {
private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
+
private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test
- public void testDispatchSimple() {
+ @DisplayName("Test Dispatch Simple")
+ void testDispatchSimple() {
AggregatorImpls.AggregableFirst af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum as = new AggregatorImpls.AggregableSum<>();
- AggregableMultiOp multiaf =
- new AggregableMultiOp<>(Collections.>singletonList(af));
- AggregableMultiOp multias =
- new AggregableMultiOp<>(Collections.>singletonList(as));
-
- DispatchOp parallel =
- new DispatchOp<>(Arrays.>>asList(multiaf, multias));
-
+ AggregableMultiOp multiaf = new AggregableMultiOp<>(Collections.>singletonList(af));
+ AggregableMultiOp multias = new AggregableMultiOp<>(Collections.>singletonList(as));
+ DispatchOp parallel = new DispatchOp<>(Arrays.>>asList(multiaf, multias));
assertTrue(multiaf.getOperations().size() == 1);
assertTrue(multias.getOperations().size() == 1);
assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
}
-
List