Merge pull request #9233 from eclipse/ag_junit5

Upgrade dl4j to junit 5
master
Adam Gibson 2021-03-19 21:26:43 +09:00 committed by GitHub
commit c505a11ed6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1022 changed files with 28893 additions and 38578 deletions

View File

@ -31,7 +31,7 @@ jobs:
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -DskipTestResourceEnforcement=true -Pintegration-tests -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -Pintegration-tests -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -31,7 +31,7 @@ jobs:
protoc --version
cd dl4j-test-resources-master && mvn clean install -DskipTests && cd ..
export OMP_NUM_THREADS=1
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.buildthreads=1 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
windows-x86_64:
runs-on: windows-2019
@ -44,7 +44,7 @@ jobs:
run: |
set "PATH=C:\msys64\usr\bin;%PATH%"
export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -DskipTestResourceEnforcement=true -Ptestresources -Dlibnd4j.buildthreads=1 -Dlibnd4j.build="Debug" -Djavacpp.platform=windows-x86_64 -libnd4j.platform=windows-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test
@ -60,5 +60,5 @@ jobs:
run: |
brew install unzip ccache gcc swig autoconf-archive automake cmake libomp libtool libusb ant maven nasm xz pkg-config sdl gpg1 bison flex perl ragel binutils gradle gmp isl libmpc mpfr wget python
export OMP_NUM_THREADS=1
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Ptest-nd4j-native -Dlibnd4j.chip=cpu clean test
mvn -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Dlibnd4j.build="Debug" -Dlibnd4j.buildthreads=1 -Ptestresources -Djavacpp.platform=macosx-x86_64 -libnd4j.platform=macosx-x86_64 -Pnd4j-tests-cpu -Dlibnd4j.chip=cpu clean test

View File

@ -1,11 +1,3 @@
on:
workflow_dispatch:
jobs:
# Wait for up to a minute for previous run to complete, abort if not done by then
pre-ci:
run
on:
workflow_dispatch:
jobs:
@ -42,5 +34,5 @@ jobs:
cmake --version
protoc --version
export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pdl4j-integration-tests -Pnd4j-tests-cpu clean test
mvn -DskipTestResourceEnforcement=true -Ptestresources -Pintegration-tests -Pnd4j-tests-cpu clean test -rf :rl4j-core

View File

@ -34,5 +34,5 @@ jobs:
cmake --version
protoc --version
export OMP_NUM_THREADS=1
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Ptest-nd4j-native --also-make clean test
mvn -DskipTestResourceEnforcement=true -Ptestresources -pl ":deeplearning4j-modelimport,:deeplearning4j-core,:nd4j-native,:samediff-import,:libnd4j" -Pnd4j-tests-cpu --also-make clean test

View File

@ -15,7 +15,7 @@
<commons.dbutils.version>1.7</commons.dbutils.version>
<lombok.version>1.18.8</lombok.version>
<logback.version>1.1.7</logback.version>
<junit.version>4.12</junit.version>
<junit.version>5.8.0-M1</junit.version>
<junit-jupiter.version>5.4.2</junit-jupiter.version>
<java.version>1.8</java.version>
<maven-shade-plugin.version>3.1.1</maven-shade-plugin.version>

View File

@ -17,13 +17,14 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.ir;
public class SerializationTest {
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public static void main(String...args) {
@DisplayName("Serialization Test")
class SerializationTest {
public static void main(String... args) {
}
}

View File

@ -17,29 +17,23 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.codegen.dsl;
import org.apache.commons.lang3.StringUtils;
import org.junit.jupiter.api.Test;
import org.nd4j.codegen.impl.java.DocsGenerator;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class DocsGeneratorTest {
@DisplayName("Docs Generator Test")
class DocsGeneratorTest {
@Test
public void testJDtoMDAdapter() {
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" +
" eye:\n" +
" [ 1, 0]\n" +
" [ 0, 1]\n" +
" [ 0, 0]}";
String expected = "{ INDArray eye = eye(3,2)\n" +
" eye:\n" +
" [ 1, 0]\n" +
" [ 0, 1]\n" +
" [ 0, 0]}";
@DisplayName("Test J Dto MD Adapter")
void testJDtoMDAdapter() {
String original = "{@code %INPUT_TYPE% eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
String expected = "{ INDArray eye = eye(3,2)\n" + " eye:\n" + " [ 1, 0]\n" + " [ 0, 1]\n" + " [ 0, 0]}";
DocsGenerator.JavaDocToMDAdapter adapter = new DocsGenerator.JavaDocToMDAdapter(original);
String out = adapter.filter("@code", StringUtils.EMPTY).filter("%INPUT_TYPE%", "INDArray").toString();
assertEquals(out, expected);

View File

@ -34,6 +34,14 @@
<artifactId>datavec-api</artifactId>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>
<dependency>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
@ -101,10 +109,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -26,47 +25,38 @@ import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Csv Line Sequence Record Reader Test")
class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void test() throws Exception {
File f = testDir.newFolder();
@DisplayName("Test")
void test(@TempDir Path testDir) throws Exception {
File f = testDir.toFile();
File source = new File(f, "temp.csv");
String str = "a,b,c\n1,2,3,4";
FileUtils.writeStringToFile(source, str, StandardCharsets.UTF_8);
SequenceRecordReader rr = new CSVLineSequenceRecordReader();
rr.initialize(new FileSplit(source));
List<List<Writable>> exp0 = Arrays.asList(
Collections.<Writable>singletonList(new Text("a")),
Collections.<Writable>singletonList(new Text("b")),
Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(
Collections.<Writable>singletonList(new Text("1")),
Collections.<Writable>singletonList(new Text("2")),
Collections.<Writable>singletonList(new Text("3")),
Collections.<Writable>singletonList(new Text("4")));
for( int i=0; i<3; i++ ) {
List<List<Writable>> exp0 = Arrays.asList(Collections.<Writable>singletonList(new Text("a")), Collections.<Writable>singletonList(new Text("b")), Collections.<Writable>singletonList(new Text("c")));
List<List<Writable>> exp1 = Arrays.asList(Collections.<Writable>singletonList(new Text("1")), Collections.<Writable>singletonList(new Text("2")), Collections.<Writable>singletonList(new Text("3")), Collections.<Writable>singletonList(new Text("4")));
for (int i = 0; i < 3; i++) {
int count = 0;
while (rr.hasNext()) {
List<List<Writable>> next = rr.sequenceRecord();
@ -76,9 +66,7 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest {
assertEquals(exp1, next);
}
}
assertEquals(2, count);
rr.reset();
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -26,33 +25,37 @@ import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@DisplayName("Csv Multi Sequence Record Reader Test")
class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testConcatMode() throws Exception {
for( int i=0; i<3; i++ ) {
@DisplayName("Test Concat Mode")
@Disabled
void testConcatMode() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
switch (i){
switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@ -68,31 +71,23 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
String str = "a,b,c\n1,2,3,4\nx,y\n" + seqSep + "\nA,B,C";
File f = testDir.newFile();
File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.CONCAT);
seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = new ArrayList<>();
for (String s : "a,b,c,1,2,3,4,x,y".split(",")) {
exp0.add(Collections.<Writable>singletonList(new Text(s)));
}
List<List<Writable>> exp1 = new ArrayList<>();
for (String s : "A,B,C".split(",")) {
exp1.add(Collections.<Writable>singletonList(new Text(s)));
}
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
@ -100,13 +95,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testEqualLength() throws Exception {
for( int i=0; i<3; i++ ) {
@DisplayName("Test Equal Length")
@Disabled
void testEqualLength() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
switch (i) {
switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@ -122,27 +117,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
String str = "a,b\n1,2\nx,y\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile();
File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.EQUAL_LENGTH);
seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("2"), new Text("y")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
@ -150,13 +135,13 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testPadding() throws Exception {
for( int i=0; i<3; i++ ) {
@DisplayName("Test Padding")
@Disabled
void testPadding() throws Exception {
for (int i = 0; i < 3; i++) {
String seqSep;
String seqSepRegex;
switch (i) {
switch(i) {
case 0:
seqSep = "";
seqSepRegex = "^$";
@ -172,27 +157,17 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest {
default:
throw new RuntimeException();
}
String str = "a,b\n1\nx\n" + seqSep + "\nA\nB\nC";
File f = testDir.newFile();
File f = testDir.toFile();
FileUtils.writeStringToFile(f, str, StandardCharsets.UTF_8);
SequenceRecordReader seqRR = new CSVMultiSequenceRecordReader(seqSepRegex, CSVMultiSequenceRecordReader.Mode.PAD, new Text("PAD"));
seqRR.initialize(new FileSplit(f));
List<List<Writable>> exp0 = Arrays.asList(
Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")),
Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp0 = Arrays.asList(Arrays.<Writable>asList(new Text("a"), new Text("1"), new Text("x")), Arrays.<Writable>asList(new Text("b"), new Text("PAD"), new Text("PAD")));
List<List<Writable>> exp1 = Collections.singletonList(Arrays.<Writable>asList(new Text("A"), new Text("B"), new Text("C")));
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());
seqRR.reset();
assertEquals(exp0, seqRR.sequenceRecord());
assertEquals(exp1, seqRR.sequenceRecord());
assertFalse(seqRR.hasNext());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord;
@ -27,61 +26,53 @@ import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@DisplayName("Csvn Lines Sequence Record Reader Test")
class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
@Test
public void testCSVNLinesSequenceRecordReader() throws Exception {
@DisplayName("Test CSVN Lines Sequence Record Reader")
void testCSVNLinesSequenceRecordReader() throws Exception {
int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0;
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
List<List<Writable>> expected = new ArrayList<>();
for (int i = 0; i < nLinesPerSequence; i++) {
expected.add(rr.next());
}
assertEquals(10, next.size());
assertEquals(expected, next);
count++;
}
assertEquals(150 / nLinesPerSequence, count);
}
@Test
public void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
@DisplayName("Test CSV Nlines Sequence Record Reader Meta Data")
void testCSVNlinesSequenceRecordReaderMetaData() throws Exception {
int nLinesPerSequence = 10;
SequenceRecordReader seqRR = new CSVNLinesSequenceRecordReader(nLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
List<List<List<Writable>>> out = new ArrayList<>();
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
out.add(next);
}
seqRR.reset();
List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>();
@ -92,11 +83,8 @@ public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest {
meta.add(seq.getMetaData());
out3.add(seq);
}
assertEquals(out, out2);
List<SequenceRecord> out4 = seqRR.loadSequenceFromMetaData(meta);
assertEquals(out3, out4);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -34,10 +33,10 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
@ -47,41 +46,44 @@ import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Csv Record Reader Test")
class CSVRecordReaderTest extends BaseND4JTest {
public class CSVRecordReaderTest extends BaseND4JTest {
@Test
public void testNext() throws Exception {
@DisplayName("Test Next")
void testNext() throws Exception {
CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,1"));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
List<Writable> arr = new ArrayList<>(vals);
assertEquals("Entry count", 23, vals.size());
assertEquals(23, vals.size(), "Entry count");
Text lastEntry = (Text) arr.get(arr.size() - 1);
assertEquals("Last entry garbage", 1, lastEntry.getLength());
assertEquals(1, lastEntry.getLength(), "Last entry garbage");
}
}
@Test
public void testEmptyEntries() throws Exception {
@DisplayName("Test Empty Entries")
void testEmptyEntries() throws Exception {
CSVRecordReader reader = new CSVRecordReader();
reader.initialize(new StringSplit("1,1,8.0,,,,14.0,,,,15.0,,,,,,,,,,,,"));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
assertEquals("Entry count", 23, vals.size());
assertEquals(23, vals.size(), "Entry count");
}
}
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5;
for (int i = 0; i < nResets; i++) {
int lineCount = 0;
while (rr.hasNext()) {
List<Writable> line = rr.next();
@ -95,7 +97,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
public void testResetWithSkipLines() throws Exception {
@DisplayName("Test Reset With Skip Lines")
void testResetWithSkipLines() throws Exception {
CSVRecordReader rr = new CSVRecordReader(10, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0;
@ -114,7 +117,8 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
public void testWrite() throws Exception {
@DisplayName("Test Write")
void testWrite() throws Exception {
List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 10; i++) {
@ -130,81 +134,72 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
list.add(temp);
}
String expected = sb.toString();
Path p = Files.createTempFile("csvwritetest", "csv");
p.toFile().deleteOnExit();
FileRecordWriter writer = new CSVRecordWriter();
FileSplit fileSplit = new FileSplit(p.toFile());
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
for (List<Writable> c : list) {
writer.write(c);
}
writer.close();
//Read file back in; compare
// Read file back in; compare
String fileContents = FileUtils.readFileToString(p.toFile(), FileRecordWriter.DEFAULT_CHARSET.name());
// System.out.println(expected);
// System.out.println("----------");
// System.out.println(fileContents);
// System.out.println(expected);
// System.out.println("----------");
// System.out.println(fileContents);
assertEquals(expected, fileContents);
}
@Test
public void testTabsAsSplit1() throws Exception {
@DisplayName("Test Tabs As Split 1")
void testTabsAsSplit1() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '\t');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/tabbed.txt").getFile()));
while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next());
assertEquals(2, list.size());
}
}
@Test
public void testPipesAsSplit() throws Exception {
@DisplayName("Test Pipes As Split")
void testPipesAsSplit() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, '|');
reader.initialize(new FileSplit(new ClassPathResource("datavec-api/issue414.csv").getFile()));
int lineidx = 0;
List<Integer> sixthColumn = Arrays.asList(13, 95, 15, 25);
while (reader.hasNext()) {
List<Writable> list = new ArrayList<>(reader.next());
assertEquals(10, list.size());
assertEquals((long)sixthColumn.get(lineidx), list.get(5).toInt());
assertEquals((long) sixthColumn.get(lineidx), list.get(5).toInt());
lineidx++;
}
}
@Test
public void testWithQuotes() throws Exception {
@DisplayName("Test With Quotes")
void testWithQuotes() throws Exception {
CSVRecordReader reader = new CSVRecordReader(0, ',', '\"');
reader.initialize(new StringSplit("1,0,3,\"Braund, Mr. Owen Harris\",male,\"\"\"\""));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
assertEquals("Entry count", 6, vals.size());
assertEquals("1", vals.get(0).toString());
assertEquals("0", vals.get(1).toString());
assertEquals("3", vals.get(2).toString());
assertEquals("Braund, Mr. Owen Harris", vals.get(3).toString());
assertEquals("male", vals.get(4).toString());
assertEquals("\"", vals.get(5).toString());
assertEquals(6, vals.size(), "Entry count");
assertEquals(vals.get(0).toString(), "1");
assertEquals(vals.get(1).toString(), "0");
assertEquals(vals.get(2).toString(), "3");
assertEquals(vals.get(3).toString(), "Braund, Mr. Owen Harris");
assertEquals(vals.get(4).toString(), "male");
assertEquals(vals.get(5).toString(), "\"");
}
}
@Test
public void testMeta() throws Exception {
@DisplayName("Test Meta")
void testMeta() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int lineCount = 0;
List<RecordMetaData> metaList = new ArrayList<>();
List<List<Writable>> writables = new ArrayList<>();
@ -213,30 +208,25 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(5, r.getRecord().size());
lineCount++;
RecordMetaData meta = r.getMetaData();
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
// System.out.println(r.getRecord() + "\t" + meta.getLocation() + "\t" + meta.getURI());
metaList.add(meta);
writables.add(r.getRecord());
}
assertFalse(rr.hasNext());
assertEquals(150, lineCount);
rr.reset();
System.out.println("\n\n\n--------------------------------");
List<Record> contents = rr.loadFromMetaData(metaList);
assertEquals(150, contents.size());
// for(Record r : contents ){
// System.out.println(r);
// }
// for(Record r : contents ){
// System.out.println(r);
// }
List<RecordMetaData> meta2 = new ArrayList<>();
meta2.add(metaList.get(100));
meta2.add(metaList.get(90));
meta2.add(metaList.get(80));
meta2.add(metaList.get(70));
meta2.add(metaList.get(60));
List<Record> contents2 = rr.loadFromMetaData(meta2);
assertEquals(writables.get(100), contents2.get(0).getRecord());
assertEquals(writables.get(90), contents2.get(1).getRecord());
@ -246,50 +236,49 @@ public class CSVRecordReaderTest extends BaseND4JTest {
}
@Test
public void testRegex() throws Exception {
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] {null, "(.+) (.+) (.+)"});
@DisplayName("Test Regex")
void testRegex() throws Exception {
CSVRecordReader reader = new CSVRegexRecordReader(0, ",", null, new String[] { null, "(.+) (.+) (.+)" });
reader.initialize(new StringSplit("normal,1.2.3.4 space separator"));
while (reader.hasNext()) {
List<Writable> vals = reader.next();
assertEquals("Entry count", 4, vals.size());
assertEquals("normal", vals.get(0).toString());
assertEquals("1.2.3.4", vals.get(1).toString());
assertEquals("space", vals.get(2).toString());
assertEquals("separator", vals.get(3).toString());
assertEquals(4, vals.size(), "Entry count");
assertEquals(vals.get(0).toString(), "normal");
assertEquals(vals.get(1).toString(), "1.2.3.4");
assertEquals(vals.get(2).toString(), "space");
assertEquals(vals.get(3).toString(), "separator");
}
}
@Test(expected = NoSuchElementException.class)
public void testCsvSkipAllLines() throws IOException, InterruptedException {
final int numLines = 4;
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1),
(Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
String header = ",one,two,three";
List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++)
lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
rr.initialize(new FileSplit(tempFile));
rr.reset();
assertTrue(!rr.hasNext());
rr.next();
@Test
@DisplayName("Test Csv Skip All Lines")
void testCsvSkipAllLines() {
assertThrows(NoSuchElementException.class, () -> {
final int numLines = 4;
final List<Writable> lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three"));
String header = ",one,two,three";
List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines, ',');
rr.initialize(new FileSplit(tempFile));
rr.reset();
assertTrue(!rr.hasNext());
rr.next();
});
}
@Test
public void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
@DisplayName("Test Csv Skip All But One Line")
void testCsvSkipAllButOneLine() throws IOException, InterruptedException {
final int numLines = 4;
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)),
new Text("one"), new Text("two"), new Text("three"));
final List<Writable> lineList = Arrays.<Writable>asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three"));
String header = ",one,two,three";
List<String> lines = new ArrayList<>();
for (int i = 0; i < numLines; i++)
lines.add(Integer.toString(i) + header);
for (int i = 0; i < numLines; i++) lines.add(Integer.toString(i) + header);
File tempFile = File.createTempFile("csvSkipLines", ".csv");
FileUtils.writeLines(tempFile, lines);
CSVRecordReader rr = new CSVRecordReader(numLines - 1, ',');
rr.initialize(new FileSplit(tempFile));
rr.reset();
@ -297,50 +286,45 @@ public class CSVRecordReaderTest extends BaseND4JTest {
assertEquals(rr.next(), lineList);
}
@Test
public void testStreamReset() throws Exception {
@DisplayName("Test Stream Reset")
void testStreamReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
int count = 0;
while(rr.hasNext()){
while (rr.hasNext()) {
assertNotNull(rr.next());
count++;
}
assertEquals(150, count);
assertFalse(rr.resetSupported());
try{
try {
rr.reset();
fail("Expected exception");
} catch (Exception e){
} catch (Exception e) {
String msg = e.getMessage();
String msg2 = e.getCause().getMessage();
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
// e.printStackTrace();
assertTrue(msg.contains("Error during LineRecordReader reset"),msg);
assertTrue(msg2.contains("Reset not supported from streams"),msg2);
// e.printStackTrace();
}
}
@Test
public void testUsefulExceptionNoInit(){
@DisplayName("Test Useful Exception No Init")
void testUsefulExceptionNoInit() {
CSVRecordReader rr = new CSVRecordReader(0, ',');
try{
try {
rr.hasNext();
fail("Expected exception");
} catch (Exception e){
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
} catch (Exception e) {
assertTrue( e.getMessage().contains("initialized"),e.getMessage());
}
try{
try {
rr.next();
fail("Expected exception");
} catch (Exception e){
assertTrue(e.getMessage(), e.getMessage().contains("initialized"));
} catch (Exception e) {
assertTrue(e.getMessage().contains("initialized"),e.getMessage());
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.SequenceRecord;
@ -27,12 +26,11 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
@ -41,25 +39,27 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Csv Sequence Record Reader Test")
class CSVSequenceRecordReaderTest extends BaseND4JTest {
public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder tempDir = new TemporaryFolder();
@TempDir
public Path tempDir;
@Test
public void test() throws Exception {
@DisplayName("Test")
void test() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
int sequenceCount = 0;
while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
// 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@ -80,19 +80,18 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
int nTests = 5;
for (int i = 0; i < nTests; i++) {
seqReader.reset();
int sequenceCount = 0;
while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
// 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@ -107,15 +106,15 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMetaData() throws Exception {
@DisplayName("Test Meta Data")
void testMetaData() throws Exception {
CSVSequenceRecordReader seqReader = new CSVSequenceRecordReader(1, ",");
seqReader.initialize(new TestInputSplit());
List<List<List<Writable>>> l = new ArrayList<>();
while (seqReader.hasNext()) {
List<List<Writable>> sequence = seqReader.sequenceRecord();
assertEquals(4, sequence.size()); //4 lines, plus 1 header line
// 4 lines, plus 1 header line
assertEquals(4, sequence.size());
Iterator<List<Writable>> timeStepIter = sequence.iterator();
int lineCount = 0;
while (timeStepIter.hasNext()) {
@ -123,10 +122,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
lineCount++;
}
assertEquals(4, lineCount);
l.add(sequence);
}
List<SequenceRecord> l2 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
seqReader.reset();
@ -136,7 +133,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
meta.add(sr.getMetaData());
}
assertEquals(3, l2.size());
List<SequenceRecord> fromMeta = seqReader.loadSequenceFromMetaData(meta);
for (int i = 0; i < 3; i++) {
assertEquals(l.get(i), l2.get(i).getSequenceRecord());
@ -144,8 +140,8 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
}
}
private static class
TestInputSplit implements InputSplit {
@DisplayName("Test Input Split")
private static class TestInputSplit implements InputSplit {
@Override
public boolean canWriteToLocation(URI location) {
@ -164,7 +160,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void updateSplitLocations(boolean reset) {
}
@Override
@ -174,7 +169,6 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void bootStrapForWrite() {
}
@Override
@ -222,38 +216,30 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest {
@Override
public void reset() {
//No op
// No op
}
@Override
public boolean resetSupported() {
return true;
}
}
@Test
public void testCsvSeqAndNumberedFileSplit() throws Exception {
File baseDir = tempDir.newFolder();
//Simple sanity check unit test
@DisplayName("Test Csv Seq And Numbered File Split")
void testCsvSeqAndNumberedFileSplit(@TempDir Path tempDir) throws Exception {
File baseDir = tempDir.toFile();
// Simple sanity check unit test
for (int i = 0; i < 3; i++) {
new ClassPathResource(String.format("csvsequence_%d.txt", i)).getTempFileFromArchive(baseDir);
}
//Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
// Load time series from CSV sequence files; compare to SequenceRecordReaderDataSetIterator
ClassPathResource resource = new ClassPathResource("csvsequence_0.txt");
String featuresPath = new File(baseDir, "csvsequence_%d.txt").getAbsolutePath();
SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
featureReader.initialize(new NumberedFileInputSplit(featuresPath, 0, 2));
while(featureReader.hasNext()){
while (featureReader.hasNext()) {
featureReader.nextSequence();
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.SequenceRecordReader;
@ -25,94 +24,87 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.util.LinkedList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@DisplayName("Csv Variable Sliding Window Record Reader Test")
class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest {
@Test
public void testCSVVariableSlidingWindowRecordReader() throws Exception {
@DisplayName("Test CSV Variable Sliding Window Record Reader")
void testCSVVariableSlidingWindowRecordReader() throws Exception {
int maxLinesPerSequence = 3;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0;
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
if(count==maxLinesPerSequence-1) {
if (count == maxLinesPerSequence - 1) {
LinkedList<List<Writable>> expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next());
}
assertEquals(expected, next);
}
if(count==maxLinesPerSequence) {
if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size());
}
if(count==0) { // first seq should be length 1
if (count == 0) {
// first seq should be length 1
assertEquals(1, next.size());
}
if(count>151) { // last seq should be length 1
if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size());
}
count++;
}
assertEquals(152, count);
}
@Test
public void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
@DisplayName("Test CSV Variable Sliding Window Record Reader Stride")
void testCSVVariableSlidingWindowRecordReaderStride() throws Exception {
int maxLinesPerSequence = 3;
int stride = 2;
SequenceRecordReader seqRR = new CSVVariableSlidingWindowRecordReader(maxLinesPerSequence, stride);
seqRR.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
CSVRecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int count = 0;
while (seqRR.hasNext()) {
List<List<Writable>> next = seqRR.sequenceRecord();
if(count==maxLinesPerSequence-1) {
if (count == maxLinesPerSequence - 1) {
LinkedList<List<Writable>> expected = new LinkedList<>();
for(int s = 0; s < stride; s++) {
for (int s = 0; s < stride; s++) {
expected = new LinkedList<>();
for (int i = 0; i < maxLinesPerSequence; i++) {
expected.addFirst(rr.next());
}
}
assertEquals(expected, next);
}
if(count==maxLinesPerSequence) {
if (count == maxLinesPerSequence) {
assertEquals(maxLinesPerSequence, next.size());
}
if(count==0) { // first seq should be length 2
if (count == 0) {
// first seq should be length 2
assertEquals(2, next.size());
}
if(count>151) { // last seq should be length 1
if (count > 151) {
// last seq should be length 1
assertEquals(1, next.size());
}
count++;
}
assertEquals(76, count);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -28,45 +27,44 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader;
import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.loader.FileBatch;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*;
public class FileBatchRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testCsv() throws Exception {
//This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.newFolder();
@DisplayName("File Batch Record Reader Test")
public class FileBatchRecordReaderTest extends BaseND4JTest {
@TempDir Path testDir;
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv")
void testCsv(Nd4jBackend backend) throws Exception {
// This is an unrealistic use case - one line/record per CSV
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){
for (int i = 0; i < 10; i++) {
String s = "file_" + i + "," + i + "," + i;
File f = new File(baseDir, "origFile" + i + ".csv");
FileUtils.writeStringToFile(f, s, StandardCharsets.UTF_8);
fileList.add(f);
}
FileBatch fb = FileBatch.forFiles(fileList);
RecordReader rr = new CSVRecordReader();
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
for( int test=0; test<3; test++) {
for (int test = 0; test < 3; test++) {
for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next();
@ -82,16 +80,17 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
}
}
@Test
public void testCsvSequence() throws Exception {
//CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.newFolder();
@ParameterizedTest
@MethodSource("org.nd4j.linalg.BaseNd4jTestWithBackends#configs")
@DisplayName("Test Csv Sequence")
void testCsvSequence(Nd4jBackend backend) throws Exception {
// CSV sequence - 3 lines per file, 10 files
File baseDir = testDir.toFile();
List<File> fileList = new ArrayList<>();
for( int i=0; i<10; i++ ){
for (int i = 0; i < 10; i++) {
StringBuilder sb = new StringBuilder();
for( int j=0; j<3; j++ ){
if(j > 0)
for (int j = 0; j < 3; j++) {
if (j > 0)
sb.append("\n");
sb.append("file_" + i + "," + i + "," + j);
}
@ -99,19 +98,16 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
FileUtils.writeStringToFile(f, sb.toString(), StandardCharsets.UTF_8);
fileList.add(f);
}
FileBatch fb = FileBatch.forFiles(fileList);
SequenceRecordReader rr = new CSVSequenceRecordReader();
FileBatchSequenceRecordReader fbrr = new FileBatchSequenceRecordReader(rr, fb);
for( int test=0; test<3; test++) {
for (int test = 0; test < 3; test++) {
for (int i = 0; i < 10; i++) {
assertTrue(fbrr.hasNext());
List<List<Writable>> next = fbrr.sequenceRecord();
assertEquals(3, next.size());
int count = 0;
for(List<Writable> step : next ){
for (List<Writable> step : next) {
String s1 = "file_" + i;
assertEquals(s1, step.get(0).toString());
assertEquals(String.valueOf(i), step.get(1).toString());
@ -123,5 +119,4 @@ public class FileBatchRecordReaderTest extends BaseND4JTest {
fbrr.reset();
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record;
@ -26,28 +25,28 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
public class FileRecordReaderTest extends BaseND4JTest {
@DisplayName("File Record Reader Test")
class FileRecordReaderTest extends BaseND4JTest {
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
FileRecordReader rr = new FileRecordReader();
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/iris.dat").getFile()));
int nResets = 5;
for (int i = 0; i < nResets; i++) {
int lineCount = 0;
while (rr.hasNext()) {
List<Writable> line = rr.next();
@ -61,25 +60,20 @@ public class FileRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMeta() throws Exception {
@DisplayName("Test Meta")
void testMeta() throws Exception {
FileRecordReader rr = new FileRecordReader();
URI[] arr = new URI[3];
arr[0] = new ClassPathResource("datavec-api/csvsequence_0.txt").getFile().toURI();
arr[1] = new ClassPathResource("datavec-api/csvsequence_1.txt").getFile().toURI();
arr[2] = new ClassPathResource("datavec-api/csvsequence_2.txt").getFile().toURI();
InputSplit is = new CollectionInputSplit(Arrays.asList(arr));
rr.initialize(is);
List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.next());
}
assertEquals(3, out.size());
rr.reset();
List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>();
@ -90,13 +84,10 @@ public class FileRecordReaderTest extends BaseND4JTest {
out2.add(r.getRecord());
out3.add(r);
meta.add(r.getMetaData());
assertEquals(arr[count++], r.getMetaData().getURI());
}
assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(out3, fromMeta);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader;
@ -28,97 +27,81 @@ import org.datavec.api.split.CollectionInputSplit;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Jackson Line Record Reader Test")
class JacksonLineRecordReaderTest extends BaseND4JTest {
public class JacksonLineRecordReaderTest extends BaseND4JTest {
@TempDir
public Path testDir;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
public JacksonLineRecordReaderTest() {
}
public JacksonLineRecordReaderTest() {
}
private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("value1").
addField("value2").
addField("value3").
addField("value4").
addField("value5").
addField("value6").
addField("value7").
addField("value8").
addField("value9").
addField("value10").build();
return new FieldSelection.Builder().addField("value1").addField("value2").addField("value3").addField("value4").addField("value5").addField("value6").addField("value7").addField("value8").addField("value9").addField("value10").build();
}
@Test
public void testReadJSON() throws Exception {
@DisplayName("Test Read JSON")
void testReadJSON() throws Exception {
RecordReader rr = new JacksonLineRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/json/json_test_3.txt").getFile()));
testJacksonRecordReader(rr);
}
private static void testJacksonRecordReader(RecordReader rr) {
while (rr.hasNext()) {
List<Writable> json0 = rr.next();
//System.out.println(json0);
assert(json0.size() > 0);
}
}
private static void testJacksonRecordReader(RecordReader rr) {
while (rr.hasNext()) {
List<Writable> json0 = rr.next();
// System.out.println(json0);
assert (json0.size() > 0);
}
}
@Test
public void testJacksonLineSequenceRecordReader() throws Exception {
File dir = testDir.newFolder();
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
.addField(new Text("MISSING_CX"), "c", "x").build();
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
File[] files = dir.listFiles();
Arrays.sort(files);
URI[] u = new URI[files.length];
for( int i=0; i<files.length; i++ ){
u[i] = files[i].toURI();
}
rr.initialize(new CollectionInputSplit(u));
List<List<Writable>> expSeq0 = new ArrayList<>();
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
List<List<Writable>> expSeq1 = new ArrayList<>();
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
int count = 0;
while(rr.hasNext()){
List<List<Writable>> next = rr.sequenceRecord();
if(count++ == 0){
assertEquals(expSeq0, next);
} else {
assertEquals(expSeq1, next);
}
}
assertEquals(2, count);
}
@DisplayName("Test Jackson Line Sequence Record Reader")
void testJacksonLineSequenceRecordReader(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
new ClassPathResource("datavec-api/JacksonLineSequenceRecordReaderTest/").copyDirectory(dir);
FieldSelection f = new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
JacksonLineSequenceRecordReader rr = new JacksonLineSequenceRecordReader(f, new ObjectMapper(new JsonFactory()));
File[] files = dir.listFiles();
Arrays.sort(files);
URI[] u = new URI[files.length];
for (int i = 0; i < files.length; i++) {
u[i] = files[i].toURI();
}
rr.initialize(new CollectionInputSplit(u));
List<List<Writable>> expSeq0 = new ArrayList<>();
expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")));
expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")));
List<List<Writable>> expSeq1 = new ArrayList<>();
expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3")));
int count = 0;
while (rr.hasNext()) {
List<List<Writable>> next = rr.sequenceRecord();
if (count++ == 0) {
assertEquals(expSeq0, next);
} else {
assertEquals(expSeq1, next);
}
}
assertEquals(2, count);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.io.labels.PathLabelGenerator;
@ -31,114 +30,95 @@ import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.dataformat.xml.XmlFactory;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@DisplayName("Jackson Record Reader Test")
class JacksonRecordReaderTest extends BaseND4JTest {
public class JacksonRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testReadingJson() throws Exception {
//Load 3 values from 3 JSON files
//stricture: a:value, b:value, c:x:value, c:y:value
//And we want to load only a:value, b:value and c:x:value
//For first JSON file: all values are present
//For second JSON file: b:value is missing
//For third JSON file: c:x:value is missing
@DisplayName("Test Reading Json")
void testReadingJson(@TempDir Path testDir) throws Exception {
// Load 3 values from 3 JSON files
// stricture: a:value, b:value, c:x:value, c:y:value
// And we want to load only a:value, b:value and c:x:value
// For first JSON file: all values are present
// For second JSON file: b:value is missing
// For third JSON file: c:x:value is missing
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()));
rr.initialize(is);
testJacksonRecordReader(rr);
}
@Test
public void testReadingYaml() throws Exception {
//Exact same information as JSON format, but in YAML format
@DisplayName("Test Reading Yaml")
void testReadingYaml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in YAML format
ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "yaml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new YAMLFactory()));
rr.initialize(is);
testJacksonRecordReader(rr);
}
@Test
public void testReadingXml() throws Exception {
//Exact same information as JSON format, but in XML format
@DisplayName("Test Reading Xml")
void testReadingXml(@TempDir Path testDir) throws Exception {
// Exact same information as JSON format, but in XML format
ClassPathResource cpr = new ClassPathResource("datavec-api/xml/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "xml_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new XmlFactory()));
rr.initialize(is);
testJacksonRecordReader(rr);
}
private static FieldSelection getFieldSelection() {
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b")
.addField(new Text("MISSING_CX"), "c", "x").build();
return new FieldSelection.Builder().addField("a").addField(new Text("MISSING_B"), "b").addField(new Text("MISSING_CX"), "c", "x").build();
}
private static void testJacksonRecordReader(RecordReader rr) {
List<Writable> json0 = rr.next();
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, json0);
List<Writable> json1 = rr.next();
List<Writable> exp1 =
Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, json1);
List<Writable> json2 = rr.next();
List<Writable> exp2 =
Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, json2);
assertFalse(rr.hasNext());
//Test reset
// Test reset
rr.reset();
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
@ -147,72 +127,50 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
}
@Test
public void testAppendingLabels() throws Exception {
@DisplayName("Test Appending Labels")
void testAppendingLabels(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
//Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
// Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
rr.initialize(is);
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"),
new IntWritable(0));
List<Writable> exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0));
assertEquals(exp0, rr.next());
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"),
new IntWritable(1));
List<Writable> exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1));
assertEquals(exp1, rr.next());
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"),
new IntWritable(2));
List<Writable> exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2));
assertEquals(exp2, rr.next());
//Insert at position 0:
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen(), 0);
// Insert at position 0:
rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen(), 0);
rr.initialize(is);
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"),
new Text("cxValue0"));
exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"));
assertEquals(exp0, rr.next());
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"),
new Text("cxValue1"));
exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"));
assertEquals(exp1, rr.next());
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"),
new Text("MISSING_CX"));
exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"));
assertEquals(exp2, rr.next());
}
@Test
public void testAppendingLabelsMetaData() throws Exception {
@DisplayName("Test Appending Labels Meta Data")
void testAppendingLabelsMetaData(@TempDir Path testDir) throws Exception {
ClassPathResource cpr = new ClassPathResource("datavec-api/json/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "json_test_%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 2);
//Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1,
new LabelGen());
// Insert at the end:
RecordReader rr = new JacksonRecordReader(getFieldSelection(), new ObjectMapper(new JsonFactory()), false, -1, new LabelGen());
rr.initialize(is);
List<List<Writable>> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.next());
}
assertEquals(3, out.size());
rr.reset();
List<List<Writable>> out2 = new ArrayList<>();
List<Record> outRecord = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
@ -222,14 +180,12 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
outRecord.add(r);
meta.add(r.getMetaData());
}
assertEquals(out, out2);
List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(outRecord, fromMeta);
}
@DisplayName("Label Gen")
private static class LabelGen implements PathLabelGenerator {
@Override
@ -252,5 +208,4 @@ public class JacksonRecordReaderTest extends BaseND4JTest {
return true;
}
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.IOException;
import java.util.*;
import static org.datavec.api.records.reader.impl.misc.LibSvmRecordReader.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class LibSvmRecordReaderTest extends BaseND4JTest {
@DisplayName("Lib Svm Record Reader Test")
class LibSvmRecordReaderTest extends BaseND4JTest {
@Test
public void testBasicRecord() throws IOException, InterruptedException {
@DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -80,27 +66,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoAppendLabel() throws IOException, InterruptedException {
@DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -117,33 +91,17 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoLabel() throws IOException, InterruptedException {
@DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -160,33 +118,15 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMultioutputRecord() throws IOException, InterruptedException {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
// 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -202,51 +142,20 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
@Test
public void testMultilabelRecord() throws IOException, InterruptedException {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
@ -265,63 +174,24 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
}
@Test
public void testZeroBasedIndexing() throws IOException, InterruptedException {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO,
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO,
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO,
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
// Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
@ -336,87 +206,107 @@ public class LibSvmRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
@Test(expected = NoSuchElementException.class)
public void testNoSuchElementException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test No Such Element Exception")
void testNoSuchElementException() {
assertThrows(NoSuchElementException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
rr.next();
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void failedToSetNumFeaturesException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Failed To Set Num Features Exception")
void failedToSetNumFeaturesException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Multiabels Exception")
void testInconsistentNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void testInconsistentNumLabelsException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Label Index Exceeds Num Labels")
void testLabelIndexExceedsNumLabels() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void testInconsistentNumMultiabelsException() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.MULTILABEL, false);
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
});
}
@Test(expected = IndexOutOfBoundsException.class)
public void testFeatureIndexExceedsNumFeatures() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setInt(LibSvmRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testLabelIndexExceedsNumLabels() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setInt(LibSvmRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
@Test
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
void testZeroIndexLabelWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
LibSvmRecordReader rr = new LibSvmRecordReader();
Configuration config = new Configuration();
config.setBoolean(LibSvmRecordReader.APPEND_LABEL, true);
config.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
config.setBoolean(LibSvmRecordReader.MULTILABEL, true);
config.setInt(LibSvmRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
});
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.apache.commons.io.FileUtils;
@ -30,11 +29,10 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
@ -45,34 +43,31 @@ import java.util.Arrays;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Line Reader Test")
class LineReaderTest extends BaseND4JTest {
public class LineReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testLineReader() throws Exception {
File tmpdir = testDir.newFolder();
@DisplayName("Test Line Reader")
void testLineReader(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
if (tmpdir.exists())
tmpdir.delete();
tmpdir.mkdir();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader();
reader.initialize(split);
int count = 0;
List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) {
@ -81,34 +76,27 @@ public class LineReaderTest extends BaseND4JTest {
list.add(l);
count++;
}
assertEquals(9, count);
}
@Test
public void testLineReaderMetaData() throws Exception {
File tmpdir = testDir.newFolder();
@DisplayName("Test Line Reader Meta Data")
void testLineReaderMetaData(@TempDir Path tmpDir) throws Exception {
File tmpdir = tmpDir.toFile();
File tmp1 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp1.txt"));
File tmp2 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp2.txt"));
File tmp3 = new File(FilenameUtils.concat(tmpdir.getPath(), "tmp3.txt"));
FileUtils.writeLines(tmp1, Arrays.asList("1", "2", "3"));
FileUtils.writeLines(tmp2, Arrays.asList("4", "5", "6"));
FileUtils.writeLines(tmp3, Arrays.asList("7", "8", "9"));
InputSplit split = new FileSplit(tmpdir);
RecordReader reader = new LineRecordReader();
reader.initialize(split);
List<List<Writable>> list = new ArrayList<>();
while (reader.hasNext()) {
list.add(reader.next());
}
assertEquals(9, list.size());
List<List<Writable>> out2 = new ArrayList<>();
List<Record> out3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
@ -124,13 +112,10 @@ public class LineReaderTest extends BaseND4JTest {
assertEquals(uri, split.locations()[fileIdx]);
count++;
}
assertEquals(list, out2);
List<Record> fromMeta = reader.loadFromMetaData(meta);
assertEquals(out3, fromMeta);
//try: second line of second and third files only...
// try: second line of second and third files only...
List<RecordMetaData> subsetMeta = new ArrayList<>();
subsetMeta.add(meta.get(4));
subsetMeta.add(meta.get(7));
@ -141,27 +126,22 @@ public class LineReaderTest extends BaseND4JTest {
}
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
File tmpdir = testDir.newFolder();
@DisplayName("Test Line Reader With Input Stream Input Split")
void testLineReaderWithInputStreamInputSplit(@TempDir Path testDir) throws Exception {
File tmpdir = testDir.toFile();
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
os.flush();
os.close();
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
RecordReader reader = new LineRecordReader();
reader.initialize(split);
int count = 0;
while (reader.hasNext()) {
assertEquals(1, reader.next().size());
count++;
}
assertEquals(9, count);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.records.Record;
@ -33,44 +32,41 @@ import org.datavec.api.split.InputSplit;
import org.datavec.api.split.NumberedFileInputSplit;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@DisplayName("Regex Record Reader Test")
class RegexRecordReaderTest extends BaseND4JTest {
public class RegexRecordReaderTest extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testRegexLineRecordReader() throws Exception {
@DisplayName("Test Regex Line Record Reader")
void testRegexLineRecordReader() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"),
new Text("DEBUG"), new Text("First entry message!"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"),
new Text("INFO"), new Text("Second entry message!"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"),
new Text("WARN"), new Text("Third entry message!"));
List<Writable> exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"));
List<Writable> exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"));
List<Writable> exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"));
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
assertEquals(exp2, rr.next());
assertFalse(rr.hasNext());
//Test reset:
// Test reset:
rr.reset();
assertEquals(exp0, rr.next());
assertEquals(exp1, rr.next());
@ -79,74 +75,57 @@ public class RegexRecordReaderTest extends BaseND4JTest {
}
@Test
public void testRegexLineRecordReaderMeta() throws Exception {
@DisplayName("Test Regex Line Record Reader Meta")
void testRegexLineRecordReaderMeta() throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
RecordReader rr = new RegexLineRecordReader(regex, 1);
rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile()));
List<List<Writable>> list = new ArrayList<>();
while (rr.hasNext()) {
list.add(rr.next());
}
assertEquals(3, list.size());
List<Record> list2 = new ArrayList<>();
List<List<Writable>> list3 = new ArrayList<>();
List<RecordMetaData> meta = new ArrayList<>();
rr.reset();
int count = 1; //Start by skipping 1 line
// Start by skipping 1 line
int count = 1;
while (rr.hasNext()) {
Record r = rr.nextRecord();
list2.add(r);
list3.add(r.getRecord());
meta.add(r.getMetaData());
assertEquals(count++, ((RecordMetaDataLine) r.getMetaData()).getLineNumber());
}
List<Record> fromMeta = rr.loadFromMetaData(meta);
assertEquals(list, list3);
assertEquals(list2, fromMeta);
}
@Test
public void testRegexSequenceRecordReader() throws Exception {
@DisplayName("Test Regex Sequence Record Reader")
void testRegexSequenceRecordReader(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is);
List<List<Writable>> exp0 = new ArrayList<>();
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"),
new Text("First entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"),
new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"),
new Text("Third entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")));
exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")));
List<List<Writable>> exp1 = new ArrayList<>();
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"),
new Text("First entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"),
new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"),
new Text("Third entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!")));
exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!")));
assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord());
assertFalse(rr.hasNext());
//Test resetting:
// Test resetting:
rr.reset();
assertEquals(exp0, rr.sequenceRecord());
assertEquals(exp1, rr.sequenceRecord());
@ -154,24 +133,20 @@ public class RegexRecordReaderTest extends BaseND4JTest {
}
@Test
public void testRegexSequenceRecordReaderMeta() throws Exception {
@DisplayName("Test Regex Sequence Record Reader Meta")
void testRegexSequenceRecordReaderMeta(@TempDir Path testDir) throws Exception {
String regex = "(\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}\\.\\d{3}) (\\d+) ([A-Z]+) (.*)";
ClassPathResource cpr = new ClassPathResource("datavec-api/logtestdata/");
File f = testDir.newFolder();
File f = testDir.toFile();
cpr.copyDirectory(f);
String path = new File(f, "logtestfile%d.txt").getAbsolutePath();
InputSplit is = new NumberedFileInputSplit(path, 0, 1);
SequenceRecordReader rr = new RegexSequenceRecordReader(regex, 1);
rr.initialize(is);
List<List<List<Writable>>> out = new ArrayList<>();
while (rr.hasNext()) {
out.add(rr.sequenceRecord());
}
assertEquals(2, out.size());
List<List<List<Writable>>> out2 = new ArrayList<>();
List<SequenceRecord> out3 = new ArrayList<>();
@ -183,11 +158,8 @@ public class RegexRecordReaderTest extends BaseND4JTest {
out3.add(seqr);
meta.add(seqr.getMetaData());
}
List<SequenceRecord> fromMeta = rr.loadSequenceFromMetaData(meta);
assertEquals(out, out2);
assertEquals(out3, fromMeta);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import org.datavec.api.conf.Configuration;
@ -27,43 +26,30 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import java.io.IOException;
import java.util.*;
import static org.datavec.api.records.reader.impl.misc.SVMLightRecordReader.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class SVMLightRecordReaderTest extends BaseND4JTest {
@DisplayName("Svm Light Record Reader Test")
class SVMLightRecordReaderTest extends BaseND4JTest {
@Test
public void testBasicRecord() throws IOException, InterruptedException {
@DisplayName("Test Basic Record")
void testBasicRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2)));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33)));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -79,27 +65,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoAppendLabel() throws IOException, InterruptedException {
@DisplayName("Test No Append Label")
void testNoAppendLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 33
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -116,33 +90,17 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNoLabel() throws IOException, InterruptedException {
@DisplayName("Test No Label")
void testNoLabel() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO));
// 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5)));
// qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO));
// 1:1.0
correct.put(2, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
//
correct.put(3, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -159,33 +117,15 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testMultioutputRecord() throws IOException, InterruptedException {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 7 2.45,9 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
new IntWritable(7), new DoubleWritable(2.45),
new IntWritable(9)));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), new IntWritable(7), new DoubleWritable(2.45), new IntWritable(9)));
// 2,3,4 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
new IntWritable(2), new IntWritable(3),
new IntWritable(4)));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, new IntWritable(2), new IntWritable(3), new IntWritable(4)));
// 33,32.0,31.9
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
new IntWritable(33), new DoubleWritable(32.0),
new DoubleWritable(31.9)));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new IntWritable(33), new DoubleWritable(32.0), new DoubleWritable(31.9)));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -200,51 +140,20 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(i, correct.size());
}
@Test
public void testMultilabelRecord() throws IOException, InterruptedException {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
@ -262,63 +171,24 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testZeroBasedIndexing() throws IOException, InterruptedException {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws IOException, InterruptedException {
Map<Integer, List<Writable>> correct = new HashMap<>();
// 1,3 2:1 4:2 6:3 8:4 10:5
correct.put(0, Arrays.asList(ZERO,
ZERO, ONE,
ZERO, new DoubleWritable(2),
ZERO, new DoubleWritable(3),
ZERO, new DoubleWritable(4),
ZERO, new DoubleWritable(5),
LABEL_ZERO,
LABEL_ONE, LABEL_ZERO,
LABEL_ONE, LABEL_ZERO));
correct.put(0, Arrays.asList(ZERO, ZERO, ONE, ZERO, new DoubleWritable(2), ZERO, new DoubleWritable(3), ZERO, new DoubleWritable(4), ZERO, new DoubleWritable(5), LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ONE, LABEL_ZERO));
// 2 qid:42 1:0.1 2:2 6:6.6 8:80
correct.put(1, Arrays.asList(ZERO,
new DoubleWritable(0.1), new DoubleWritable(2),
ZERO, ZERO,
ZERO, new DoubleWritable(6.6),
ZERO, new DoubleWritable(80),
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ONE,
LABEL_ZERO, LABEL_ZERO));
correct.put(1, Arrays.asList(ZERO, new DoubleWritable(0.1), new DoubleWritable(2), ZERO, ZERO, ZERO, new DoubleWritable(6.6), ZERO, new DoubleWritable(80), ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ZERO, LABEL_ZERO));
// 1,2,4
correct.put(2, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ONE, LABEL_ONE,
LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO,
new DoubleWritable(1.0), ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
ZERO, ZERO,
LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO,
LABEL_ZERO, LABEL_ZERO));
correct.put(2, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ONE, LABEL_ONE, LABEL_ZERO, LABEL_ONE));
// 1:1.0
correct.put(3, Arrays.asList(ZERO, new DoubleWritable(1.0), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
//
correct.put(4, Arrays.asList(ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO, LABEL_ZERO));
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
// Zero-based indexing is default
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
config.setBoolean(SVMLightRecordReader.ZERO_BASED_LABEL_INDEXING, true);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 5);
@ -333,20 +203,19 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
}
@Test
public void testNextRecord() throws IOException, InterruptedException {
@DisplayName("Test Next Record")
void testNextRecord() throws IOException, InterruptedException {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.APPEND_LABEL, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
Record record = rr.nextRecord();
List<Writable> recordList = record.getRecord();
assertEquals(new DoubleWritable(1.0), recordList.get(1));
assertEquals(new DoubleWritable(3.0), recordList.get(5));
assertEquals(new DoubleWritable(4.0), recordList.get(7));
record = rr.nextRecord();
recordList = record.getRecord();
assertEquals(new DoubleWritable(0.1), recordList.get(0));
@ -354,82 +223,102 @@ public class SVMLightRecordReaderTest extends BaseND4JTest {
assertEquals(new DoubleWritable(80.0), recordList.get(7));
}
@Test(expected = NoSuchElementException.class)
public void testNoSuchElementException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test No Such Element Exception")
void testNoSuchElementException() {
assertThrows(NoSuchElementException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
rr.next();
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void failedToSetNumFeaturesException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Failed To Set Num Features Exception")
void failedToSetNumFeaturesException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Inconsistent Num Labels Exception")
void testInconsistentNumLabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Failed To Set Num Multiabels Exception")
void failedToSetNumMultiabelsException() {
assertThrows(UnsupportedOperationException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext()) rr.next();
});
}
@Test
@DisplayName("Test Feature Index Exceeds Num Features")
void testFeatureIndexExceedsNumFeatures() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void testInconsistentNumLabelsException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/inconsistentNumLabels.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Label Index Exceeds Num Labels")
void testLabelIndexExceedsNumLabels() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
});
}
@Test(expected = UnsupportedOperationException.class)
public void failedToSetNumMultiabelsException() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile()));
while (rr.hasNext())
@Test
@DisplayName("Test Zero Index Feature Without Using Zero Indexing")
void testZeroIndexFeatureWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
});
}
@Test(expected = IndexOutOfBoundsException.class)
public void testFeatureIndexExceedsNumFeatures() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 9);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testLabelIndexExceedsNumLabels() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setInt(SVMLightRecordReader.NUM_LABELS, 6);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/basic.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexFeatureWithoutUsingZeroIndexing() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexFeature.txt").getFile()));
rr.next();
}
@Test(expected = IndexOutOfBoundsException.class)
public void testZeroIndexLabelWithoutUsingZeroIndexing() throws Exception {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
@Test
@DisplayName("Test Zero Index Label Without Using Zero Indexing")
void testZeroIndexLabelWithoutUsingZeroIndexing() {
assertThrows(IndexOutOfBoundsException.class, () -> {
SVMLightRecordReader rr = new SVMLightRecordReader();
Configuration config = new Configuration();
config.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
config.setBoolean(SVMLightRecordReader.MULTILABEL, true);
config.setInt(SVMLightRecordReader.NUM_LABELS, 2);
rr.initialize(config, new FileSplit(new ClassPathResource("datavec-api/svmlight/zeroIndexLabel.txt").getFile()));
rr.next();
});
}
}

View File

@ -26,14 +26,14 @@ import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestCollectionRecordReaders extends BaseND4JTest {

View File

@ -23,11 +23,11 @@ package org.datavec.api.records.reader.impl;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestConcatenatingRecordReader extends BaseND4JTest {

View File

@ -37,7 +37,7 @@ import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.shade.jackson.core.JsonFactory;
@ -47,7 +47,7 @@ import java.io.*;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSerialization extends BaseND4JTest {

View File

@ -30,7 +30,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -38,8 +38,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TransformProcessRecordReaderTests extends BaseND4JTest {

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.writer.impl;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
@ -26,44 +25,42 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class CSVRecordWriterTest extends BaseND4JTest {
@Before
public void setUp() throws Exception {
@DisplayName("Csv Record Writer Test")
class CSVRecordWriterTest extends BaseND4JTest {
@BeforeEach
void setUp() throws Exception {
}
@Test
public void testWrite() throws Exception {
@DisplayName("Test Write")
void testWrite() throws Exception {
File tempFile = File.createTempFile("datavec", "writer");
tempFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(tempFile);
CSVRecordWriter writer = new CSVRecordWriter();
writer.initialize(fileSplit,new NumberOfRecordsPartitioner());
writer.initialize(fileSplit, new NumberOfRecordsPartitioner());
List<Writable> collection = new ArrayList<>();
collection.add(new Text("12"));
collection.add(new Text("13"));
collection.add(new Text("14"));
writer.write(collection);
CSVRecordReader reader = new CSVRecordReader(0);
reader.initialize(new FileSplit(tempFile));
int cnt = 0;
while (reader.hasNext()) {
List<Writable> line = new ArrayList<>(reader.next());
assertEquals(3, line.size());
assertEquals(12, line.get(0).toInt());
assertEquals(13, line.get(1).toInt());
assertEquals(14, line.get(2).toInt());

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils;
@ -30,93 +29,90 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals;
public class LibSvmRecordWriterTest extends BaseND4JTest {
@DisplayName("Lib Svm Record Writer Test")
class LibSvmRecordWriterTest extends BaseND4JTest {
@Test
public void testBasic() throws Exception {
@DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration();
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testNoLabel() throws Exception {
@DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultioutputRecord() throws Exception {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultilabelRecord() throws Exception {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 4);
configReader.setBoolean(LibSvmRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testZeroBasedIndexing() throws Exception {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(LibSvmRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(LibSvmRecordReader.MULTILABEL, true);
configReader.setInt(LibSvmRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@ -127,10 +123,9 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
LibSvmRecordReader rr = new LibSvmRecordReader();
rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) {
@ -138,7 +133,6 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
writer.write(record);
}
}
Pattern p = Pattern.compile(String.format("%s:\\d+ ", LibSvmRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) {
@ -159,7 +153,8 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
}
@Test
public void testNDArrayWritables() throws Exception {
@DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -167,35 +162,28 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13);
arr3.putScalar(1, 14);
arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesMultilabel() throws Exception {
@DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -203,36 +191,29 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
@DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -240,99 +221,91 @@ public class LibSvmRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(LibSvmRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.0));
@DisplayName("Test Non Integer But Valid Multilabel")
void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
}
@Test(expected = NumberFormatException.class)
public void nonIntegerMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Integer Multilabel")
void nonIntegerMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
@Test(expected = NumberFormatException.class)
public void nonBinaryMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
new IntWritable(1),
new IntWritable(2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN,0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN,1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL,true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Binary Multilabel")
void nonBinaryMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (LibSvmRecordWriter writer = new LibSvmRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(LibSvmRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(LibSvmRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(LibSvmRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.writer.impl;
import org.apache.commons.io.FileUtils;
@ -27,93 +26,90 @@ import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.Assert.assertEquals;
public class SVMLightRecordWriterTest extends BaseND4JTest {
@DisplayName("Svm Light Record Writer Test")
class SVMLightRecordWriterTest extends BaseND4JTest {
@Test
public void testBasic() throws Exception {
@DisplayName("Test Basic")
void testBasic() throws Exception {
Configuration configWriter = new Configuration();
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/basic.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testNoLabel() throws Exception {
@DisplayName("Test No Label")
void testNoLabel() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/noLabels.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultioutputRecord() throws Exception {
@DisplayName("Test Multioutput Record")
void testMultioutputRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multioutput.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testMultilabelRecord() throws Exception {
@DisplayName("Test Multilabel Record")
void testMultilabelRecord() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 9);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 10);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 4);
configReader.setBoolean(SVMLightRecordReader.ZERO_BASED_INDEXING, false);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@Test
public void testZeroBasedIndexing() throws Exception {
@DisplayName("Test Zero Based Indexing")
void testZeroBasedIndexing() throws Exception {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 10);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
Configuration configReader = new Configuration();
configReader.setInt(SVMLightRecordReader.NUM_FEATURES, 11);
configReader.setBoolean(SVMLightRecordReader.MULTILABEL, true);
configReader.setInt(SVMLightRecordReader.NUM_LABELS, 5);
File inputFile = new ClassPathResource("datavec-api/svmlight/multilabel.txt").getFile();
executeTest(configWriter, configReader, inputFile);
}
@ -124,10 +120,9 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
SVMLightRecordReader rr = new SVMLightRecordReader();
rr.initialize(configReader, new FileSplit(inputFile));
while (rr.hasNext()) {
@ -135,7 +130,6 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
writer.write(record);
}
}
Pattern p = Pattern.compile(String.format("%s:\\d+ ", SVMLightRecordReader.QID_PREFIX));
List<String> linesOriginal = new ArrayList<>();
for (String line : FileUtils.readLines(inputFile)) {
@ -156,7 +150,8 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
}
@Test
public void testNDArrayWritables() throws Exception {
@DisplayName("Test ND Array Writables")
void testNDArrayWritables() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -164,35 +159,28 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 13);
arr3.putScalar(1, 14);
arr3.putScalar(2, 15);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new IntWritable(4));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new IntWritable(4));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "13.0,14.0,15.0,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesMultilabel() throws Exception {
@DisplayName("Test ND Array Writables Multilabel")
void testNDArrayWritablesMultilabel() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -200,36 +188,29 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "2,4 1:1.0 2:11.0 3:12.0 4:2.0 5:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNDArrayWritablesZeroIndex() throws Exception {
@DisplayName("Test ND Array Writables Zero Index")
void testNDArrayWritablesZeroIndex() throws Exception {
INDArray arr2 = Nd4j.zeros(2);
arr2.putScalar(0, 11);
arr2.putScalar(1, 12);
@ -237,99 +218,91 @@ public class SVMLightRecordWriterTest extends BaseND4JTest {
arr3.putScalar(0, 0);
arr3.putScalar(1, 1);
arr3.putScalar(2, 0);
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1),
new NDArrayWritable(arr2),
new IntWritable(2),
new DoubleWritable(3),
new NDArrayWritable(arr3),
new DoubleWritable(1));
List<Writable> record = Arrays.asList((Writable) new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), new NDArrayWritable(arr3), new DoubleWritable(1));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
String lineOriginal = "1,3 0:1.0 1:11.0 2:12.0 3:2.0 4:3.0";
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true); // NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true); // NOT STANDARD!
// NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_INDEXING, true);
// NOT STANDARD!
configWriter.setBoolean(SVMLightRecordWriter.ZERO_BASED_LABEL_INDEXING, true);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 3);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
String lineNew = FileUtils.readFileToString(tempFile).trim();
assertEquals(lineOriginal, lineNew);
}
@Test
public void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.0));
@DisplayName("Test Non Integer But Valid Multilabel")
void testNonIntegerButValidMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
}
@Test(expected = NumberFormatException.class)
public void nonIntegerMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3),
new IntWritable(2),
new DoubleWritable(1.2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Integer Multilabel")
void nonIntegerMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
@Test(expected = NumberFormatException.class)
public void nonBinaryMultilabel() throws Exception {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0),
new IntWritable(1),
new IntWritable(2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter,outputSplit,new NumberOfRecordsPartitioner());
writer.write(record);
}
@Test
@DisplayName("Non Binary Multilabel")
void nonBinaryMultilabel() {
assertThrows(NumberFormatException.class, () -> {
List<Writable> record = Arrays.asList((Writable) new IntWritable(0), new IntWritable(1), new IntWritable(2));
File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt");
tempFile.setWritable(true);
tempFile.deleteOnExit();
if (tempFile.exists())
tempFile.delete();
try (SVMLightRecordWriter writer = new SVMLightRecordWriter()) {
Configuration configWriter = new Configuration();
configWriter.setInt(SVMLightRecordWriter.FEATURE_FIRST_COLUMN, 0);
configWriter.setInt(SVMLightRecordWriter.FEATURE_LAST_COLUMN, 1);
configWriter.setBoolean(SVMLightRecordWriter.MULTILABEL, true);
FileSplit outputSplit = new FileSplit(tempFile);
writer.initialize(configWriter, outputSplit, new NumberOfRecordsPartitioner());
writer.write(record);
}
});
}
}

View File

@ -26,7 +26,7 @@ import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.io.labels.PatternPathLabelGenerator;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.io.*;
import java.net.URI;
@ -34,8 +34,9 @@ import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Random;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
*

View File

@ -20,13 +20,12 @@
package org.datavec.api.split;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.*;
public class NumberedFileInputSplitTests extends BaseND4JTest {
@Test
@ -69,60 +68,81 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithLeadingSpaces() {
String baseString = "/path/to/files/prefix-%5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix-%5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithNoLeadingZeroInPadding() {
String baseString = "/path/to/files/prefix%5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
assertThrows(IllegalArgumentException.class, () -> {
String baseString = "/path/to/files/prefix%5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithLeadingPlusInPadding() {
String baseString = "/path/to/files/prefix%+5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%+5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithLeadingMinusInPadding() {
String baseString = "/path/to/files/prefix%-5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%-5d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithTwoDigitsInPadding() {
String baseString = "/path/to/files/prefix%011d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%011d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithInnerZerosInPadding() {
String baseString = "/path/to/files/prefix%101d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%101d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testNumberedFileInputSplitWithRepeatInnerZerosInPadding() {
String baseString = "/path/to/files/prefix%0505d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
assertThrows(IllegalArgumentException.class,() -> {
String baseString = "/path/to/files/prefix%0505d.suffix";
int minIdx = 0;
int maxIdx = 10;
runNumberedFileInputSplitTest(baseString, minIdx, maxIdx);
});
}
@ -135,7 +155,7 @@ public class NumberedFileInputSplitTests extends BaseND4JTest {
String path = locs[j++].getPath();
String exp = String.format(baseString, i);
String msg = exp + " vs " + path;
assertTrue(msg, path.endsWith(exp)); //Note: on Windows, Java can prepend drive to path - "/C:/"
assertTrue(path.endsWith(exp),msg); //Note: on Windows, Java can prepend drive to path - "/C:/"
}
}
}

View File

@ -25,9 +25,10 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.function.Function;
@ -37,22 +38,22 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
public class TestStreamInputSplit extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testCsvSimple() throws Exception {
File dir = testDir.newFolder();
public void testCsvSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt");
@ -93,9 +94,9 @@ public class TestStreamInputSplit extends BaseND4JTest {
@Test
public void testCsvSequenceSimple() throws Exception {
public void testCsvSequenceSimple(@TempDir Path testDir) throws Exception {
File dir = testDir.newFolder();
File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt");
@ -137,8 +138,8 @@ public class TestStreamInputSplit extends BaseND4JTest {
}
@Test
public void testShuffle() throws Exception {
File dir = testDir.newFolder();
public void testShuffle(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
File f1 = new File(dir, "file1.txt");
File f2 = new File(dir, "file2.txt");
File f3 = new File(dir, "file3.txt");

View File

@ -17,44 +17,43 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.split;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collection;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
/**
* @author Ede Meijer
*/
public class TransformSplitTest extends BaseND4JTest {
@Test
public void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
@DisplayName("Transform Split Test")
class TransformSplitTest extends BaseND4JTest {
@Test
@DisplayName("Test Transform")
void testTransform() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv"));
InputSplit SUT = new TransformSplit(new CollectionInputSplit(inputFiles), new TransformSplit.URITransform() {
@Override
public URI apply(URI uri) throws URISyntaxException {
return uri.normalize();
}
});
assertArrayEquals(new URI[] {new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv")}, SUT.locations());
assertArrayEquals(new URI[] { new URI("file:///foo/0.csv"), new URI("file:///foo/1.csv") }, SUT.locations());
}
@Test
public void testSearchReplace() throws URISyntaxException {
@DisplayName("Test Search Replace")
void testSearchReplace() throws URISyntaxException {
Collection<URI> inputFiles = asList(new URI("file:///foo/1-in.csv"), new URI("file:///foo/2-in.csv"));
InputSplit SUT = TransformSplit.ofSearchReplace(new CollectionInputSplit(inputFiles), "-in.csv", "-out.csv");
assertArrayEquals(new URI[] {new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv")},
SUT.locations());
assertArrayEquals(new URI[] { new URI("file:///foo/1-out.csv"), new URI("file:///foo/2-out.csv") }, SUT.locations());
}
}

View File

@ -27,14 +27,12 @@ import org.datavec.api.split.FileSplit;
import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.split.partition.PartitionMetaData;
import org.datavec.api.split.partition.Partitioner;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.OutputStream;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.jupiter.api.Assertions.*;
public class PartitionerTests extends BaseND4JTest {
@Test

View File

@ -29,12 +29,12 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestTransformProcess extends BaseND4JTest {

View File

@ -27,13 +27,13 @@ import org.datavec.api.transform.condition.string.StringRegexColumnCondition;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.TestTransforms;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConditions extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -36,8 +36,8 @@ import java.util.Collections;
import java.util.List;
import static java.util.Arrays.asList;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestFilters extends BaseND4JTest {

View File

@ -26,19 +26,22 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class TestJoin extends BaseND4JTest {
@Test
public void testJoin() {
public void testJoin(@TempDir Path testDir) {
Schema firstSchema =
new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("first0", "first1").build();
@ -46,20 +49,20 @@ public class TestJoin extends BaseND4JTest {
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build();
List<List<Writable>> first = new ArrayList<>();
first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1)));
first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11)));
first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1)));
first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11)));
List<List<Writable>> second = new ArrayList<>();
second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100)));
second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110)));
second.add(Arrays.asList(new Text("key0"), new IntWritable(100)));
second.add(Arrays.asList(new Text("key1"), new IntWritable(110)));
Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn")
.setSchemas(firstSchema, secondSchema).build();
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1),
expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1),
new IntWritable(100)));
expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11),
expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11),
new IntWritable(110)));
@ -94,27 +97,31 @@ public class TestJoin extends BaseND4JTest {
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testJoinValidation() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build();
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build();
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
.setSchemas(firstSchema, secondSchema).build();
});
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1", "thisDoesntExist")
.setSchemas(firstSchema, secondSchema).build();
}
@Test(expected = IllegalArgumentException.class)
@Test()
public void testJoinValidation2() {
assertThrows(IllegalArgumentException.class,() -> {
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build();
Schema firstSchema = new Schema.Builder().addColumnString("keyColumn1").addColumnsInteger("first0", "first1")
.build();
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
Schema secondSchema = new Schema.Builder().addColumnString("keyColumn2").addColumnsInteger("second0").build();
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
.build();
});
new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn1").setSchemas(firstSchema, secondSchema)
.build();
}
}

View File

@ -17,32 +17,25 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import com.tngtech.archunit.core.importer.ImportOption;
import com.tngtech.archunit.junit.AnalyzeClasses;
import com.tngtech.archunit.junit.ArchTest;
import com.tngtech.archunit.junit.ArchUnitRunner;
import com.tngtech.archunit.lang.ArchRule;
import com.tngtech.archunit.lang.extension.ArchUnitExtension;
import com.tngtech.archunit.lang.extension.ArchUnitExtensions;
import org.junit.runner.RunWith;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.Serializable;
import static com.tngtech.archunit.lang.syntax.ArchRuleDefinition.classes;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
@RunWith(ArchUnitRunner.class)
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = {ImportOption.DoNotIncludeTests.class})
public class AggregableMultiOpArchTest extends BaseND4JTest {
@AnalyzeClasses(packages = "org.datavec.api.transform.ops", importOptions = { ImportOption.DoNotIncludeTests.class })
@DisplayName("Aggregable Multi Op Arch Test")
class AggregableMultiOpArchTest extends BaseND4JTest {
@ArchTest
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes()
.that().resideInAPackage("org.datavec.api.transform.ops")
.and().doNotHaveSimpleName("AggregatorImpls")
.and().doNotHaveSimpleName("IAggregableReduceOp")
.and().doNotHaveSimpleName("StringAggregatorImpls")
.and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1")
.should().implement(Serializable.class)
.because("All aggregate ops must be serializable.");
}
public static final ArchRule ALL_AGGREGATE_OPS_MUST_BE_SERIALIZABLE = classes().that().resideInAPackage("org.datavec.api.transform.ops").and().doNotHaveSimpleName("AggregatorImpls").and().doNotHaveSimpleName("IAggregableReduceOp").and().doNotHaveSimpleName("StringAggregatorImpls").and().doNotHaveFullyQualifiedName("org.datavec.api.transform.ops.StringAggregatorImpls$1").should().implement(Serializable.class).because("All aggregate ops must be serializable.");
}

View File

@ -17,52 +17,46 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue;
public class AggregableMultiOpTest extends BaseND4JTest {
@DisplayName("Aggregable Multi Op Test")
class AggregableMultiOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
@Test
public void testMulti() throws Exception {
@DisplayName("Test Multi")
void testMulti() throws Exception {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
assertTrue(multi.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
multi.accept(intList.get(i));
}
// mutablility
assertTrue(as.get().toDouble() == 45D);
assertTrue(af.get().toInt() == 1);
List<Writable> res = multi.get();
assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1);
AggregatorImpls.AggregableFirst<Integer> rf = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> rs = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> reverse = new AggregableMultiOp<>(Arrays.asList(rf, rs));
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
List<Writable> revRes = reverse.get();
assertTrue(revRes.get(1).toDouble() == 45D);
assertTrue(revRes.get(0).toInt() == 9);
multi.combine(reverse);
List<Writable> combinedRes = multi.get();
assertTrue(combinedRes.get(1).toDouble() == 90D);

View File

@ -17,41 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.nd4j.common.tests.BaseND4JTest;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.junit.jupiter.api.DisplayName;
public class AggregatorImplsTest extends BaseND4JTest {
import static org.junit.jupiter.api.Assertions.*;
@DisplayName("Aggregator Impls Test")
class AggregatorImplsTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test
public void aggregableFirstTest() {
@DisplayName("Aggregable First Test")
void aggregableFirstTest() {
AggregatorImpls.AggregableFirst<Integer> first = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) {
first.accept(intList.get(i));
}
assertEquals(1, first.get().toInt());
AggregatorImpls.AggregableFirst<String> firstS = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < stringList.size(); i++) {
firstS.accept(stringList.get(i));
}
assertTrue(firstS.get().toString().equals("arakoa"));
AggregatorImpls.AggregableFirst<Integer> reverse = new AggregatorImpls.AggregableFirst<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -60,22 +58,19 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(1, first.get().toInt());
}
@Test
public void aggregableLastTest() {
@DisplayName("Aggregable Last Test")
void aggregableLastTest() {
AggregatorImpls.AggregableLast<Integer> last = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) {
last.accept(intList.get(i));
}
assertEquals(9, last.get().toInt());
AggregatorImpls.AggregableLast<String> lastS = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i));
}
assertTrue(lastS.get().toString().equals("acceptance"));
AggregatorImpls.AggregableLast<Integer> reverse = new AggregatorImpls.AggregableLast<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -85,20 +80,18 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableCountTest() {
@DisplayName("Aggregable Count Test")
void aggregableCountTest() {
AggregatorImpls.AggregableCount<Integer> cnt = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) {
cnt.accept(intList.get(i));
}
assertEquals(9, cnt.get().toInt());
AggregatorImpls.AggregableCount<String> lastS = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < stringList.size(); i++) {
lastS.accept(stringList.get(i));
}
assertEquals(4, lastS.get().toInt());
AggregatorImpls.AggregableCount<Integer> reverse = new AggregatorImpls.AggregableCount<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -108,14 +101,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableMaxTest() {
@DisplayName("Aggregable Max Test")
void aggregableMaxTest() {
AggregatorImpls.AggregableMax<Integer> mx = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i));
}
assertEquals(9, mx.get().toInt());
AggregatorImpls.AggregableMax<Integer> reverse = new AggregatorImpls.AggregableMax<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -124,16 +116,14 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, mx.get().toInt());
}
@Test
public void aggregableRangeTest() {
@DisplayName("Aggregable Range Test")
void aggregableRangeTest() {
AggregatorImpls.AggregableRange<Integer> mx = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) {
mx.accept(intList.get(i));
}
assertEquals(8, mx.get().toInt());
AggregatorImpls.AggregableRange<Integer> reverse = new AggregatorImpls.AggregableRange<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1) + 9);
@ -143,14 +133,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableMinTest() {
@DisplayName("Aggregable Min Test")
void aggregableMinTest() {
AggregatorImpls.AggregableMin<Integer> mn = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i));
}
assertEquals(1, mn.get().toInt());
AggregatorImpls.AggregableMin<Integer> reverse = new AggregatorImpls.AggregableMin<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -160,14 +149,13 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableSumTest() {
@DisplayName("Aggregable Sum Test")
void aggregableSumTest() {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
}
assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableSum<Integer> reverse = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -176,17 +164,15 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(90, sm.get().toInt());
}
@Test
public void aggregableMeanTest() {
@DisplayName("Aggregable Mean Test")
void aggregableMeanTest() {
AggregatorImpls.AggregableMean<Integer> mn = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
mn.accept(intList.get(i));
}
assertEquals(9l, (long) mn.getCount());
assertEquals(5D, mn.get().toDouble(), 0.001);
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -197,80 +183,73 @@ public class AggregatorImplsTest extends BaseND4JTest {
}
@Test
public void aggregableStdDevTest() {
@DisplayName("Aggregable Std Dev Test")
void aggregableStdDevTest() {
AggregatorImpls.AggregableStdDev<Integer> sd = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 2.7386) < 0.0001);
AggregatorImpls.AggregableStdDev<Integer> reverse = new AggregatorImpls.AggregableStdDev<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8787) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 1.8787) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregableVariance() {
@DisplayName("Aggregable Variance")
void aggregableVariance() {
AggregatorImpls.AggregableVariance<Integer> sd = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 60D / 8) < 0.0001);
AggregatorImpls.AggregableVariance<Integer> reverse = new AggregatorImpls.AggregableVariance<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 3.5294) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 3.5294) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregableUncorrectedStdDevTest() {
@DisplayName("Aggregable Uncorrected Std Dev Test")
void aggregableUncorrectedStdDevTest() {
AggregatorImpls.AggregableUncorrectedStdDev<Integer> sd = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 2.582) < 0.0001);
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse =
new AggregatorImpls.AggregableUncorrectedStdDev<>();
AggregatorImpls.AggregableUncorrectedStdDev<Integer> reverse = new AggregatorImpls.AggregableUncorrectedStdDev<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 1.8257) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 1.8257) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregablePopulationVariance() {
@DisplayName("Aggregable Population Variance")
void aggregablePopulationVariance() {
AggregatorImpls.AggregablePopulationVariance<Integer> sd = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) {
sd.accept(intList.get(i));
}
assertTrue(Math.abs(sd.get().toDouble() - 60D / 9) < 0.0001);
AggregatorImpls.AggregablePopulationVariance<Integer> reverse =
new AggregatorImpls.AggregablePopulationVariance<>();
AggregatorImpls.AggregablePopulationVariance<Integer> reverse = new AggregatorImpls.AggregablePopulationVariance<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sd.combine(reverse);
assertTrue("" + sd.get().toDouble(), Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001);
assertTrue(Math.abs(sd.get().toDouble() - 30D / 9) < 0.0001,"" + sd.get().toDouble());
}
@Test
public void aggregableCountUniqueTest() {
@DisplayName("Aggregable Count Unique Test")
void aggregableCountUniqueTest() {
// at this low range, it's linear counting
AggregatorImpls.AggregableCountUnique<Integer> cu = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) {
cu.accept(intList.get(i));
@ -278,7 +257,6 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt());
cu.accept(1);
assertEquals(9, cu.get().toInt());
AggregatorImpls.AggregableCountUnique<Integer> reverse = new AggregatorImpls.AggregableCountUnique<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
@ -287,26 +265,25 @@ public class AggregatorImplsTest extends BaseND4JTest {
assertEquals(9, cu.get().toInt());
}
@Rule
public final ExpectedException exception = ExpectedException.none();
@Test
public void incompatibleAggregatorTest() {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
}
assertEquals(45, sm.get().toInt());
@DisplayName("Incompatible Aggregator Test")
void incompatibleAggregatorTest() {
assertThrows(UnsupportedOperationException.class,() -> {
AggregatorImpls.AggregableSum<Integer> sm = new AggregatorImpls.AggregableSum<>();
for (int i = 0; i < intList.size(); i++) {
sm.accept(intList.get(i));
}
assertEquals(45, sm.get().toInt());
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
sm.combine(reverse);
assertEquals(45, sm.get().toInt());
});
AggregatorImpls.AggregableMean<Integer> reverse = new AggregatorImpls.AggregableMean<>();
for (int i = 0; i < intList.size(); i++) {
reverse.accept(intList.get(intList.size() - i - 1));
}
exception.expect(UnsupportedOperationException.class);
sm.combine(reverse);
assertEquals(45, sm.get().toInt());
}
}

View File

@ -17,77 +17,65 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.ops;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertTrue;
public class DispatchOpTest extends BaseND4JTest {
@DisplayName("Dispatch Op Test")
class DispatchOpTest extends BaseND4JTest {
private List<Integer> intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
private List<String> stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance"));
@Test
public void testDispatchSimple() {
@DisplayName("Test Dispatch Simple")
void testDispatchSimple() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multiaf =
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
AggregableMultiOp<Integer> multias =
new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
DispatchOp<Integer, Writable> parallel =
new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
AggregableMultiOp<Integer> multiaf = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(af));
AggregableMultiOp<Integer> multias = new AggregableMultiOp<>(Collections.<IAggregableReduceOp<Integer, Writable>>singletonList(as));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multiaf, multias));
assertTrue(multiaf.getOperations().size() == 1);
assertTrue(multias.getOperations().size() == 1);
assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
}
List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1);
}
@Test
public void testDispatchFlatMap() {
@DisplayName("Test Dispatch Flat Map")
void testDispatchFlatMap() {
AggregatorImpls.AggregableFirst<Integer> af = new AggregatorImpls.AggregableFirst<>();
AggregatorImpls.AggregableSum<Integer> as = new AggregatorImpls.AggregableSum<>();
AggregableMultiOp<Integer> multi = new AggregableMultiOp<>(Arrays.asList(af, as));
AggregatorImpls.AggregableLast<Integer> al = new AggregatorImpls.AggregableLast<>();
AggregatorImpls.AggregableMax<Integer> amax = new AggregatorImpls.AggregableMax<>();
AggregableMultiOp<Integer> otherMulti = new AggregableMultiOp<>(Arrays.asList(al, amax));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(
Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
DispatchOp<Integer, Writable> parallel = new DispatchOp<>(Arrays.<IAggregableReduceOp<Integer, List<Writable>>>asList(multi, otherMulti));
assertTrue(multi.getOperations().size() == 2);
assertTrue(otherMulti.getOperations().size() == 2);
assertTrue(parallel.getOperations().size() == 2);
for (int i = 0; i < intList.size(); i++) {
parallel.accept(Arrays.asList(intList.get(i), intList.get(i)));
}
List<Writable> res = parallel.get();
assertTrue(res.get(1).toDouble() == 45D);
assertTrue(res.get(0).toInt() == 1);
assertTrue(res.get(3).toDouble() == 9);
assertTrue(res.get(2).toInt() == 9);
}
}

View File

@ -32,13 +32,14 @@ import org.datavec.api.transform.ops.AggregableMultiOp;
import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
public class TestMultiOpReduce extends BaseND4JTest {
@ -46,10 +47,10 @@ public class TestMultiOpReduce extends BaseND4JTest {
public void testMultiOpReducerDouble() {
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2)));
Map<ReduceOp, Double> exp = new LinkedHashMap<>();
exp.put(ReduceOp.Min, 0.0);
@ -82,7 +83,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
}
}
@ -126,19 +127,20 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
}
}
@Test
@Disabled
public void testReduceString() {
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3")));
inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4")));
inputs.add(Arrays.asList(new Text("someKey"), new Text("1")));
inputs.add(Arrays.asList(new Text("someKey"), new Text("2")));
inputs.add(Arrays.asList(new Text("someKey"), new Text("3")));
inputs.add(Arrays.asList(new Text("someKey"), new Text("4")));
Map<ReduceOp, String> exp = new LinkedHashMap<>();
exp.put(ReduceOp.Append, "1234");
@ -210,7 +212,7 @@ public class TestMultiOpReduce extends BaseND4JTest {
assertEquals(out.get(0), new Text("someKey"));
String msg = op.toString();
assertEquals(msg, exp.get(op), out.get(1).toDouble(), 1e-5);
assertEquals(exp.get(op), out.get(1).toDouble(), 1e-5,msg);
}
for (ReduceOp op : Arrays.asList(ReduceOp.Min, ReduceOp.Max, ReduceOp.Range, ReduceOp.Sum, ReduceOp.Mean,

View File

@ -24,13 +24,13 @@ import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReductions extends BaseND4JTest {

View File

@ -22,10 +22,10 @@ package org.datavec.api.transform.schema;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest {

View File

@ -21,10 +21,10 @@
package org.datavec.api.transform.schema;
import org.datavec.api.transform.ColumnType;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSchemaMethods extends BaseND4JTest {

View File

@ -33,7 +33,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -41,7 +41,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduceSequenceByWindowFunction extends BaseND4JTest {

View File

@ -27,7 +27,7 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -35,7 +35,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestSequenceSplit extends BaseND4JTest {

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
@ -37,7 +37,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWindowFunctions extends BaseND4JTest {

View File

@ -26,10 +26,10 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.testClasses.CustomCondition;
import org.datavec.api.transform.serde.testClasses.CustomFilter;
import org.datavec.api.transform.serde.testClasses.CustomTransform;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCustomTransformJsonYaml extends BaseND4JTest {

View File

@ -64,13 +64,13 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -24,12 +24,12 @@ import org.datavec.api.transform.StringReduceOp;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestReduce extends BaseND4JTest {

View File

@ -50,7 +50,7 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.io.ClassPathResource;
@ -61,7 +61,7 @@ import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class RegressionTestJson extends BaseND4JTest {

View File

@ -50,13 +50,13 @@ import org.datavec.api.writable.Text;
import org.datavec.api.writable.comparator.LongWritableComparator;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJsonYaml extends BaseND4JTest {

View File

@ -58,8 +58,8 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform;
import org.datavec.api.writable.*;
import org.joda.time.DateTimeFieldType;
import org.joda.time.DateTimeZone;
import org.junit.Assert;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -71,8 +71,8 @@ import java.io.ObjectOutputStream;
import java.util.*;
import java.util.concurrent.TimeUnit;
import static junit.framework.TestCase.assertEquals;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestTransforms extends BaseND4JTest {
@ -277,22 +277,22 @@ public class TestTransforms extends BaseND4JTest {
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
outputColumns.add(NEW_COLUMN);
Schema newSchema = transform.transform(schema);
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
assertEquals(outputColumns, newSchema.getColumnNames());
List<Writable> input = new ArrayList<>();
input.addAll(COLUMN_VALUES);
transform.setInputSchema(schema);
List<Writable> transformed = transform.map(input);
Assert.assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
assertEquals(NEW_COLUMN_VALUE, transformed.get(transformed.size() - 1).toString());
List<Text> outputColumnValues = new ArrayList<>(COLUMN_VALUES);
outputColumnValues.add(new Text(NEW_COLUMN_VALUE));
Assert.assertEquals(outputColumnValues, transformed);
assertEquals(outputColumnValues, transformed);
String s = JsonMappers.getMapper().writeValueAsString(transform);
Transform transform2 = JsonMappers.getMapper().readValue(s, ConcatenateStringColumns.class);
Assert.assertEquals(transform, transform2);
assertEquals(transform, transform2);
}
@Test
@ -309,7 +309,7 @@ public class TestTransforms extends BaseND4JTest {
transform.setInputSchema(schema);
Schema newSchema = transform.transform(schema);
List<String> outputColumns = new ArrayList<>(ALL_COLUMNS);
Assert.assertEquals(outputColumns, newSchema.getColumnNames());
assertEquals(outputColumns, newSchema.getColumnNames());
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.LOWER);
transform.setInputSchema(schema);
@ -320,8 +320,8 @@ public class TestTransforms extends BaseND4JTest {
output.add(new Text(TEXT_LOWER_CASE));
output.add(new Text(TEXT_MIXED_CASE));
List<Writable> transformed = transform.map(input);
Assert.assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
Assert.assertEquals(transformed, output);
assertEquals(transformed.get(0).toString(), TEXT_LOWER_CASE);
assertEquals(transformed, output);
transform = new ChangeCaseStringTransform(STRING_COLUMN, ChangeCaseStringTransform.CaseType.UPPER);
transform.setInputSchema(schema);
@ -329,12 +329,12 @@ public class TestTransforms extends BaseND4JTest {
output.add(new Text(TEXT_UPPER_CASE));
output.add(new Text(TEXT_MIXED_CASE));
transformed = transform.map(input);
Assert.assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
Assert.assertEquals(transformed, output);
assertEquals(transformed.get(0).toString(), TEXT_UPPER_CASE);
assertEquals(transformed, output);
String s = JsonMappers.getMapper().writeValueAsString(transform);
Transform transform2 = JsonMappers.getMapper().readValue(s, ChangeCaseStringTransform.class);
Assert.assertEquals(transform, transform2);
assertEquals(transform, transform2);
}
@Test
@ -1530,7 +1530,7 @@ public class TestTransforms extends BaseND4JTest {
String json = JsonMappers.getMapper().writeValueAsString(t);
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToCountsNDArrayTransform.class);
Assert.assertEquals(t, transform2);
assertEquals(t, transform2);
}
@ -1551,7 +1551,7 @@ public class TestTransforms extends BaseND4JTest {
String json = JsonMappers.getMapper().writeValueAsString(t);
Transform transform2 = JsonMappers.getMapper().readValue(json, StringListToIndicesNDArrayTransform.class);
Assert.assertEquals(t, transform2);
assertEquals(t, transform2);
}

View File

@ -29,7 +29,7 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -39,7 +39,7 @@ import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayWritableTransforms extends BaseND4JTest {

View File

@ -30,13 +30,13 @@ import org.datavec.api.transform.ndarray.NDArrayScalarOpTransform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.serde.JsonSerializer;
import org.datavec.api.transform.serde.YamlSerializer;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestYamlJsonSerde extends BaseND4JTest {

View File

@ -17,29 +17,29 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.transform.transform.parse;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Parse Double Transform Test")
class ParseDoubleTransformTest extends BaseND4JTest {
public class ParseDoubleTransformTest extends BaseND4JTest {
@Test
public void testDoubleTransform() {
@DisplayName("Test Double Transform")
void testDoubleTransform() {
List<Writable> record = new ArrayList<>();
record.add(new Text("0.0"));
List<Writable> transformed = Arrays.<Writable>asList(new DoubleWritable(0.0));
assertEquals(transformed, new ParseDoubleTransform().map(record));
}
}

View File

@ -35,26 +35,26 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Ignore;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.File;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestUI extends BaseND4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testUI() throws Exception {
public void testUI(@TempDir Path testDir) throws Exception {
Schema schema = new Schema.Builder().addColumnString("StringColumn").addColumnInteger("IntColumn")
.addColumnInteger("IntColumn2").addColumnInteger("IntColumn3")
.addColumnTime("TimeColumn", DateTimeZone.UTC).build();
@ -92,7 +92,7 @@ public class TestUI extends BaseND4JTest {
DataAnalysis da = new DataAnalysis(schema, list);
File fDir = testDir.newFolder();
File fDir = testDir.toFile();
String tempDir = fDir.getAbsolutePath();
String outPath = FilenameUtils.concat(tempDir, "datavec_transform_UITest.html");
System.out.println(outPath);
@ -143,7 +143,7 @@ public class TestUI extends BaseND4JTest {
@Test
@Ignore
@Disabled
public void testSequencePlot() throws Exception {
Schema schema = new SequenceSchema.Builder().addColumnDouble("sinx")

View File

@ -17,30 +17,31 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.util;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.io.BufferedReader;
import java.io.File;
import java.io.InputStream;
import java.io.InputStreamReader;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.AnyOf.anyOf;
import static org.hamcrest.core.IsEqual.equalTo;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
public class ClassPathResourceTest extends BaseND4JTest {
@DisplayName("Class Path Resource Test")
class ClassPathResourceTest extends BaseND4JTest {
private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows
// File sizes are reported slightly different on Linux vs. Windows
private boolean isWindows = false;
@Before
public void setUp() throws Exception {
@BeforeEach
void setUp() throws Exception {
String osname = System.getProperty("os.name");
if (osname != null && osname.toLowerCase().contains("win")) {
isWindows = true;
@ -48,9 +49,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testGetFile1() throws Exception {
@DisplayName("Test Get File 1")
void testGetFile1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists());
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -60,9 +61,9 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testGetFileSlash1() throws Exception {
@DisplayName("Test Get File Slash 1")
void testGetFileSlash1() throws Exception {
File intFile = new ClassPathResource("datavec-api/iris.dat").getFile();
assertTrue(intFile.exists());
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(2700L), equalTo(2850L)));
@ -72,11 +73,10 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testGetFileWithSpace1() throws Exception {
@DisplayName("Test Get File With Space 1")
void testGetFileWithSpace1() throws Exception {
File intFile = new ClassPathResource("datavec-api/csvsequence test.txt").getFile();
assertTrue(intFile.exists());
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else {
@ -85,16 +85,15 @@ public class ClassPathResourceTest extends BaseND4JTest {
}
@Test
public void testInputStream() throws Exception {
@DisplayName("Test Input Stream")
void testInputStream() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile();
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else {
assertEquals(60, intFile.length());
}
InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
@ -102,21 +101,19 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) {
cnt++;
}
assertEquals(5, cnt);
}
@Test
public void testInputStreamSlash() throws Exception {
@DisplayName("Test Input Stream Slash")
void testInputStreamSlash() throws Exception {
ClassPathResource resource = new ClassPathResource("datavec-api/csvsequence_1.txt");
File intFile = resource.getFile();
if (isWindows) {
assertThat(intFile.length(), anyOf(equalTo(60L), equalTo(64L)));
} else {
assertEquals(60, intFile.length());
}
InputStream stream = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
String line = "";
@ -124,7 +121,6 @@ public class ClassPathResourceTest extends BaseND4JTest {
while ((line = reader.readLine()) != null) {
cnt++;
}
assertEquals(5, cnt);
}
}

View File

@ -17,44 +17,41 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.util;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertArrayEquals;
public class TimeSeriesUtilsTest extends BaseND4JTest {
@DisplayName("Time Series Utils Test")
class TimeSeriesUtilsTest extends BaseND4JTest {
@Test
public void testTimeSeriesCreation() {
@DisplayName("Test Time Series Creation")
void testTimeSeriesCreation() {
List<List<List<Writable>>> test = new ArrayList<>();
List<List<Writable>> timeStep = new ArrayList<>();
for(int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
timeStep.add(getRecord(5));
}
test.add(timeStep);
INDArray arr = TimeSeriesWritableUtils.convertWritablesSequence(test).getFirst();
assertArrayEquals(new long[]{1,5,5},arr.shape());
}
assertArrayEquals(new long[] { 1, 5, 5 }, arr.shape());
}
private List<Writable> getRecord(int length) {
private List<Writable> getRecord(int length) {
List<Writable> ret = new ArrayList<>();
for(int i = 0; i < length; i++) {
for (int i = 0; i < length; i++) {
ret.add(new DoubleWritable(1.0));
}
return ret;
}
}
}

View File

@ -17,52 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.writable;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.shade.guava.collect.Lists;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.Arrays;
import java.util.List;
import java.util.TimeZone;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Record Converter Test")
class RecordConverterTest extends BaseND4JTest {
public class RecordConverterTest extends BaseND4JTest {
@Test
public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
INDArray feature2 = Nd4j.create(new double[]{11, .7, -1.3, 4}, new long[]{1, 4}, DataType.FLOAT);
INDArray label1 = Nd4j.create(new double[]{0, 0, 1, 0}, new long[]{1, 4}, DataType.FLOAT);
INDArray label2 = Nd4j.create(new double[]{0, 1, 0, 0}, new long[]{1, 4}, DataType.FLOAT);
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)),
Nd4j.vstack(Lists.newArrayList(label1, label2)));
@DisplayName("To Records _ Pass In Classification Data Set _ Expect ND Array And Int Writables")
void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() {
INDArray feature1 = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray feature2 = Nd4j.create(new double[] { 11, .7, -1.3, 4 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label1 = Nd4j.create(new double[] { 0, 0, 1, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label2 = Nd4j.create(new double[] { 0, 1, 0, 0 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(Nd4j.vstack(Lists.newArrayList(feature1, feature2)), Nd4j.vstack(Lists.newArrayList(label1, label2)));
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
assertEquals(2, writableList.size());
testClassificationWritables(feature1, 2, writableList.get(0));
testClassificationWritables(feature2, 1, writableList.get(1));
}
@Test
public void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
INDArray feature = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT);
INDArray label = Nd4j.create(new double[]{.5, 2, 3, .5}, new long[]{1, 4}, DataType.FLOAT);
@DisplayName("To Records _ Pass In Regression Data Set _ Expect ND Array And Double Writables")
void toRecords_PassInRegressionDataSet_ExpectNDArrayAndDoubleWritables() {
INDArray feature = Nd4j.create(new double[] { 4, -5.7, 10, -0.1 }, new long[] { 1, 4 }, DataType.FLOAT);
INDArray label = Nd4j.create(new double[] { .5, 2, 3, .5 }, new long[] { 1, 4 }, DataType.FLOAT);
DataSet dataSet = new DataSet(feature, label);
List<List<Writable>> writableList = RecordConverter.toRecords(dataSet);
List<Writable> results = writableList.get(0);
NDArrayWritable ndArrayWritable = (NDArrayWritable) results.get(0);
assertEquals(1, writableList.size());
assertEquals(5, results.size());
assertEquals(feature, ndArrayWritable.get());
@ -72,62 +70,39 @@ public class RecordConverterTest extends BaseND4JTest {
}
}
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex,
List<Writable> writables) {
private void testClassificationWritables(INDArray expectedFeatureVector, int expectLabelIndex, List<Writable> writables) {
NDArrayWritable ndArrayWritable = (NDArrayWritable) writables.get(0);
IntWritable intWritable = (IntWritable) writables.get(1);
assertEquals(2, writables.size());
assertEquals(expectedFeatureVector, ndArrayWritable.get());
assertEquals(expectLabelIndex, intWritable.get());
}
@Test
public void testNDArrayWritableConcat() {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1),
new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5),
new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9),
new IntWritable(1));
INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT);
@DisplayName("Test ND Array Writable Concat")
void testNDArrayWritableConcat() {
List<Writable> l = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] { 6, 7, 8 }, new long[] { 1, 3 }, DataType.FLOAT)), new IntWritable(9), new IntWritable(1));
INDArray exp = Nd4j.create(new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1 }, new long[] { 1, 10 }, DataType.FLOAT);
INDArray act = RecordConverter.toArray(DataType.FLOAT, l);
assertEquals(exp, act);
}
@Test
public void testNDArrayWritableConcatToMatrix(){
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5));
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10));
INDArray exp = Nd4j.create(new double[][]{
{1,2,3,4,5},
{6,7,8,9,10}}).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2));
@DisplayName("Test ND Array Writable Concat To Matrix")
void testNDArrayWritableConcatToMatrix() {
List<Writable> l1 = Arrays.<Writable>asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] { 2, 3, 4 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(5));
List<Writable> l2 = Arrays.<Writable>asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[] { 7, 8, 9 }, new long[] { 1, 3 }, DataType.FLOAT)), new DoubleWritable(10));
INDArray exp = Nd4j.create(new double[][] { { 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10 } }).castTo(DataType.FLOAT);
INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1, l2));
assertEquals(exp, act);
}
@Test
public void testToRecordWithListOfObject(){
final List<Object> list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
final Schema schema = new Schema.Builder()
.addColumnInteger("a")
.addColumnFloat("b")
.addColumnString("c")
.addColumnCategorical("d", "Bar", "Baz")
.addColumnDouble("e")
.addColumnFloat("f")
.addColumnLong("g")
.addColumnInteger("h")
.addColumnTime("i", TimeZone.getDefault())
.build();
@DisplayName("Test To Record With List Of Object")
void testToRecordWithListOfObject() {
final List<Object> list = Arrays.asList((Object) 3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
final Schema schema = new Schema.Builder().addColumnInteger("a").addColumnFloat("b").addColumnString("c").addColumnCategorical("d", "Bar", "Baz").addColumnDouble("e").addColumnFloat("f").addColumnLong("g").addColumnInteger("h").addColumnTime("i", TimeZone.getDefault()).build();
final List<Writable> record = RecordConverter.toRecord(schema, list);
assertEquals(record.get(0).toInt(), 3);
assertEquals(record.get(1).toFloat(), 7f, 1e-6);
assertEquals(record.get(2).toString(), "Foo");
@ -137,7 +112,5 @@ public class RecordConverterTest extends BaseND4JTest {
assertEquals(record.get(6).toLong(), 3L);
assertEquals(record.get(7).toInt(), 7);
assertEquals(record.get(8).toLong(), 0);
}
}

View File

@ -21,14 +21,14 @@
package org.datavec.api.writable;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestNDArrayWritableAndSerialization extends BaseND4JTest {

View File

@ -17,38 +17,38 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.junit.jupiter.api.DisplayName;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class WritableTest extends BaseND4JTest {
@DisplayName("Writable Test")
class WritableTest extends BaseND4JTest {
@Test
public void testWritableEqualityReflexive() {
@DisplayName("Test Writable Equality Reflexive")
void testWritableEqualityReflexive() {
assertEquals(new IntWritable(1), new IntWritable(1));
assertEquals(new LongWritable(1), new LongWritable(1));
assertEquals(new DoubleWritable(1), new DoubleWritable(1));
assertEquals(new FloatWritable(1), new FloatWritable(1));
assertEquals(new Text("Hello"), new Text("Hello"));
assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes()));
INDArray ndArray = Nd4j.rand(new int[]{1, 100});
assertEquals(new BytesWritable("Hello".getBytes()), new BytesWritable("Hello".getBytes()));
INDArray ndArray = Nd4j.rand(new int[] { 1, 100 });
assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray));
assertEquals(new NullWritable(), new NullWritable());
assertEquals(new BooleanWritable(true), new BooleanWritable(true));
@ -56,9 +56,9 @@ public class WritableTest extends BaseND4JTest {
assertEquals(new ByteWritable(b), new ByteWritable(b));
}
@Test
public void testBytesWritableIndexing() {
@DisplayName("Test Bytes Writable Indexing")
void testBytesWritableIndexing() {
byte[] doubleWrite = new byte[16];
ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite);
Buffer buffer = (Buffer) wrapped;
@ -66,53 +66,51 @@ public class WritableTest extends BaseND4JTest {
wrapped.putDouble(2.0);
buffer.rewind();
BytesWritable byteWritable = new BytesWritable(doubleWrite);
assertEquals(2,byteWritable.getDouble(1),1e-1);
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] {1,2});
assertEquals(2, byteWritable.getDouble(1), 1e-1);
DataBuffer dataBuffer = Nd4j.createBuffer(new double[] { 1, 2 });
double[] d1 = dataBuffer.asDouble();
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE,8).asDouble();
double[] d2 = byteWritable.asNd4jBuffer(DataType.DOUBLE, 8).asDouble();
assertArrayEquals(d1, d2, 0.0);
}
@Test
public void testByteWritable() {
@DisplayName("Test Byte Writable")
void testByteWritable() {
byte b = 0xfffffffe;
assertEquals(new IntWritable(-2), new ByteWritable(b));
assertEquals(new LongWritable(-2), new ByteWritable(b));
assertEquals(new ByteWritable(b), new IntWritable(-2));
assertEquals(new ByteWritable(b), new LongWritable(-2));
// those would cast to the same Int
byte minus126 = 0xffffff82;
assertNotEquals(new ByteWritable(minus126), new IntWritable(130));
}
@Test
public void testIntLongWritable() {
@DisplayName("Test Int Long Writable")
void testIntLongWritable() {
assertEquals(new IntWritable(1), new LongWritable(1l));
assertEquals(new LongWritable(2l), new IntWritable(2));
long l = 1L << 34;
// those would cast to the same Int
assertNotEquals(new LongWritable(l), new IntWritable(4));
}
@Test
public void testDoubleFloatWritable() {
@DisplayName("Test Double Float Writable")
void testDoubleFloatWritable() {
assertEquals(new DoubleWritable(1d), new FloatWritable(1f));
assertEquals(new FloatWritable(2f), new DoubleWritable(2d));
// we defer to Java equality for Floats
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable(1.1f));
// same idea as above
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float)1.1d));
assertNotEquals(new DoubleWritable((double)Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
assertNotEquals(new DoubleWritable(1.1d), new FloatWritable((float) 1.1d));
assertNotEquals(new DoubleWritable((double) Float.MAX_VALUE + 1), new FloatWritable(Float.POSITIVE_INFINITY));
}
@Test
public void testFuzzies() {
@DisplayName("Test Fuzzies")
void testFuzzies() {
assertTrue(new DoubleWritable(1.1d).fuzzyEquals(new FloatWritable(1.1f), 1e-6d));
assertTrue(new FloatWritable(1.1f).fuzzyEquals(new DoubleWritable(1.1d), 1e-6d));
byte b = 0xfffffffe;
@ -122,62 +120,57 @@ public class WritableTest extends BaseND4JTest {
assertTrue(new LongWritable(1).fuzzyEquals(new DoubleWritable(1.05f), 1e-1d));
}
@Test
public void testNDArrayRecordBatch(){
@DisplayName("Test ND Array Record Batch")
void testNDArrayRecordBatch() {
Nd4j.getRandom().setSeed(12345);
List<List<INDArray>> orig = new ArrayList<>(); //Outer list over writables/columns, inner list over examples
for( int i=0; i<3; i++ ){
// Outer list over writables/columns, inner list over examples
List<List<INDArray>> orig = new ArrayList<>();
for (int i = 0; i < 3; i++) {
orig.add(new ArrayList<INDArray>());
}
for( int i=0; i<5; i++ ){
orig.get(0).add(Nd4j.rand(1,10));
orig.get(1).add(Nd4j.rand(new int[]{1,5,6}));
orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5}));
for (int i = 0; i < 5; i++) {
orig.get(0).add(Nd4j.rand(1, 10));
orig.get(1).add(Nd4j.rand(new int[] { 1, 5, 6 }));
orig.get(2).add(Nd4j.rand(new int[] { 1, 3, 4, 5 }));
}
List<List<INDArray>> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables
for( int i=0; i<5; i++ ){
// Outer list over examples, inner list over writables
List<List<INDArray>> origByExample = new ArrayList<>();
for (int i = 0; i < 5; i++) {
origByExample.add(Arrays.asList(orig.get(0).get(i), orig.get(1).get(i), orig.get(2).get(i)));
}
List<INDArray> batched = new ArrayList<>();
for(List<INDArray> l : orig){
for (List<INDArray> l : orig) {
batched.add(Nd4j.concat(0, l.toArray(new INDArray[5])));
}
NDArrayRecordBatch batch = new NDArrayRecordBatch(batched);
assertEquals(5, batch.size());
for( int i=0; i<5; i++ ){
for (int i = 0; i < 5; i++) {
List<Writable> act = batch.get(i);
List<INDArray> unboxed = new ArrayList<>();
for(Writable w : act){
unboxed.add(((NDArrayWritable)w).get());
for (Writable w : act) {
unboxed.add(((NDArrayWritable) w).get());
}
List<INDArray> exp = origByExample.get(i);
assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){
for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j));
}
}
Iterator<List<Writable>> iter = batch.iterator();
int count = 0;
while(iter.hasNext()){
while (iter.hasNext()) {
List<Writable> next = iter.next();
List<INDArray> unboxed = new ArrayList<>();
for(Writable w : next){
unboxed.add(((NDArrayWritable)w).get());
for (Writable w : next) {
unboxed.add(((NDArrayWritable) w).get());
}
List<INDArray> exp = origByExample.get(count++);
assertEquals(exp.size(), unboxed.size());
for( int j=0; j<exp.size(); j++ ){
for (int j = 0; j < exp.size(); j++) {
assertEquals(exp.get(j), unboxed.get(j));
}
}
assertEquals(5, count);
}
}

View File

@ -60,10 +60,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.arrow;
import lombok.extern.slf4j.Slf4j;
@ -42,461 +41,397 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.*;
import static java.nio.channels.Channels.newChannel;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
@Slf4j
public class ArrowConverterTest extends BaseND4JTest {
@DisplayName("Arrow Converter Test")
class ArrowConverterTest extends BaseND4JTest {
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testToArrayFromINDArray() {
@DisplayName("Test To Array From IND Array")
void testToArrayFromINDArray() {
Schema.Builder schemaBuilder = new Schema.Builder();
schemaBuilder.addColumnNDArray("outputArray",new long[]{1,4});
schemaBuilder.addColumnNDArray("outputArray", new long[] { 1, 4 });
Schema schema = schemaBuilder.build();
int numRows = 4;
List<List<Writable>> ret = new ArrayList<>(numRows);
for(int i = 0; i < numRows; i++) {
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4))));
for (int i = 0; i < numRows; i++) {
ret.add(Arrays.<Writable>asList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4))));
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors,schema);
ArrowWritableRecordBatch arrowWritableRecordBatch = new ArrowWritableRecordBatch(fieldVectors, schema);
INDArray array = ArrowConverter.toArray(arrowWritableRecordBatch);
assertArrayEquals(new long[]{4,4},array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1,4,4),4).reshape(4,4);
assertEquals(assertion,array);
assertArrayEquals(new long[] { 4, 4 }, array.shape());
INDArray assertion = Nd4j.repeat(Nd4j.linspace(1, 4, 4), 4).reshape(4, 4);
assertEquals(assertion, array);
}
@Test
public void testArrowColumnINDArray() {
@DisplayName("Test Arrow Column IND Array")
void testArrowColumnINDArray() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
int numCols = 2;
INDArray arr = Nd4j.linspace(1,4,4);
for(int i = 0; i < numCols; i++) {
schema.addColumnNDArray(String.valueOf(i),new long[]{1,4});
INDArray arr = Nd4j.linspace(1, 4, 4);
for (int i = 0; i < numCols; i++) {
schema.addColumnNDArray(String.valueOf(i), new long[] { 1, 4 });
single.add(String.valueOf(i));
}
Schema buildSchema = schema.build();
List<List<Writable>> list = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
for(int i = 0 ; i < numCols; i++) {
for (int i = 0; i < numCols; i++) {
firstRow.add(new NDArrayWritable(arr));
}
list.add(firstRow);
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, buildSchema, list);
assertEquals(numCols,fieldVectors.size());
assertEquals(1,fieldVectors.get(0).getValueCount());
assertEquals(numCols, fieldVectors.size());
assertEquals(1, fieldVectors.get(0).getValueCount());
assertFalse(fieldVectors.get(0).isNull(0));
ArrowWritableRecordBatch arrowWritableRecordBatch = ArrowConverter.toArrowWritables(fieldVectors, buildSchema);
assertEquals(1,arrowWritableRecordBatch.size());
assertEquals(1, arrowWritableRecordBatch.size());
Writable writable = arrowWritableRecordBatch.get(0).get(0);
assertTrue(writable instanceof NDArrayWritable);
NDArrayWritable ndArrayWritable = (NDArrayWritable) writable;
assertEquals(arr,ndArrayWritable.get());
assertEquals(arr, ndArrayWritable.get());
Writable writable1 = ArrowConverter.fromEntry(0, fieldVectors.get(0), ColumnType.NDArray);
NDArrayWritable ndArrayWritablewritable1 = (NDArrayWritable) writable1;
System.out.println(ndArrayWritablewritable1.get());
}
@Test
public void testArrowColumnString() {
@DisplayName("Test Arrow Column String")
void testArrowColumnString() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i));
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single);
List<List<Writable>> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
List<List<Writable>> assertion = new ArrayList<>();
assertion.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)));
assertEquals(assertion,records);
assertion.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)));
assertEquals(assertion, records);
List<List<String>> batch = new ArrayList<>();
for(int i = 0; i < 2; i++) {
batch.add(Arrays.asList(String.valueOf(i),String.valueOf(i)));
for (int i = 0; i < 2; i++) {
batch.add(Arrays.asList(String.valueOf(i), String.valueOf(i)));
}
List<FieldVector> fieldVectorsBatch = ArrowConverter.toArrowColumnsString(bufferAllocator, schema.build(), batch);
List<List<Writable>> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build());
List<List<Writable>> assertionBatch = new ArrayList<>();
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0),new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1),new IntWritable(1)));
assertEquals(assertionBatch,batchRecords);
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(0)));
assertionBatch.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1)));
assertEquals(assertionBatch, batchRecords);
}
@Test
public void testArrowBatchSetTime() {
@DisplayName("Test Arrow Batch Set Time")
void testArrowBatchSetTime() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
schema.addColumnTime(String.valueOf(i),TimeZone.getDefault());
for (int i = 0; i < 2; i++) {
schema.addColumnTime(String.valueOf(i), TimeZone.getDefault());
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(
Arrays.<Writable>asList(new LongWritable(0),new LongWritable(1)),
Arrays.<Writable>asList(new LongWritable(2),new LongWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new LongWritable(0), new LongWritable(1)), Arrays.<Writable>asList(new LongWritable(2), new LongWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4),new LongWritable(5)));
writableRecordBatch.set(1, Arrays.<Writable>asList(new LongWritable(4), new LongWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest);
assertEquals(assertion, recordTest);
}
@Test
public void testArrowBatchSet() {
@DisplayName("Test Arrow Batch Set")
void testArrowBatchSet() {
Schema.Builder schema = new Schema.Builder();
List<String> single = new ArrayList<>();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schema.addColumnInteger(String.valueOf(i));
single.add(String.valueOf(i));
}
List<List<Writable>> input = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1)),
Arrays.<Writable>asList(new IntWritable(2),new IntWritable(3))
);
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build());
List<List<Writable>> input = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(1)), Arrays.<Writable>asList(new IntWritable(2), new IntWritable(3)));
List<FieldVector> fieldVector = ArrowConverter.toArrowColumns(bufferAllocator, schema.build(), input);
ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector, schema.build());
List<Writable> assertion = Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5)));
writableRecordBatch.set(1, Arrays.<Writable>asList(new IntWritable(4), new IntWritable(5)));
List<Writable> recordTest = writableRecordBatch.get(1);
assertEquals(assertion,recordTest);
assertEquals(assertion, recordTest);
}
@Test
public void testArrowColumnsStringTimeSeries() {
@DisplayName("Test Arrow Columns String Time Series")
void testArrowColumnsStringTimeSeries() {
Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) {
for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i));
}
for(int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr);
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
assertEquals(3,fieldVectors.size());
assertEquals(5,fieldVectors.get(0).getValueCount());
assertEquals(3, fieldVectors.size());
assertEquals(5, fieldVectors.get(0).getValueCount());
INDArray exp = Nd4j.create(5, 3);
for( int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
exp.getRow(i).assign(i);
}
//Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
// Convert to ArrowWritableRecordBatch - note we can't do this in general with time series...
ArrowWritableRecordBatch wri = ArrowConverter.toArrowWritables(fieldVectors, schema.build());
INDArray arr = ArrowConverter.toArray(wri);
assertArrayEquals(new long[] {5,3}, arr.shape());
assertArrayEquals(new long[] { 5, 3 }, arr.shape());
assertEquals(exp, arr);
}
@Test
public void testConvertVector() {
@DisplayName("Test Convert Vector")
void testConvertVector() {
Schema.Builder schema = new Schema.Builder();
List<List<List<String>>> entries = new ArrayList<>();
for(int i = 0; i < 3; i++) {
for (int i = 0; i < 3; i++) {
schema.addColumnInteger(String.valueOf(i));
}
for(int i = 0; i < 5; i++) {
for (int i = 0; i < 5; i++) {
List<List<String>> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i)));
entries.add(arr);
}
List<FieldVector> fieldVectors = ArrowConverter.toArrowColumnsStringTimeSeries(bufferAllocator, schema.build(), entries);
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0),schema.build().getType(0));
assertEquals(5,arr.length());
INDArray arr = ArrowConverter.convertArrowVector(fieldVectors.get(0), schema.build().getType(0));
assertEquals(5, arr.length());
}
@Test
public void testCreateNDArray() throws Exception {
@DisplayName("Test Create ND Array")
void testCreateNDArray() throws Exception {
val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
File f = testDir.newFolder();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
File f = testDir.toFile();
File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw");
FileOutputStream outputStream = new FileOutputStream(tmpFile);
tmpFile.deleteOnExit();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), outputStream);
outputStream.flush();
outputStream.close();
Pair<Schema, ArrowWritableRecordBatch> schemaArrowWritableRecordBatchPair = ArrowConverter.readFromFile(tmpFile);
assertEquals(recordsToWrite.getFirst(),schemaArrowWritableRecordBatchPair.getFirst());
assertEquals(recordsToWrite.getRight(),schemaArrowWritableRecordBatchPair.getRight().toArrayList());
assertEquals(recordsToWrite.getFirst(), schemaArrowWritableRecordBatchPair.getFirst());
assertEquals(recordsToWrite.getRight(), schemaArrowWritableRecordBatchPair.getRight().toArrayList());
byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read);
//send file
File tmp = tmpDataFile(recordsToWrite);
assertEquals(recordsToWrite, read);
// send file
File tmp = tmpDataFile(recordsToWrite);
ArrowRecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp));
recordReader.next();
ArrowWritableRecordBatch currentBatch = recordReader.getCurrentBatch();
INDArray arr2 = ArrowConverter.toArray(currentBatch);
assertEquals(2,arr2.rows());
assertEquals(2,arr2.columns());
}
@Test
public void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1,4,4).reshape(2,2);
val vectors = ArrowConverter.convertToArrowVector(matrix,Arrays.asList("test","test2"), ColumnType.Double,bufferAllocator);
assertEquals(matrix.rows(),vectors.size());
INDArray vector = Nd4j.linspace(1,4,4);
val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator);
assertEquals(1,vectors2.size());
assertEquals(matrix.length(),vectors2.get(0).getValueCount());
assertEquals(2, arr2.rows());
assertEquals(2, arr2.columns());
}
@Test
public void testSchemaConversionBasic() {
@DisplayName("Test Convert To Arrow Vectors")
void testConvertToArrowVectors() {
INDArray matrix = Nd4j.linspace(1, 4, 4).reshape(2, 2);
val vectors = ArrowConverter.convertToArrowVector(matrix, Arrays.asList("test", "test2"), ColumnType.Double, bufferAllocator);
assertEquals(matrix.rows(), vectors.size());
INDArray vector = Nd4j.linspace(1, 4, 4);
val vectors2 = ArrowConverter.convertToArrowVector(vector, Arrays.asList("test"), ColumnType.Double, bufferAllocator);
assertEquals(1, vectors2.size());
assertEquals(matrix.length(), vectors2.get(0).getValueCount());
}
@Test
@DisplayName("Test Schema Conversion Basic")
void testSchemaConversionBasic() {
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnDouble("test-" + i);
schemaBuilder.addColumnInteger("testi-" + i);
schemaBuilder.addColumnLong("testl-" + i);
schemaBuilder.addColumnFloat("testf-" + i);
}
Schema schema = schemaBuilder.build();
val schema2 = ArrowConverter.toArrowSchema(schema);
assertEquals(8,schema2.getFields().size());
assertEquals(8, schema2.getFields().size());
val convertedSchema = ArrowConverter.toDatavecSchema(schema2);
assertEquals(schema,convertedSchema);
assertEquals(schema, convertedSchema);
}
@Test
public void testReadSchemaAndRecordsFromByteArray() throws Exception {
@DisplayName("Test Read Schema And Records From Byte Array")
void testReadSchemaAndRecordsFromByteArray() throws Exception {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
int valueCount = 3;
List<Field> fields = new ArrayList<>();
fields.add(ArrowConverter.field("field1",new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
fields.add(ArrowConverter.field("field1", new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)));
fields.add(ArrowConverter.intField("field2"));
List<FieldVector> fieldVectors = new ArrayList<>();
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field1",new float[] {1,2,3}));
fieldVectors.add(ArrowConverter.vectorFor(allocator,"field2",new int[] {1,2,3}));
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 }));
fieldVectors.add(ArrowConverter.vectorFor(allocator, "field2", new int[] { 1, 2, 3 }));
org.apache.arrow.vector.types.pojo.Schema schema = new org.apache.arrow.vector.types.pojo.Schema(fields);
VectorSchemaRoot schemaRoot1 = new VectorSchemaRoot(schema, fieldVectors, valueCount);
VectorUnloader vectorUnloader = new VectorUnloader(schemaRoot1);
vectorUnloader.getRecordBatch();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
try(ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1,null,newChannel(byteArrayOutputStream))) {
try (ArrowFileWriter arrowFileWriter = new ArrowFileWriter(schemaRoot1, null, newChannel(byteArrayOutputStream))) {
arrowFileWriter.writeBatch();
} catch (IOException e) {
log.error("",e);
log.error("", e);
}
byte[] arr = byteArrayOutputStream.toByteArray();
val arr2 = ArrowConverter.readFromBytes(arr);
assertEquals(2,arr2.getFirst().numColumns());
assertEquals(3,arr2.getRight().size());
val arrowCols = ArrowConverter.toArrowColumns(allocator,arr2.getFirst(),arr2.getRight());
assertEquals(2,arrowCols.size());
assertEquals(valueCount,arrowCols.get(0).getValueCount());
assertEquals(2, arr2.getFirst().numColumns());
assertEquals(3, arr2.getRight().size());
val arrowCols = ArrowConverter.toArrowColumns(allocator, arr2.getFirst(), arr2.getRight());
assertEquals(2, arrowCols.size());
assertEquals(valueCount, arrowCols.get(0).getValueCount());
}
@Test
public void testVectorForEdgeCases() {
@DisplayName("Test Vector For Edge Cases")
void testVectorForEdgeCases() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{Float.MIN_VALUE,Float.MAX_VALUE});
assertEquals(Float.MIN_VALUE,vector.get(0),1e-2);
assertEquals(Float.MAX_VALUE,vector.get(1),1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{Integer.MIN_VALUE,Integer.MAX_VALUE});
assertEquals(Integer.MIN_VALUE,vectorInt.get(0),1e-2);
assertEquals(Integer.MAX_VALUE,vectorInt.get(1),1e-2);
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { Float.MIN_VALUE, Float.MAX_VALUE });
assertEquals(Float.MIN_VALUE, vector.get(0), 1e-2);
assertEquals(Float.MAX_VALUE, vector.get(1), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { Integer.MIN_VALUE, Integer.MAX_VALUE });
assertEquals(Integer.MIN_VALUE, vectorInt.get(0), 1e-2);
assertEquals(Integer.MAX_VALUE, vectorInt.get(1), 1e-2);
}
@Test
public void testVectorFor() {
@DisplayName("Test Vector For")
void testVectorFor() {
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
val vector = ArrowConverter.vectorFor(allocator,"field1",new float[]{1,2,3});
assertEquals(3,vector.getValueCount());
assertEquals(1,vector.get(0),1e-2);
assertEquals(2,vector.get(1),1e-2);
assertEquals(3,vector.get(2),1e-2);
val vectorLong = ArrowConverter.vectorFor(allocator,"field1",new long[]{1,2,3});
assertEquals(3,vectorLong.getValueCount());
assertEquals(1,vectorLong.get(0),1e-2);
assertEquals(2,vectorLong.get(1),1e-2);
assertEquals(3,vectorLong.get(2),1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator,"field1",new int[]{1,2,3});
assertEquals(3,vectorInt.getValueCount());
assertEquals(1,vectorInt.get(0),1e-2);
assertEquals(2,vectorInt.get(1),1e-2);
assertEquals(3,vectorInt.get(2),1e-2);
val vectorDouble = ArrowConverter.vectorFor(allocator,"field1",new double[]{1,2,3});
assertEquals(3,vectorDouble.getValueCount());
assertEquals(1,vectorDouble.get(0),1e-2);
assertEquals(2,vectorDouble.get(1),1e-2);
assertEquals(3,vectorDouble.get(2),1e-2);
val vectorBool = ArrowConverter.vectorFor(allocator,"field1",new boolean[]{true,true,false});
assertEquals(3,vectorBool.getValueCount());
assertEquals(1,vectorBool.get(0),1e-2);
assertEquals(1,vectorBool.get(1),1e-2);
assertEquals(0,vectorBool.get(2),1e-2);
val vector = ArrowConverter.vectorFor(allocator, "field1", new float[] { 1, 2, 3 });
assertEquals(3, vector.getValueCount());
assertEquals(1, vector.get(0), 1e-2);
assertEquals(2, vector.get(1), 1e-2);
assertEquals(3, vector.get(2), 1e-2);
val vectorLong = ArrowConverter.vectorFor(allocator, "field1", new long[] { 1, 2, 3 });
assertEquals(3, vectorLong.getValueCount());
assertEquals(1, vectorLong.get(0), 1e-2);
assertEquals(2, vectorLong.get(1), 1e-2);
assertEquals(3, vectorLong.get(2), 1e-2);
val vectorInt = ArrowConverter.vectorFor(allocator, "field1", new int[] { 1, 2, 3 });
assertEquals(3, vectorInt.getValueCount());
assertEquals(1, vectorInt.get(0), 1e-2);
assertEquals(2, vectorInt.get(1), 1e-2);
assertEquals(3, vectorInt.get(2), 1e-2);
val vectorDouble = ArrowConverter.vectorFor(allocator, "field1", new double[] { 1, 2, 3 });
assertEquals(3, vectorDouble.getValueCount());
assertEquals(1, vectorDouble.get(0), 1e-2);
assertEquals(2, vectorDouble.get(1), 1e-2);
assertEquals(3, vectorDouble.get(2), 1e-2);
val vectorBool = ArrowConverter.vectorFor(allocator, "field1", new boolean[] { true, true, false });
assertEquals(3, vectorBool.getValueCount());
assertEquals(1, vectorBool.get(0), 1e-2);
assertEquals(1, vectorBool.get(1), 1e-2);
assertEquals(0, vectorBool.get(2), 1e-2);
}
@Test
public void testRecordReaderAndWriteFile() throws Exception {
@DisplayName("Test Record Reader And Write File")
void testRecordReaderAndWriteFile() throws Exception {
val recordsToWrite = recordToWrite();
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),byteArrayOutputStream);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), byteArrayOutputStream);
byte[] arr = byteArrayOutputStream.toByteArray();
val read = ArrowConverter.readFromBytes(arr);
assertEquals(recordsToWrite,read);
//send file
File tmp = tmpDataFile(recordsToWrite);
assertEquals(recordsToWrite, read);
// send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
recordReader.initialize(new FileSplit(tmp));
List<Writable> record = recordReader.next();
assertEquals(2,record.size());
assertEquals(2, record.size());
}
@Test
public void testRecordReaderMetaDataList() throws Exception {
@DisplayName("Test Record Reader Meta Data List")
void testRecordReaderMetaDataList() throws Exception {
val recordsToWrite = recordToWrite();
//send file
File tmp = tmpDataFile(recordsToWrite);
// send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(Arrays.<RecordMetaData>asList(recordMetaDataIndex));
Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size());
assertEquals(2, record.getRecord().size());
}
@Test
public void testDates() {
@DisplayName("Test Dates")
void testDates() {
Date now = new Date();
BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[]{now});
assertEquals(now.getTime(),timeStampMilliVector.get(0));
TimeStampMilliVector timeStampMilliVector = ArrowConverter.vectorFor(bufferAllocator, "col1", new Date[] { now });
assertEquals(now.getTime(), timeStampMilliVector.get(0));
}
@Test
public void testRecordReaderMetaData() throws Exception {
@DisplayName("Test Record Reader Meta Data")
void testRecordReaderMetaData() throws Exception {
val recordsToWrite = recordToWrite();
//send file
File tmp = tmpDataFile(recordsToWrite);
// send file
File tmp = tmpDataFile(recordsToWrite);
RecordReader recordReader = new ArrowRecordReader();
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class);
RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0, tmp.toURI(), ArrowRecordReader.class);
recordReader.loadFromMetaData(recordMetaDataIndex);
Record record = recordReader.nextRecord();
assertEquals(2,record.getRecord().size());
assertEquals(2, record.getRecord().size());
}
private File tmpDataFile(Pair<Schema,List<List<Writable>>> recordsToWrite) throws IOException {
File f = testDir.newFolder();
//send file
File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString());
private File tmpDataFile(Pair<Schema, List<List<Writable>>> recordsToWrite) throws IOException {
File f = testDir.toFile();
// send file
File tmp = new File(f, "tmp-file-" + UUID.randomUUID().toString());
tmp.mkdirs();
File tmpFile = new File(tmp,"data.arrow");
File tmpFile = new File(tmp, "data.arrow");
tmpFile.deleteOnExit();
FileOutputStream bufferedOutputStream = new FileOutputStream(tmpFile);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),bufferedOutputStream);
ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(), recordsToWrite.getFirst(), bufferedOutputStream);
bufferedOutputStream.flush();
bufferedOutputStream.close();
return tmp;
}
private Pair<Schema,List<List<Writable>>> recordToWrite() {
private Pair<Schema, List<List<Writable>>> recordToWrite() {
List<List<Writable>> records = new ArrayList<>();
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0),new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
records.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new DoubleWritable(0.0)));
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < 2; i++) {
for (int i = 0; i < 2; i++) {
schemaBuilder.addColumnFloat("col-" + i);
}
return Pair.of(schemaBuilder.build(),records);
return Pair.of(schemaBuilder.build(), records);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.arrow;
import lombok.val;
@ -34,132 +33,98 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowRecordReader;
import org.datavec.arrow.recordreader.ArrowRecordWriter;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import org.nd4j.common.primitives.Triple;
import java.io.File;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class RecordMapperTest extends BaseND4JTest {
@DisplayName("Record Mapper Test")
class RecordMapperTest extends BaseND4JTest {
@Test
public void testMultiWrite() throws Exception {
@DisplayName("Test Multi Write")
void testMultiWrite() throws Exception {
val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst());
FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
int numReaders = 2;
RecordReader[] readers = new RecordReader[numReaders];
InputSplit[] splits = new InputSplit[numReaders];
for(int i = 0; i < readers.length; i++) {
for (int i = 0; i < readers.length; i++) {
FileSplit split = new FileSplit(p.toFile());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
readers[i] = arrowRecordReader;
splits[i] = split;
}
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight());
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst());
FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers)
.splitPerReader(splits)
.recordWriter(csvRecordWriter)
.build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).readersToConcat(readers).splitPerReader(splits).recordWriter(csvRecordWriter).build();
mapper.copy();
}
@Test
public void testCopyFromArrowToCsv() throws Exception {
@DisplayName("Test Copy From Arrow To Csv")
void testCopyFromArrowToCsv() throws Exception {
val recordsPair = records();
Path p = Files.createTempFile("arrowwritetest", ".arrow");
FileUtils.write(p.toFile(),recordsPair.getFirst());
FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
FileSplit split = new FileSplit(p.toFile());
arrowRecordWriter.initialize(split,new NumberOfRecordsPartitioner());
arrowRecordWriter.initialize(split, new NumberOfRecordsPartitioner());
arrowRecordWriter.writeBatch(recordsPair.getRight());
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(split);
CSVRecordWriter csvRecordWriter = new CSVRecordWriter();
Path p2 = Files.createTempFile("arrowwritetest", ".csv");
FileUtils.write(p2.toFile(),recordsPair.getFirst());
FileUtils.write(p2.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
FileSplit outputCsv = new FileSplit(p2.toFile());
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split)
.outputUrl(outputCsv)
.partitioner(new NumberOfRecordsPartitioner())
.recordReader(arrowRecordReader).recordWriter(csvRecordWriter)
.build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(split).outputUrl(outputCsv).partitioner(new NumberOfRecordsPartitioner()).recordReader(arrowRecordReader).recordWriter(csvRecordWriter).build();
mapper.copy();
CSVRecordReader recordReader = new CSVRecordReader();
recordReader.initialize(outputCsv);
List<List<Writable>> loadedCSvRecords = recordReader.next(10);
assertEquals(10,loadedCSvRecords.size());
assertEquals(10, loadedCSvRecords.size());
}
@Test
public void testCopyFromCsvToArrow() throws Exception {
@DisplayName("Test Copy From Csv To Arrow")
void testCopyFromCsvToArrow() throws Exception {
val recordsPair = records();
Path p = Files.createTempFile("csvwritetest", ".csv");
FileUtils.write(p.toFile(),recordsPair.getFirst());
FileUtils.write(p.toFile(), recordsPair.getFirst());
p.toFile().deleteOnExit();
CSVRecordReader recordReader = new CSVRecordReader();
FileSplit fileSplit = new FileSplit(p.toFile());
ArrowRecordWriter arrowRecordWriter = new ArrowRecordWriter(recordsPair.getMiddle());
File outputFile = Files.createTempFile("outputarrow","arrow").toFile();
File outputFile = Files.createTempFile("outputarrow", "arrow").toFile();
FileSplit outputFileSplit = new FileSplit(outputFile);
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit)
.outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner())
.recordReader(recordReader).recordWriter(arrowRecordWriter)
.build();
RecordMapper mapper = RecordMapper.builder().batchSize(10).inputUrl(fileSplit).outputUrl(outputFileSplit).partitioner(new NumberOfRecordsPartitioner()).recordReader(recordReader).recordWriter(arrowRecordWriter).build();
mapper.copy();
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(outputFileSplit);
List<List<Writable>> next = arrowRecordReader.next(10);
System.out.println(next);
assertEquals(10,next.size());
assertEquals(10, next.size());
}
private Triple<String,Schema,List<List<Writable>>> records() {
private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder();
int numColumns = 3;
@ -176,15 +141,10 @@ public class RecordMapperTest extends BaseND4JTest {
}
list.add(temp);
}
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) {
for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i));
}
return Triple.of(sb.toString(),schemaBuilder.build(),list);
return Triple.of(sb.toString(), schemaBuilder.build(), list);
}
}

View File

@ -29,16 +29,16 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@ -46,6 +46,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@Test
@Disabled
public void testBasicIndexing() {
Schema.Builder schema = new Schema.Builder();
for(int i = 0; i < 3; i++) {
@ -54,9 +55,9 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
List<List<Writable>> timeStep = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)),
Arrays.<Writable>asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)),
Arrays.<Writable>asList(new IntWritable(4),new IntWritable(5),new IntWritable(6))
Arrays.asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)),
Arrays.asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)),
Arrays.asList(new IntWritable(4),new IntWritable(5),new IntWritable(6))
);
int numTimeSteps = 5;
@ -69,7 +70,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
assertEquals(3,fieldVectors.size());
for(FieldVector fieldVector : fieldVectors) {
for(int i = 0; i < fieldVector.getValueCount(); i++) {
assertFalse("Index " + i + " was null for field vector " + fieldVector, fieldVector.isNull(i));
assertFalse( fieldVector.isNull(i),"Index " + i + " was null for field vector " + fieldVector);
}
}
@ -79,7 +80,7 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest {
@Test
//not worried about this till after next release
@Ignore
@Disabled
public void testVariableLengthTS() {
Schema.Builder schema = new Schema.Builder()
.addColumnString("str")

View File

@ -119,10 +119,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -17,41 +17,39 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.recordreader.ImageRecordReader;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.io.ClassPathResource;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@DisplayName("Label Generator Test")
class LabelGeneratorTest {
public class LabelGeneratorTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testParentPathLabelGenerator() throws Exception {
//https://github.com/deeplearning4j/DataVec/issues/273
@DisplayName("Test Parent Path Label Generator")
@Disabled
void testParentPathLabelGenerator(@TempDir Path testDir) throws Exception {
File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile();
for(String dirPrefix : new String[]{"m.", "m"}) {
File f = testDir.newFolder();
for (String dirPrefix : new String[] { "m.", "m" }) {
File f = testDir.toFile();
int numDirs = 3;
int filesPerDir = 4;
for (int i = 0; i < numDirs; i++) {
File currentLabelDir = new File(f, dirPrefix + i);
currentLabelDir.mkdirs();
@ -61,14 +59,11 @@ public class LabelGeneratorTest {
assertTrue(f3.exists());
}
}
ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
rr.initialize(new FileSplit(f));
List<String> labelsAct = rr.getLabels();
List<String> labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2");
assertEquals(labelsExp, labelsAct);
int expCount = numDirs * filesPerDir;
int actCount = 0;
while (rr.hasNext()) {

View File

@ -22,8 +22,8 @@ package org.datavec.image.loader;
import org.apache.commons.io.FilenameUtils;
import org.datavec.api.records.reader.RecordReader;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.DataSet;
import java.io.File;
@ -32,9 +32,9 @@ import java.io.InputStream;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
*
@ -182,7 +182,7 @@ public class LoaderTests {
}
@Ignore // Use when confirming data is getting stored
@Disabled // Use when confirming data is getting stored
@Test
public void testProcessCifar() {
int row = 32;
@ -208,15 +208,15 @@ public class LoaderTests {
int minibatch = 100;
int nMinibatches = 50000 / minibatch;
for( int i=0; i<nMinibatches; i++ ){
for( int i=0; i < nMinibatches; i++) {
DataSet ds = loader.next(minibatch);
String s = String.valueOf(i);
assertNotNull(s, ds.getFeatures());
assertNotNull(s, ds.getLabels());
assertNotNull(ds.getFeatures(),s);
assertNotNull(ds.getLabels(),s);
assertEquals(s, minibatch, ds.getFeatures().size(0));
assertEquals(s, minibatch, ds.getLabels().size(0));
assertEquals(s, 10, ds.getLabels().size(1));
assertEquals(minibatch, ds.getFeatures().size(0),s);
assertEquals(minibatch, ds.getLabels().size(0),s);
assertEquals(10, ds.getLabels().size(1),s);
}
}

View File

@ -21,7 +21,7 @@
package org.datavec.image.loader;
import org.datavec.image.data.Image;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -32,7 +32,7 @@ import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Random;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestImageLoader {

View File

@ -30,9 +30,10 @@ import org.bytedeco.javacv.Java2DFrameConverter;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.data.Image;
import org.datavec.image.data.ImageWritable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.resources.Resources;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -42,16 +43,17 @@ import org.nd4j.common.io.ClassPathResource;
import java.awt.image.BufferedImage;
import java.io.*;
import java.lang.reflect.Field;
import java.nio.file.Path;
import java.util.Random;
import org.bytedeco.leptonica.*;
import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.leptonica.global.lept.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
/**
*
@ -62,8 +64,6 @@ public class TestNativeImageLoader {
static final long seed = 10;
static final Random rng = new Random(seed);
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testConvertPix() throws Exception {
@ -566,8 +566,8 @@ public class TestNativeImageLoader {
@Test
public void testNativeImageLoaderEmptyStreams() throws Exception {
File dir = testDir.newFolder();
public void testNativeImageLoaderEmptyStreams(@TempDir Path testDir) throws Exception {
File dir = testDir.toFile();
File f = new File(dir, "myFile.jpg");
f.createNewFile();
@ -578,7 +578,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue(msg.contains("decode image"),msg);
}
try(InputStream is = new FileInputStream(f)){
@ -586,7 +586,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue(msg.contains("decode image"),msg);
}
try(InputStream is = new FileInputStream(f)){
@ -594,7 +594,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue(msg.contains("decode image"),msg);
}
try(InputStream is = new FileInputStream(f)){
@ -603,7 +603,7 @@ public class TestNativeImageLoader {
fail("Expected exception");
} catch (IOException e){
String msg = e.getMessage();
assertTrue(msg, msg.contains("decode image"));
assertTrue( msg.contains("decode image"),msg);
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.recordreader;
import org.apache.commons.io.FileUtils;
@ -28,61 +27,56 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.loader.FileBatch;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.*;
@DisplayName("File Batch Record Reader Test")
class FileBatchRecordReaderTest {
public class FileBatchRecordReaderTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testCsv() throws Exception {
File extractedSourceDir = testDir.newFolder();
@DisplayName("Test Csv")
void testCsv(@TempDir Path testDir,@TempDir Path baseDirPath) throws Exception {
File extractedSourceDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages").copyDirectory(extractedSourceDir);
File baseDir = testDir.newFolder();
File baseDir = baseDirPath.toFile();
List<File> c = new ArrayList<>(FileUtils.listFiles(extractedSourceDir, null, true));
assertEquals(6, c.size());
Collections.sort(c, new Comparator<File>() {
@Override
public int compare(File o1, File o2) {
return o1.getPath().compareTo(o2.getPath());
}
});
FileBatch fb = FileBatch.forFiles(c);
File saveFile = new File(baseDir, "saved.zip");
fb.writeAsZip(saveFile);
fb = FileBatch.readFromZip(saveFile);
PathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 1, labelMaker);
rr.setLabels(Arrays.asList("class0", "class1"));
FileBatchRecordReader fbrr = new FileBatchRecordReader(rr, fb);
NativeImageLoader il = new NativeImageLoader(32, 32, 1);
for( int test=0; test<3; test++) {
for (int test = 0; test < 3; test++) {
for (int i = 0; i < 6; i++) {
assertTrue(fbrr.hasNext());
List<Writable> next = fbrr.next();
assertEquals(2, next.size());
INDArray exp;
switch (i){
switch(i) {
case 0:
exp = il.asMatrix(new File(extractedSourceDir, "class0/0.jpg"));
break;
@ -105,8 +99,7 @@ public class FileBatchRecordReaderTest {
throw new RuntimeException();
}
Writable expLabel = (i < 3 ? new IntWritable(0) : new IntWritable(1));
assertEquals(((NDArrayWritable)next.get(0)).get(), exp);
assertEquals(((NDArrayWritable) next.get(0)).get(), exp);
assertEquals(expLabel, next.get(1));
}
assertFalse(fbrr.hasNext());
@ -114,5 +107,4 @@ public class FileBatchRecordReaderTest {
fbrr.reset();
}
}
}

View File

@ -36,9 +36,10 @@ import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -46,28 +47,30 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.*;
import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestImageRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test(expected = IllegalArgumentException.class)
@Test()
public void testEmptySplit() throws IOException {
InputSplit data = new CollectionInputSplit(new ArrayList<URI>());
new ImageRecordReader().initialize(data, null);
assertThrows(IllegalArgumentException.class,() -> {
InputSplit data = new CollectionInputSplit(new ArrayList<>());
new ImageRecordReader().initialize(data, null);
});
}
@Test
public void testMetaData() throws IOException {
public void testMetaData(@TempDir Path testDir) throws IOException {
File parentDir = testDir.newFolder();
File parentDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(parentDir);
// System.out.println(f.getAbsolutePath());
// System.out.println(f.getParentFile().getParentFile().getAbsolutePath());
@ -104,11 +107,11 @@ public class TestImageRecordReader {
}
@Test
public void testImageRecordReaderLabelsOrder() throws Exception {
public void testImageRecordReaderLabelsOrder(@TempDir Path testDir) throws Exception {
//Labels order should be consistent, regardless of file iteration order
//Idea: labels order should be consistent regardless of input file order
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File f0 = new File(f, "/class0/0.jpg");
File f1 = new File(f, "/class1/A.jpg");
@ -135,11 +138,11 @@ public class TestImageRecordReader {
@Test
public void testImageRecordReaderRandomization() throws Exception {
public void testImageRecordReaderRandomization(@TempDir Path testDir) throws Exception {
//Order of FileSplit+ImageRecordReader should be different after reset
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs = new FileSplit(f0, new Random(12345));
@ -189,13 +192,13 @@ public class TestImageRecordReader {
@Test
public void testImageRecordReaderRegression() throws Exception {
public void testImageRecordReaderRegression(@TempDir Path testDir) throws Exception {
PathLabelGenerator regressionLabelGen = new TestRegressionLabelGen();
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, regressionLabelGen);
File rootDir = testDir.newFolder();
File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs);
@ -244,10 +247,10 @@ public class TestImageRecordReader {
}
@Test
public void testListenerInvocationBatch() throws IOException {
public void testListenerInvocationBatch(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f);
File parent = f;
@ -260,10 +263,10 @@ public class TestImageRecordReader {
}
@Test
public void testListenerInvocationSingle() throws IOException {
public void testListenerInvocationSingle(@TempDir Path testDir) throws IOException {
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader rr = new ImageRecordReader(32, 32, 3, labelMaker);
File parent = testDir.newFolder();
File parent = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/class0/").copyDirectory(parent);
int numFiles = parent.list().length;
rr.initialize(new FileSplit(parent));
@ -315,7 +318,7 @@ public class TestImageRecordReader {
@Test
public void testImageRecordReaderPathMultiLabelGenerator() throws Exception {
public void testImageRecordReaderPathMultiLabelGenerator(@TempDir Path testDir) throws Exception {
Nd4j.setDataType(DataType.FLOAT);
//Assumption: 2 multi-class (one hot) classification labels: 2 and 3 classes respectively
// PLUS single value (Writable) regression label
@ -324,7 +327,7 @@ public class TestImageRecordReader {
ImageRecordReader rr = new ImageRecordReader(28, 28, 3, multiLabelGen);
File rootDir = testDir.newFolder();
File rootDir = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(rootDir);
FileSplit fs = new FileSplit(rootDir);
rr.initialize(fs);
@ -471,9 +474,9 @@ public class TestImageRecordReader {
@Test
public void testNCHW_NCHW() throws Exception {
public void testNCHW_NCHW(@TempDir Path testDir) throws Exception {
//Idea: labels order should be consistent regardless of input file order
File f0 = testDir.newFolder();
File f0 = testDir.toFile();
new ClassPathResource("datavec-data-image/testimages/").copyDirectory(f0);
FileSplit fs0 = new FileSplit(f0, new Random(12345));

View File

@ -35,9 +35,10 @@ import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.PipelineImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
@ -46,24 +47,24 @@ import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.net.URI;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
public class TestObjectDetectionRecordReader {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void test() throws Exception {
public void test(@TempDir Path testDir) throws Exception {
for(boolean nchw : new boolean[]{true, false}) {
ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider();
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f);
String path = new File(f, "000012.jpg").getParent();

View File

@ -21,27 +21,27 @@
package org.datavec.image.recordreader.objdetect;
import org.datavec.image.recordreader.objdetect.impl.VocLabelProvider;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestVocLabelProvider {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testVocLabelProvider() throws Exception {
public void testVocLabelProvider(@TempDir Path testDir) throws Exception {
File f = testDir.newFolder();
File f = testDir.toFile();
new ClassPathResource("datavec-data-image/voc/2007/").copyDirectory(f);
String path = f.getAbsolutePath(); //new ClassPathResource("voc/2007/JPEGImages/000005.jpg").getFile().getParentFile().getParent();

View File

@ -17,106 +17,70 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.transform;
import org.datavec.image.data.ImageWritable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.List;
import java.util.Random;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@DisplayName("Json Yaml Test")
class JsonYamlTest {
public class JsonYamlTest {
@Test
public void testJsonYamlImageTransformProcess() throws IOException {
@DisplayName("Test Json Yaml Image Transform Process")
void testJsonYamlImageTransformProcess() throws IOException {
int seed = 12345;
Random random = new Random(seed);
//from org.bytedeco.javacpp.opencv_imgproc
// from org.bytedeco.javacpp.opencv_imgproc
int COLOR_BGR2Luv = 50;
int CV_BGR2GRAY = 6;
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv)
.cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0)
.resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3)
.warpImageTransform((float) 0.5)
// Note : since randomCropTransform use random value
// the results from each case(json, yaml, ImageTransformProcess)
// can be different
// don't use the below line
// if you uncomment it, you will get fail from below assertions
// .randomCropTransform(seed, 50, 50)
// Note : you will get "java.lang.NoClassDefFoundError: Could not initialize class org.bytedeco.javacpp.avutil"
// it needs to add the below dependency
// <dependency>
// <groupId>org.bytedeco</groupId>
// <artifactId>ffmpeg-platform</artifactId>
// </dependency>
// FFmpeg has license issues, be careful to use it
//.filterImageTransform("noise=alls=20:allf=t+u,format=rgba", 100, 100, 4)
.build();
ImageTransformProcess itp = new ImageTransformProcess.Builder().colorConversionTransform(COLOR_BGR2Luv).cropImageTransform(10).equalizeHistTransform(CV_BGR2GRAY).flipImageTransform(0).resizeImageTransform(300, 300).rotateImageTransform(30).scaleImageTransform(3).warpImageTransform((float) 0.5).build();
String asJson = itp.toJson();
String asYaml = itp.toYaml();
// System.out.println(asJson);
// System.out.println("\n\n\n");
// System.out.println(asYaml);
// System.out.println(asJson);
// System.out.println("\n\n\n");
// System.out.println(asYaml);
ImageWritable img = TestImageTransform.makeRandomImage(0, 0, 3);
ImageWritable imgJson = new ImageWritable(img.getFrame().clone());
ImageWritable imgYaml = new ImageWritable(img.getFrame().clone());
ImageWritable imgAll = new ImageWritable(img.getFrame().clone());
ImageTransformProcess itpFromJson = ImageTransformProcess.fromJson(asJson);
ImageTransformProcess itpFromYaml = ImageTransformProcess.fromYaml(asYaml);
List<ImageTransform> transformList = itp.getTransformList();
List<ImageTransform> transformListJson = itpFromJson.getTransformList();
List<ImageTransform> transformListYaml = itpFromYaml.getTransformList();
for (int i = 0; i < transformList.size(); i++) {
ImageTransform it = transformList.get(i);
ImageTransform itJson = transformListJson.get(i);
ImageTransform itYaml = transformListYaml.get(i);
System.out.println(i + "\t" + it);
img = it.transform(img);
imgJson = itJson.transform(imgJson);
imgYaml = itYaml.transform(imgYaml);
if (it instanceof RandomCropTransform) {
assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth);
assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight);
assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth);
} else if (it instanceof FilterImageTransform) {
assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgJson.getFrame().imageChannels);
assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight);
assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth);
assertEquals(img.getFrame().imageChannels, imgYaml.getFrame().imageChannels);
} else {
assertEquals(img, imgJson);
assertEquals(img, imgYaml);
}
}
imgAll = itp.execute(imgAll);
assertEquals(imgAll, img);
}
}

View File

@ -17,56 +17,50 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.transform;
import org.bytedeco.javacv.Frame;
import org.datavec.image.data.ImageWritable;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
public class ResizeImageTransformTest {
@Before
public void setUp() throws Exception {
@DisplayName("Resize Image Transform Test")
class ResizeImageTransformTest {
@BeforeEach
void setUp() throws Exception {
}
@Test
public void testResizeUpscale1() throws Exception {
@DisplayName("Test Resize Upscale 1")
void testResizeUpscale1() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(32, 32, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200);
float[] coordinates = {100, 200};
float[] coordinates = { 100, 200 };
float[] transformed = transform.query(coordinates);
assertEquals(200f * 100 / 32, transformed[0], 0);
assertEquals(200f * 200 / 32, transformed[1], 0);
}
@Test
public void testResizeDownscale() throws Exception {
@DisplayName("Test Resize Downscale")
void testResizeDownscale() throws Exception {
ImageWritable srcImg = TestImageTransform.makeRandomImage(571, 443, 3);
ResizeImageTransform transform = new ResizeImageTransform(200, 200);
ImageWritable dstImg = transform.transform(srcImg);
Frame f = dstImg.getFrame();
assertEquals(f.imageWidth, 200);
assertEquals(f.imageHeight, 200);
float[] coordinates = {300, 400};
float[] coordinates = { 300, 400 };
float[] transformed = transform.query(coordinates);
assertEquals(200f * 300 / 443, transformed[0], 0);
assertEquals(200f * 400 / 571, transformed[1], 0);
}
}

View File

@ -28,8 +28,8 @@ import org.nd4j.common.io.ClassPathResource;
import org.nd4j.common.primitives.Pair;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.loader.NativeImageLoader;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.awt.*;
import java.util.LinkedList;
@ -40,7 +40,7 @@ import org.bytedeco.opencv.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_core.*;
import static org.bytedeco.opencv.global.opencv_imgproc.*;
import static org.junit.Assert.*;
import static org.junit.jupiter.api.Assertions.*;
/**
*
@ -255,7 +255,7 @@ public class TestImageTransform {
assertEquals(22, transformed[1], 0);
}
@Ignore
@Disabled
@Test
public void testFilterImageTransform() throws Exception {
ImageWritable writable = makeRandomImage(0, 0, 4);

View File

@ -59,10 +59,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -57,10 +57,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -17,37 +17,34 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.poi.excel;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class ExcelRecordReaderTest {
@DisplayName("Excel Record Reader Test")
class ExcelRecordReaderTest {
@Test
public void testSimple() throws Exception {
@DisplayName("Test Simple")
void testSimple() throws Exception {
RecordReader excel = new ExcelRecordReader();
excel.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheet.xlsx").getFile()));
assertTrue(excel.hasNext());
List<Writable> next = excel.next();
assertEquals(3,next.size());
assertEquals(3, next.size());
RecordReader headerReader = new ExcelRecordReader(1);
headerReader.initialize(new FileSplit(new ClassPathResource("datavec-excel/testsheetheader.xlsx").getFile()));
assertTrue(excel.hasNext());
List<Writable> next2 = excel.next();
assertEquals(3,next2.size());
assertEquals(3, next2.size());
}
}

View File

@ -17,7 +17,6 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.poi.excel;
import lombok.val;
@ -26,44 +25,45 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.primitives.Triple;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.nd4j.common.primitives.Triple;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.Assert.assertEquals;
@DisplayName("Excel Record Writer Test")
class ExcelRecordWriterTest {
public class ExcelRecordWriterTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@TempDir
public Path testDir;
@Test
public void testWriter() throws Exception {
@DisplayName("Test Writer")
void testWriter() throws Exception {
ExcelRecordWriter excelRecordWriter = new ExcelRecordWriter();
val records = records();
File tmpDir = testDir.newFolder();
File outputFile = new File(tmpDir,"testexcel.xlsx");
File tmpDir = testDir.toFile();
File outputFile = new File(tmpDir, "testexcel.xlsx");
outputFile.deleteOnExit();
FileSplit fileSplit = new FileSplit(outputFile);
excelRecordWriter.initialize(fileSplit,new NumberOfRecordsPartitioner());
excelRecordWriter.initialize(fileSplit, new NumberOfRecordsPartitioner());
excelRecordWriter.writeBatch(records.getRight());
excelRecordWriter.close();
File parentFile = outputFile.getParentFile();
assertEquals(1,parentFile.list().length);
assertEquals(1, parentFile.list().length);
ExcelRecordReader excelRecordReader = new ExcelRecordReader();
excelRecordReader.initialize(fileSplit);
List<List<Writable>> next = excelRecordReader.next(10);
assertEquals(10,next.size());
assertEquals(10, next.size());
}
private Triple<String,Schema,List<List<Writable>>> records() {
private Triple<String, Schema, List<List<Writable>>> records() {
List<List<Writable>> list = new ArrayList<>();
StringBuilder sb = new StringBuilder();
int numColumns = 3;
@ -80,13 +80,10 @@ public class ExcelRecordWriterTest {
}
list.add(temp);
}
Schema.Builder schemaBuilder = new Schema.Builder();
for(int i = 0; i < numColumns; i++) {
for (int i = 0; i < numColumns; i++) {
schemaBuilder.addColumnInteger(String.valueOf(i));
}
return Triple.of(sb.toString(),schemaBuilder.build(),list);
return Triple.of(sb.toString(), schemaBuilder.build(), list);
}
}

View File

@ -65,10 +65,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -17,14 +17,12 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.api.records.reader.impl;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.File;
import java.net.URI;
import java.sql.Connection;
@ -49,53 +47,57 @@ import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
public class JDBCRecordReaderTest {
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.junit.jupiter.api.DisplayName;
import java.nio.file.Path;
import org.junit.jupiter.api.extension.ExtendWith;
import static org.junit.jupiter.api.Assertions.assertThrows;
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@DisplayName("Jdbc Record Reader Test")
class JDBCRecordReaderTest {
@TempDir
public Path testDir;
Connection conn;
EmbeddedDataSource dataSource;
private final String dbName = "datavecTests";
private final String driverClassName = "org.apache.derby.jdbc.EmbeddedDriver";
@Before
public void setUp() throws Exception {
File f = testDir.newFolder();
@BeforeEach
void setUp() throws Exception {
File f = testDir.toFile();
System.setProperty("derby.system.home", f.getAbsolutePath());
dataSource = new EmbeddedDataSource();
dataSource.setDatabaseName(dbName);
dataSource.setCreateDatabase("create");
conn = dataSource.getConnection();
TestDb.dropTables(conn);
TestDb.buildCoffeeTable(conn);
}
@After
public void tearDown() throws Exception {
@AfterEach
void tearDown() throws Exception {
DbUtils.closeQuietly(conn);
}
@Test
public void testSimpleIter() throws Exception {
@DisplayName("Test Simple Iter")
void testSimpleIter() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>();
while (reader.hasNext()) {
List<Writable> values = reader.next();
records.add(values);
}
assertFalse(records.isEmpty());
List<Writable> first = records.get(0);
assertEquals(new Text("Bolivian Dark"), first.get(0));
assertEquals(new Text("14-001"), first.get(1));
@ -104,39 +106,43 @@ public class JDBCRecordReaderTest {
}
@Test
public void testSimpleWithListener() throws Exception {
@DisplayName("Test Simple With Listener")
void testSimpleWithListener() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordListener recordListener = new LogRecordListener();
reader.setListeners(recordListener);
reader.next();
assertTrue(recordListener.invoked());
}
}
@Test
public void testReset() throws Exception {
@DisplayName("Test Reset")
void testReset() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
List<List<Writable>> records = new ArrayList<>();
records.add(reader.next());
reader.reset();
records.add(reader.next());
assertEquals(2, records.size());
assertEquals(new Text("Bolivian Dark"), records.get(0).get(0));
assertEquals(new Text("Bolivian Dark"), records.get(1).get(0));
}
}
@Test(expected = IllegalStateException.class)
public void testLackingDataSourceShouldFail() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
reader.initialize(null);
}
@Test
@DisplayName("Test Lacking Data Source Should Fail")
void testLackingDataSourceShouldFail() {
assertThrows(IllegalStateException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
reader.initialize(null);
}
});
}
@Test
public void testConfigurationDataSourceInitialization() throws Exception {
@DisplayName("Test Configuration Data Source Initialization")
void testConfigurationDataSourceInitialization() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "jdbc:derby:" + dbName + ";create=true");
@ -146,28 +152,33 @@ public class JDBCRecordReaderTest {
}
}
@Test(expected = IllegalArgumentException.class)
public void testInitConfigurationMissingParametersShouldFail() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
reader.initialize(conf, null);
}
}
@Test(expected = UnsupportedOperationException.class)
public void testRecordDataInputStreamShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
reader.record(null, null);
}
@Test
@DisplayName("Test Init Configuration Missing Parameters Should Fail")
void testInitConfigurationMissingParametersShouldFail() {
assertThrows(IllegalArgumentException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee")) {
Configuration conf = new Configuration();
conf.set(JDBCRecordReader.JDBC_URL, "should fail anyway");
reader.initialize(conf, null);
}
});
}
@Test
public void testLoadFromMetaData() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()),
"SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
@DisplayName("Test Record Data Input Stream Should Fail")
void testRecordDataInputStreamShouldFail() {
assertThrows(UnsupportedOperationException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
reader.record(null, null);
}
});
}
@Test
@DisplayName("Test Load From Meta Data")
void testLoadFromMetaData() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataJdbc rmd = new RecordMetaDataJdbc(new URI(conn.getMetaData().getURL()), "SELECT * FROM Coffee WHERE ProdNum = ?", Collections.singletonList("14-001"), reader.getClass());
Record res = reader.loadFromMetaData(rmd);
assertNotNull(res);
assertEquals(new Text("Bolivian Dark"), res.getRecord().get(0));
@ -177,7 +188,8 @@ public class JDBCRecordReaderTest {
}
@Test
public void testNextRecord() throws Exception {
@DisplayName("Test Next Record")
void testNextRecord() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord();
List<Writable> fields = r.getRecord();
@ -193,7 +205,8 @@ public class JDBCRecordReaderTest {
}
@Test
public void testNextRecordAndRecover() throws Exception {
@DisplayName("Test Next Record And Recover")
void testNextRecordAndRecover() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
Record r = reader.nextRecord();
List<Writable> fields = r.getRecord();
@ -208,69 +221,91 @@ public class JDBCRecordReaderTest {
}
// Resetting the record reader when initialized as forward only should fail
@Test(expected = RuntimeException.class)
public void testResetForwardOnlyShouldFail() throws Exception {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
Configuration conf = new Configuration();
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
reader.initialize(conf, null);
reader.next();
reader.reset();
}
@Test
@DisplayName("Test Reset Forward Only Should Fail")
void testResetForwardOnlyShouldFail() {
assertThrows(RuntimeException.class, () -> {
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM Coffee", dataSource)) {
Configuration conf = new Configuration();
conf.setInt(JDBCRecordReader.JDBC_RESULTSET_TYPE, ResultSet.TYPE_FORWARD_ONLY);
reader.initialize(conf, null);
reader.next();
reader.reset();
}
});
}
@Test
public void testReadAllTypes() throws Exception {
@DisplayName("Test Read All Types")
void testReadAllTypes() throws Exception {
TestDb.buildAllTypesTable(conn);
try (JDBCRecordReader reader = new JDBCRecordReader("SELECT * FROM AllTypes", dataSource)) {
reader.initialize(null);
List<Writable> item = reader.next();
assertEquals(item.size(), 15);
assertEquals(BooleanWritable.class, item.get(0).getClass()); // boolean to boolean
assertEquals(Text.class, item.get(1).getClass()); // date to text
assertEquals(Text.class, item.get(2).getClass()); // time to text
assertEquals(Text.class, item.get(3).getClass()); // timestamp to text
assertEquals(Text.class, item.get(4).getClass()); // char to text
assertEquals(Text.class, item.get(5).getClass()); // long varchar to text
assertEquals(Text.class, item.get(6).getClass()); // varchar to text
assertEquals(DoubleWritable.class,
item.get(7).getClass()); // float to double (derby's float is an alias of double by default)
assertEquals(FloatWritable.class, item.get(8).getClass()); // real to float
assertEquals(DoubleWritable.class, item.get(9).getClass()); // decimal to double
assertEquals(DoubleWritable.class, item.get(10).getClass()); // numeric to double
assertEquals(DoubleWritable.class, item.get(11).getClass()); // double to double
assertEquals(IntWritable.class, item.get(12).getClass()); // integer to integer
assertEquals(IntWritable.class, item.get(13).getClass()); // small int to integer
assertEquals(LongWritable.class, item.get(14).getClass()); // bigint to long
// boolean to boolean
assertEquals(BooleanWritable.class, item.get(0).getClass());
// date to text
assertEquals(Text.class, item.get(1).getClass());
// time to text
assertEquals(Text.class, item.get(2).getClass());
// timestamp to text
assertEquals(Text.class, item.get(3).getClass());
// char to text
assertEquals(Text.class, item.get(4).getClass());
// long varchar to text
assertEquals(Text.class, item.get(5).getClass());
// varchar to text
assertEquals(Text.class, item.get(6).getClass());
assertEquals(DoubleWritable.class, // float to double (derby's float is an alias of double by default)
item.get(7).getClass());
// real to float
assertEquals(FloatWritable.class, item.get(8).getClass());
// decimal to double
assertEquals(DoubleWritable.class, item.get(9).getClass());
// numeric to double
assertEquals(DoubleWritable.class, item.get(10).getClass());
// double to double
assertEquals(DoubleWritable.class, item.get(11).getClass());
// integer to integer
assertEquals(IntWritable.class, item.get(12).getClass());
// small int to integer
assertEquals(IntWritable.class, item.get(13).getClass());
// bigint to long
assertEquals(LongWritable.class, item.get(14).getClass());
}
}
@Test(expected = RuntimeException.class)
public void testNextNoMoreShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
while (reader.hasNext()) {
@Test
@DisplayName("Test Next No More Should Fail")
void testNextNoMoreShouldFail() {
assertThrows(RuntimeException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
while (reader.hasNext()) {
reader.next();
}
reader.next();
}
reader.next();
}
});
}
@Test(expected = IllegalArgumentException.class)
public void testInvalidMetadataShouldFail() throws Exception {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
reader.loadFromMetaData(md);
}
@Test
@DisplayName("Test Invalid Metadata Should Fail")
void testInvalidMetadataShouldFail() {
assertThrows(IllegalArgumentException.class, () -> {
try (JDBCRecordReader reader = getInitializedReader("SELECT * FROM Coffee")) {
RecordMetaDataLine md = new RecordMetaDataLine(1, new URI("file://test"), JDBCRecordReader.class);
reader.loadFromMetaData(md);
}
});
}
private JDBCRecordReader getInitializedReader(String query) throws Exception {
int[] indices = {1}; // ProdNum column
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?",
indices);
// ProdNum column
int[] indices = { 1 };
JDBCRecordReader reader = new JDBCRecordReader(query, dataSource, "SELECT * FROM Coffee WHERE ProdNum = ?", indices);
reader.setTrimStrings(true);
reader.initialize(null);
return reader;
}
}
}

View File

@ -61,25 +61,18 @@
<artifactId>nd4j-common</artifactId>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-geo</artifactId>
<groupId>org.nd4j</groupId>
<artifactId>python4j-numpy</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-python</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

View File

@ -36,14 +36,14 @@ import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class LocalTransformProcessRecordReaderTests {

View File

@ -29,9 +29,9 @@ import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.AnalyzeLocal;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
@ -39,12 +39,11 @@ import org.nd4j.common.io.ClassPathResource;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestAnalyzeLocal {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testAnalysisBasic() throws Exception {
@ -72,7 +71,7 @@ public class TestAnalyzeLocal {
INDArray mean = arr.mean(0);
INDArray std = arr.std(0);
for( int i=0; i<5; i++ ){
for( int i = 0; i < 5; i++) {
double m = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getMean();
double stddev = ((NumericalColumnAnalysis)da.getColumnAnalysis().get(i)).getSampleStdev();
assertEquals(mean.getDouble(i), m, 1e-3);

View File

@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
@ -36,8 +36,8 @@ import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestLineRecordReaderFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.NDArrayToWritablesFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestNDArrayToWritablesFunction {

View File

@ -25,7 +25,7 @@ import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.WritablesToNDArrayFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@ -33,7 +33,7 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToNDArrayFunction {

View File

@ -30,12 +30,12 @@ import org.datavec.api.writable.Writable;
import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction;
import org.datavec.local.transforms.misc.WritablesToStringFunction;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestWritablesToStringFunctions {

View File

@ -17,10 +17,8 @@
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.MathFunction;
import org.datavec.api.transform.MathOp;
import org.datavec.api.transform.ReduceOp;
@ -31,108 +29,85 @@ import org.datavec.api.transform.reduce.Reducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.*;
import org.datavec.python.PythonTransform;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.*;
import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.DisplayName;
import static java.time.Duration.ofMillis;
import static org.junit.jupiter.api.Assertions.assertTimeout;
import static org.junit.Assert.assertEquals;
public class ExecutionTest {
@DisplayName("Execution Test")
class ExecutionTest {
@Test
public void testExecutionNdarray() {
Schema schema = new Schema.Builder()
.addColumnNDArray("first",new long[]{1,32577})
.addColumnNDArray("second",new long[]{1,32577}).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.ndArrayMathFunctionTransform("first", MathFunction.SIN)
.ndArrayMathFunctionTransform("second",MathFunction.COS)
.build();
@DisplayName("Test Execution Ndarray")
void testExecutionNdarray() {
Schema schema = new Schema.Builder().addColumnNDArray("first", new long[] { 1, 32577 }).addColumnNDArray("second", new long[] { 1, 32577 }).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema).ndArrayMathFunctionTransform("first", MathFunction.SIN).ndArrayMathFunctionTransform("second", MathFunction.COS).build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1,4,4);
INDArray secondArr = Nd4j.linspace(1,4,4);
INDArray firstArr = Nd4j.linspace(1, 4, 4);
INDArray secondArr = Nd4j.linspace(1, 4, 4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult);
assertEquals(secondExpected,secondResult);
assertEquals(expected, firstResult);
assertEquals(secondExpected, secondResult);
}
@Test
public void testExecutionSimple() {
Schema schema = new Schema.Builder().addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").
addColumnFloat("col3").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
@DisplayName("Test Execution Simple")
void testExecutionSimple() {
Schema schema = new Schema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").addColumnFloat("col3").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f)));
List<List<Writable>> rdd = (inputData);
List<List<Writable>> out = new ArrayList<>(LocalTransformExecutor.execute(rdd, tp));
Collections.sort(out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
});
List<List<Writable>> expected = new ArrayList<>();
expected.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f)));
expected.add(Arrays.<Writable>asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f)));
expected.add(Arrays.<Writable>asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f)));
assertEquals(expected, out);
}
@Test
public void testFilter() {
Schema filterSchema = new Schema.Builder()
.addColumnDouble("col1").addColumnDouble("col2")
.addColumnDouble("col3").build();
@DisplayName("Test Filter")
void testFilter() {
Schema filterSchema = new Schema.Builder().addColumnDouble("col1").addColumnDouble("col2").addColumnDouble("col3").build();
List<List<Writable>> inputData = new ArrayList<>();
inputData.add(Arrays.<Writable>asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1)));
inputData.add(Arrays.<Writable>asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1)));
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema)
.filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build();
TransformProcess transformProcess = new TransformProcess.Builder(filterSchema).filter(new DoubleColumnCondition("col1", ConditionOp.LessThan, 1)).build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputData, transformProcess);
assertEquals(2,execute.size());
assertEquals(2, execute.size());
}
@Test
public void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0")
.addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1")
.doubleMathOp("col2", MathOp.Add, 10.0).build();
@DisplayName("Test Execution Sequence")
void testExecutionSequence() {
Schema schema = new SequenceSchema.Builder().addColumnInteger("col0").addColumnCategorical("col1", "state0", "state1", "state2").addColumnDouble("col2").build();
TransformProcess tp = new TransformProcess.Builder(schema).categoricalToInteger("col1").doubleMathOp("col2", MathOp.Add, 10.0).build();
List<List<List<Writable>>> inputSequences = new ArrayList<>();
List<List<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1)));
@ -141,21 +116,17 @@ public class ExecutionTest {
List<List<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1)));
seq2.add(Arrays.<Writable>asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1)));
inputSequences.add(seq1);
inputSequences.add(seq2);
List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> rdd = (inputSequences);
List<List<List<Writable>>> out = LocalTransformExecutor.executeSequenceToSequence(rdd, tp);
Collections.sort(out, new Comparator<List<List<Writable>>>() {
@Override
public int compare(List<List<Writable>> o1, List<List<Writable>> o2) {
return -Integer.compare(o1.size(), o2.size());
}
});
List<List<List<Writable>>> expectedSequence = new ArrayList<>();
List<List<Writable>> seq1e = new ArrayList<>();
seq1e.add(Arrays.<Writable>asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1)));
@ -164,121 +135,37 @@ public class ExecutionTest {
List<List<Writable>> seq2e = new ArrayList<>();
seq2e.add(Arrays.<Writable>asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1)));
seq2e.add(Arrays.<Writable>asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1)));
expectedSequence.add(seq1e);
expectedSequence.add(seq2e);
assertEquals(expectedSequence, out);
}
@Test
public void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(
Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0))
);
@DisplayName("Test Reduction Global")
void testReductionGlobal() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new Text("second"), new DoubleWritable(5.0)));
List<List<Writable>> inData = in;
Schema s = new Schema.Builder()
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
Schema s = new Schema.Builder().addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Collections.singletonList(Arrays.<Writable>asList(new Text("first"), new DoubleWritable(4.0)));
assertEquals(expOut, out);
}
@Test
public void testReductionByKey(){
List<List<Writable>> in = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)),
Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0))
);
@DisplayName("Test Reduction By Key")
void testReductionByKey() {
List<List<Writable>> in = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), Arrays.<Writable>asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)));
List<List<Writable>> inData = in;
Schema s = new Schema.Builder()
.addColumnInteger("intCol")
.addColumnString("textCol")
.addColumnDouble("doubleCol")
.build();
TransformProcess tp = new TransformProcess.Builder(s)
.reduce(new Reducer.Builder(ReduceOp.TakeFirst)
.keyColumns("intCol")
.takeFirstColumns("textCol")
.meanColumns("doubleCol").build())
.build();
Schema s = new Schema.Builder().addColumnInteger("intCol").addColumnString("textCol").addColumnDouble("doubleCol").build();
TransformProcess tp = new TransformProcess.Builder(s).reduce(new Reducer.Builder(ReduceOp.TakeFirst).keyColumns("intCol").takeFirstColumns("textCol").meanColumns("doubleCol").build()).build();
List<List<Writable>> outRdd = LocalTransformExecutor.execute(inData, tp);
List<List<Writable>> out = outRdd;
List<List<Writable>> expOut = Arrays.asList(
Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)),
Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
List<List<Writable>> expOut = Arrays.asList(Arrays.<Writable>asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), Arrays.<Writable>asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0)));
out = new ArrayList<>(out);
Collections.sort(
out, new Comparator<List<Writable>>() {
@Override
public int compare(List<Writable> o1, List<Writable> o2) {
return Integer.compare(o1.get(0).toInt(), o2.get(0).toInt());
}
}
);
Collections.sort(out, Comparator.comparingInt(o -> o.get(0).toInt()));
assertEquals(expOut, out);
}
@Test(timeout = 60000L)
@Ignore("AB 2019/05/21 - Fine locally, timeouts on CI - Issue #7657 and #7771")
public void testPythonExecutionNdarray()throws Exception{
Schema schema = new Schema.Builder()
.addColumnNDArray("first",new long[]{1,32577})
.addColumnNDArray("second",new long[]{1,32577}).build();
TransformProcess transformProcess = new TransformProcess.Builder(schema)
.transform(
PythonTransform.builder().code(
"first = np.sin(first)\nsecond = np.cos(second)")
.outputSchema(schema).build())
.build();
List<List<Writable>> functions = new ArrayList<>();
List<Writable> firstRow = new ArrayList<>();
INDArray firstArr = Nd4j.linspace(1,4,4);
INDArray secondArr = Nd4j.linspace(1,4,4);
firstRow.add(new NDArrayWritable(firstArr));
firstRow.add(new NDArrayWritable(secondArr));
functions.add(firstRow);
List<List<Writable>> execute = LocalTransformExecutor.execute(functions, transformProcess);
INDArray firstResult = ((NDArrayWritable) execute.get(0).get(0)).get();
INDArray secondResult = ((NDArrayWritable) execute.get(0).get(1)).get();
INDArray expected = Transforms.sin(firstArr);
INDArray secondExpected = Transforms.cos(secondArr);
assertEquals(expected,firstResult);
assertEquals(secondExpected,secondResult);
}
}

View File

@ -1,154 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.geo.LocationType;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform;
import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform;
import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.common.io.ClassPathResource;
import java.io.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
/**
* @author saudet
*/
public class TestGeoTransforms {
@BeforeClass
public static void beforeClass() throws Exception {
//Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing
File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile();
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath());
}
@AfterClass
public static void afterClass(){
System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, "");
}
@Test
public void testCoordinatesDistanceTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev")
.build();
Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(4, out.numColumns());
assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames());
assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double),
out.getColumnTypes());
assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)),
transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"))));
assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"),
new DoubleWritable(Math.sqrt(160))),
transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"),
new Text("10|5"))));
}
@Test
public void testIPAddressToCoordinatesTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER");
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
String in = "81.2.69.160";
double latitude = 51.5142;
double longitude = -0.0931;
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
//Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(transform);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais);
Transform deserialized = (Transform) ois.readObject();
writables = deserialized.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER");
//System.out.println(Arrays.toString(coordinates));
assertEquals(2, coordinates.length);
assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1);
assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1);
}
@Test
public void testIPAddressToLocationTransform() throws Exception {
Schema schema = new Schema.Builder().addColumnString("column").build();
LocationType[] locationTypes = LocationType.values();
String in = "81.2.69.160";
String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167",
"51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record
for (int i = 0; i < locationTypes.length; i++) {
LocationType locationType = locationTypes[i];
String location = locations[i];
Transform transform = new IPAddressToLocationTransform("column", locationType);
transform.setInputSchema(schema);
Schema out = transform.transform(schema);
assertEquals(1, out.getColumnMetaData().size());
assertEquals(ColumnType.String, out.getMetaData(0).getColumnType());
List<Writable> writables = transform.map(Collections.singletonList((Writable) new Text(in)));
assertEquals(1, writables.size());
assertEquals(location, writables.get(0).toString());
//System.out.println(location);
}
}
}

View File

@ -1,379 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.local.transforms.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.condition.Condition;
import org.datavec.api.transform.filter.ConditionFilter;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.schema.Schema;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.api.writable.*;
import org.datavec.python.PythonCondition;
import org.datavec.python.PythonTransform;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.datavec.api.transform.schema.Schema.Builder;
import static org.junit.Assert.*;
@NotThreadSafe
public class TestPythonTransformProcess {
@Test()
public void testStringConcat() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnString("col1")
.addColumnString("col2");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col3");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!"));
List<Writable> outputs = tp.execute(inputs);
assertEquals((outputs.get(0)).toString(), "Hello ");
assertEquals((outputs.get(1)).toString(), "World!");
assertEquals((outputs.get(2)).toString(), "Hello World!");
}
@Test(timeout = 60000L)
public void testMixedTypes() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnInteger("col5");
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col5 = (int(col3) + col1 + int(col2)) * int(col4)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.inputSchema(initialSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList((Writable)new IntWritable(10),
new FloatWritable(3.5f),
new Text("5"),
new DoubleWritable(2.0)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(((LongWritable)outputs.get(4)).get(), 36);
}
@Test(timeout = 60000L)
public void testNDArray() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test(timeout = 60000L)
public void testNDArray2() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(shape);
INDArray arr2 = Nd4j.rand(shape);
INDArray expectedOutput = arr1.add(arr2);
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build() ).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test(timeout = 60000L)
public void testNDArrayMixed() throws Exception{
long[] shape = new long[]{3, 2};
INDArray arr1 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray arr2 = Nd4j.rand(DataType.DOUBLE, shape);
INDArray expectedOutput = arr1.add(arr2.castTo(DataType.DOUBLE));
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnNDArray("col1", shape)
.addColumnNDArray("col2", shape);
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnNDArray("col3", shape);
Schema finalSchema = schemaBuilder.build();
String pythonCode = "col3 = col1 + col2";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).build();
List<Writable> inputs = Arrays.asList(
(Writable)
new NDArrayWritable(arr1),
new NDArrayWritable(arr2)
);
List<Writable> outputs = tp.execute(inputs);
assertEquals(arr1, ((NDArrayWritable)outputs.get(0)).get());
assertEquals(arr2, ((NDArrayWritable)outputs.get(1)).get());
assertEquals(expectedOutput,((NDArrayWritable)outputs.get(2)).get());
}
@Test(timeout = 60000L)
public void testPythonFilter() {
Schema schema = new Builder().addColumnInteger("column").build();
Condition condition = new PythonCondition(
"f = lambda: column < 0"
);
condition.setInputSchema(schema);
Filter filter = new ConditionFilter(condition);
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1))));
assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1))));
assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10))));
}
@Test(timeout = 60000L)
public void testPythonFilterAndTransform() throws Exception{
Builder schemaBuilder = new Builder();
schemaBuilder
.addColumnInteger("col1")
.addColumnFloat("col2")
.addColumnString("col3")
.addColumnDouble("col4");
Schema initialSchema = schemaBuilder.build();
schemaBuilder.addColumnString("col6");
Schema finalSchema = schemaBuilder.build();
Condition condition = new PythonCondition(
"f = lambda: col1 < 0 and col2 > 10.0"
);
condition.setInputSchema(initialSchema);
Filter filter = new ConditionFilter(condition);
String pythonCode = "col6 = str(col1 + col2)";
TransformProcess tp = new TransformProcess.Builder(initialSchema).transform(
PythonTransform.builder().code(pythonCode)
.outputSchema(finalSchema)
.build()
).filter(
filter
).build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(-3),
new FloatWritable(3.0f),
new Text("abcd"),
new DoubleWritable(2.1))
);
inputs.add(
Arrays.asList(
(Writable)
new IntWritable(5),
new FloatWritable(11.2f),
new Text("abcd"),
new DoubleWritable(2.1))
);
LocalTransformExecutor.execute(inputs,tp);
}
@Test
public void testPythonTransformNoOutputSpecified() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable)new IntWritable(1)));
Schema inputSchema = new Builder()
.addColumnInteger("a")
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertEquals(3,execute.get(0).get(0).toInt());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testNumpyTransform() {
PythonTransform pythonTransform = PythonTransform.builder()
.code("a += 2; b = 'hello world'")
.returnAllInputs(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertNotNull(execute.get(0).get(1));
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
assertEquals("hello world",execute.get(0).get(1).toString());
}
@Test
public void testWithSetupRun() throws Exception {
PythonTransform pythonTransform = PythonTransform.builder()
.code("five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n")
.returnAllInputs(true)
.setupAndRun(true)
.build();
List<List<Writable>> inputs = new ArrayList<>();
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
Schema inputSchema = new Builder()
.addColumnNDArray("a",new long[]{1,1})
.addColumnNDArray("b", new long[]{1, 1})
.build();
TransformProcess tp = new TransformProcess.Builder(inputSchema)
.transform(pythonTransform)
.build();
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
assertFalse(execute.isEmpty());
assertNotNull(execute.get(0));
assertNotNull(execute.get(0).get(0));
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
}
}

View File

@ -28,11 +28,11 @@ import org.datavec.api.writable.*;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.*;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestJoin {

View File

@ -31,13 +31,13 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class TestCalculateSortedRank {

View File

@ -31,14 +31,14 @@ import org.datavec.api.writable.Writable;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.junit.Test;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class TestConvertToSequence {

View File

@ -41,6 +41,12 @@
</properties>
<dependencies>
<dependency>
<groupId>com.tdunning</groupId>
<artifactId>t-digest</artifactId>
<version>3.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
@ -122,10 +128,10 @@
<profiles>
<profile>
<id>test-nd4j-native</id>
<id>nd4j-tests-cpu</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<id>nd4j-tests-cuda</id>
</profile>
</profiles>
</project>

Some files were not shown because too many files have changed in this diff Show More